diff --git a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.h b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.h index 747286b9ec3..3121259921a 100644 --- a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.h +++ b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.h @@ -184,10 +184,16 @@ withTokenCallback:(nullable void (^)(NSString *))callback error:(NSError **)error; /** - Stops any ongoing generation and cleans up internal resources. + Stop producing new tokens and terminate the current generation process. */ - (void)stop; +/** + Remove the prefilled tokens from the KV cache and resets the start position + to 0. It also clears the stats for previous runs. + */ +- (void)reset; + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm index b95e480aded..bdf78d3f15e 100644 --- a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm +++ b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMMultimodalRunner.mm @@ -216,4 +216,10 @@ - (void)stop { } } +- (void)reset { + if (_runner) { + _runner->reset(); + } +} + @end diff --git a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMTextRunner.h b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMTextRunner.h index b2c628fadf6..ca9867ebbb0 100644 --- a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMTextRunner.h +++ b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMTextRunner.h @@ -64,10 +64,16 @@ withTokenCallback:(nullable void (^)(NSString *))callback error:(NSError **)error; /** - Stops any ongoing generation and cleans up internal resources. + Stop producing new tokens and terminate the current generation process. */ - (void)stop; +/** + Remove the prefilled tokens from the KV cache and resets the start position + to 0. It also clears the stats for previous runs. + */ +- (void)reset; + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMTextRunner.mm b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMTextRunner.mm index ac50b000704..f4516009694 100644 --- a/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMTextRunner.mm +++ b/extension/llm/apple/ExecuTorchLLM/Exported/ExecuTorchLLMTextRunner.mm @@ -101,4 +101,10 @@ - (void)stop { } } +- (void)reset { + if (_runner) { + _runner->reset(); + } +} + @end diff --git a/extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift b/extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift index e1ee4372187..5176e193ab8 100644 --- a/extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift +++ b/extension/llm/apple/ExecuTorchLLM/__tests__/MultimodalRunnerTest.swift @@ -45,6 +45,11 @@ extension UIImage { } class MultimodalRunnerTest: XCTestCase { + let systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: " + let assistantPrompt = "ASSISTANT: " + let userPrompt = "What's on the picture?" + let sequenceLength = 768 + func test() { let bundle = Bundle(for: type(of: self)) guard let modelPath = bundle.path(forResource: "llava", ofType: "pte"), @@ -59,10 +64,25 @@ class MultimodalRunnerTest: XCTestCase { do { try runner.generate([ - MultimodalInput("A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "), + MultimodalInput(systemPrompt), + MultimodalInput(image.asImage()), + MultimodalInput("\(userPrompt) \(assistantPrompt)"), + ], sequenceLength: sequenceLength) { token in + text += token + } + } catch { + XCTFail("Failed to generate text with error \(error)") + } + XCTAssertTrue(text.lowercased().contains("waterfall")) + + text = "" + runner.reset() + do { + try runner.generate([ + MultimodalInput(systemPrompt), MultimodalInput(image.asImage()), - MultimodalInput("What's on the picture? ASSISTANT: "), - ], sequenceLength: 768) { token in + MultimodalInput("\(userPrompt) \(assistantPrompt)"), + ], sequenceLength: sequenceLength) { token in text += token } } catch { diff --git a/extension/llm/apple/ExecuTorchLLM/__tests__/TextRunnerTest.swift b/extension/llm/apple/ExecuTorchLLM/__tests__/TextRunnerTest.swift index 42dbac8ae30..6a91960b088 100644 --- a/extension/llm/apple/ExecuTorchLLM/__tests__/TextRunnerTest.swift +++ b/extension/llm/apple/ExecuTorchLLM/__tests__/TextRunnerTest.swift @@ -36,6 +36,9 @@ struct SpecialTokens { } class TextRunnerTest: XCTestCase { + let userPrompt = "The capital of France is called" + let sequenceLength = 128 + func test() { let bundle = Bundle(for: type(of: self)) guard let modelPath = bundle.path(forResource: "llama3_2-1B", ofType: "pte"), @@ -47,12 +50,23 @@ class TextRunnerTest: XCTestCase { var text = "" do { - try runner.generate("hello", sequenceLength: 2) { token in + try runner.generate(userPrompt, sequenceLength: sequenceLength) { token in + text += token + } + } catch { + XCTFail("Failed to generate text with error \(error)") + } + XCTAssertTrue(text.lowercased().contains("paris")) + + text = "" + runner.reset() + do { + try runner.generate(userPrompt, sequenceLength: sequenceLength) { token in text += token } } catch { XCTFail("Failed to generate text with error \(error)") } - XCTAssertEqual("hello,", text.lowercased()) + XCTAssertTrue(text.lowercased().contains("paris")) } }