Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions Example/LocalLLMClientExample/AI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import LocalLLMClientLlama
import UIKit
#endif

// TODO: Convert to struct
enum LLMModel: Sendable, CaseIterable, Identifiable {
case qwen3
case qwen3_4b
Expand Down Expand Up @@ -45,7 +46,7 @@ enum LLMModel: Sendable, CaseIterable, Identifiable {
}
}

var clipFilename: String? {
var mmprojFilename: String? {
switch self {
case .qwen3, .qwen3_4b, .qwen2_5VL_3b, .gemma3: nil
#if os(macOS)
Expand All @@ -58,10 +59,7 @@ enum LLMModel: Sendable, CaseIterable, Identifiable {
}

var isMLX: Bool {
switch self {
case .qwen3, .qwen3_4b, .qwen2_5VL_3b: true
case .gemma3, .gemma3_4b, .mobileVLM_3b: false
}
filename == nil
}

var supportsVision: Bool {
Expand All @@ -83,46 +81,55 @@ final class AI {
private(set) var isLoading = false
private(set) var downloadProgress: Double = 0

private var client: AnyLLMClient?
private var session: LLMSession?

var messages: [LLMInput.Message] {
get { session?.messages ?? [] }
set { session?.messages = newValue }
}

func loadLLM() async {
isLoading = true
defer { isLoading = false }

// Release memory first if a previous model was loaded
client = nil
session = nil

do {
let downloader = Downloader(model: model)
if downloader.isDownloaded {
downloadProgress = 1
let downloadModel: LLMSession.DownloadModel = if model.isMLX {
.mlx(id: model.id)
} else {
downloadProgress = 0
try await downloader.download { @MainActor [weak self] progress in
self?.downloadProgress = progress
}
.llama(
id: model.id,
model: model.filename!,
mmproj: model.mmprojFilename,
parameter: .init(options: .init(verbose: true))
)
}

#if os(iOS)
while downloadProgress < 1 || UIApplication.shared.applicationState != .active {
try await Task.sleep(for: .seconds(2))
try await downloadModel.downloadModel { @MainActor [weak self] progress in
self?.downloadProgress = progress
print("Download progress: \(progress)")
}
#endif

if model.isMLX {
client = try await AnyLLMClient(LocalLLMClient.mlx(url: downloader.url))
} else {
client = try await AnyLLMClient(LocalLLMClient.llama(url: downloader.url, mmprojURL: downloader.clipURL, verbose: true))
}
session = LLMSession(model: downloadModel)
} catch {
print("Failed to load LLM: \(error)")
}
}

func ask(_ messages: [LLMInput.Message]) async throws -> AsyncThrowingStream<String, any Error> {
guard let client else {
func ask(_ message: String, attachments: [LLMAttachment]) async throws -> AsyncThrowingStream<String, any Error> {
guard let session else {
throw LLMError.failedToLoad(reason: "LLM not loaded")
}
return try await client.textStream(from: .chat(messages))
return session.streamResponse(to: message, attachments: attachments)
}
}

#if DEBUG
extension AI {
func setSession(_ session: LLMSession) {
self.session = session
}
}
#endif
2 changes: 1 addition & 1 deletion Example/LocalLLMClientExample/App.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct RootView: View {

var body: some View {
NavigationStack {
ChatView()
ChatView(viewModel: .init(ai: ai))
}
.disabled(ai.isLoading)
.overlay {
Expand Down
54 changes: 31 additions & 23 deletions Example/LocalLLMClientExample/BottomBar.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import LocalLLMClient

struct BottomBar: View {
@Binding var text: String
@Binding var images: [ChatMessage.Image]
@Binding var attachments: [LLMAttachment]
let isGenerating: Bool
let onSubmit: (String) -> Void
let onCancel: () -> Void
Expand Down Expand Up @@ -45,29 +45,37 @@ struct BottomBar: View {
}
}
.safeAreaInset(edge: .top) {
if !images.isEmpty {
ScrollView(.horizontal) {
HStack {
ForEach(images) { image in
Image(llm: image.value)
.resizable()
.aspectRatio(1, contentMode: .fill)
.cornerRadius(8)
.contextMenu {
Button {
images.removeAll { $0.id == image.id }
} label: {
Text("Remove")
}
if !attachments.isEmpty {
attachmentList
}
}
.animation(.default, value: text.isEmpty)
.animation(.default, value: attachments.count)
}

@ViewBuilder
private var attachmentList: some View {
ScrollView(.horizontal) {
HStack {
ForEach(attachments) { attachment in
switch attachment.content {
case let .image(image):
Image(llm: image)
.resizable()
.aspectRatio(1, contentMode: .fill)
.cornerRadius(8)
.contextMenu {
Button {
attachments.removeAll { $0.id == attachment.id }
} label: {
Text("Remove")
}
}
}
}
.frame(height: 60)
}
}
.frame(height: 60)
}
.animation(.default, value: text.isEmpty)
.animation(.default, value: images.count)
}

@ViewBuilder
Expand Down Expand Up @@ -129,18 +137,18 @@ struct BottomBar: View {
Task {
let data = try await item.loadTransferable(type: Data.self)
guard let data, let image = LLMInputImage(data: data) else { return }
images.append(.init(value: image))
attachments.append(.image(image))
}
}
}
}

#Preview(traits: .sizeThatFitsLayout) {
@Previewable @State var text = ""
@Previewable @State var images: [ChatMessage.Image] = [
.preview, .preview2
@Previewable @State var attachments: [LLMAttachment] = [
.imagePreview, .imagePreview2
]

BottomBar(text: $text, images: $images, isGenerating: false, onSubmit: { _ in }, onCancel: {})
BottomBar(text: $text, attachments: $attachments, isGenerating: false, onSubmit: { _ in }, onCancel: {})
.environment(AI())
}
76 changes: 44 additions & 32 deletions Example/LocalLLMClientExample/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@ import LocalLLMClient
import LocalLLMClientMLX

struct ChatView: View {
@State var viewModel = ChatViewModel()
@State private var position = ScrollPosition(idType: ChatMessage.ID.self)
@State var viewModel: ChatViewModel
@State private var position = ScrollPosition(idType: LLMInput.Message.ID.self)

@Environment(AI.self) private var ai

var body: some View {
VStack {
messageList
MessageList(messages: viewModel.messages)

BottomBar(
text: $viewModel.inputText,
images: $viewModel.inputImages,
attachments: $viewModel.inputAttachments,
isGenerating: viewModel.isGenerating
) { _ in
viewModel.sendMessage(to: ai)
viewModel.sendMessage()
} onCancel: {
viewModel.cancelGeneration()
}
Expand All @@ -28,58 +28,64 @@ struct ChatView: View {
ToolbarItem {
Menu {
Button("Clear Chat") {
viewModel.clearMessages()
ai.messages = []
}
} label: {
Image(systemName: "ellipsis.circle")
}
}
}
.onChange(of: ai.model) { _, _ in
viewModel.clearMessages()
ai.messages = []
}
}
}

struct MessageList: View {
let messages: [LLMInput.Message]

@State private var position = ScrollPosition(idType: LLMInput.Message.ID.self)

@ViewBuilder
private var messageList: some View {
var body: some View {
ScrollView {
LazyVStack(spacing: 12) {
ForEach(viewModel.messages) { message in
ForEach(messages) { message in
ChatBubbleView(message: message)
.id(message.id)
}
}
.scrollTargetLayout()
.padding(.horizontal)
}
.onChange(of: viewModel.messages) { _, _ in
withAnimation {
position.scrollTo(edge: .bottom)
}
}
.scrollPosition($position)
.onChange(of: messages) { _, _ in
position.scrollTo(edge: .bottom)
}
}
}

struct ChatBubbleView: View {
let message: ChatMessage
let message: LLMInput.Message

var body: some View {
let isUser = message.role == .user

VStack(alignment: isUser ? .trailing : .leading) {
LazyVGrid(columns: [.init(.adaptive(minimum: 100))], alignment: .leading) {
ForEach(message.images) { image in
Image(llm: image.value)
.resizable()
.scaledToFit()
.cornerRadius(16)
ForEach(message.attachments) { attachment in
switch attachment.content {
case let .image(image):
Image(llm: image)
.resizable()
.scaledToFit()
.cornerRadius(16)
}
}
.scaleEffect(x: isUser ? -1 : 1)
}
.scaleEffect(x: isUser ? -1 : 1)

Text(message.text)
Text(message.content)
.padding(12)
.background(isUser ? Color.accentColor : .gray.opacity(0.2))
.foregroundColor(isUser ? .white : .primary)
Expand All @@ -91,17 +97,23 @@ struct ChatBubbleView: View {
}

#Preview("Text") {
NavigationStack {
ChatView(viewModel: .init(messages: [
.init(role: .user, text: "Hello"),
.init(role: .assistant, text: "Hi! How can I help you?"),
.init(role: .user, text: "Hello", images: [.preview, .preview2]),
@Previewable @State var ai: AI = {
let ai = AI()
ai.setSession(.init(model: .mlx(id: ""), messages: [
.user("Hello"),
.assistant("Hi! How can I help you?"),
.user("What is in these images?", attachments: [.imagePreview, .imagePreview2])
]))
return ai
}()

NavigationStack {
ChatView(viewModel: .init(ai: ai))
}
.environment(AI())
.environment(ai)
}

extension ChatMessage.Image {
static let preview = try! Self.init(value: LLMInputImage(data: .init(contentsOf: URL(string: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.jpeg")!))!)
static let preview2 = try! Self.init(value: LLMInputImage(data: .init(contentsOf: URL(string: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")!))!)
extension LLMAttachment {
static let imagePreview = try! Self.image(LLMInputImage(data: .init(contentsOf: URL(string: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.jpeg")!))!)
static let imagePreview2 = try! Self.image(LLMInputImage(data: .init(contentsOf: URL(string: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")!))!)
}
Loading