Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,23 @@ __attribute__((deprecated("This API is experimental.")))

@param modelPath File system path to the serialized model.
@param tokenizerPath File system path to the tokenizer data.
@param tokens An array of NSString special tokens to use during tokenization.
@return An initialized ExecuTorchLLMTextRunner instance.
*/
- (instancetype)initWithModelPath:(NSString *)modelPath
tokenizerPath:(NSString *)tokenizerPath;

/**
Initializes a text LLM runner with the given model and tokenizer paths,
and a list of special tokens to include in the tokenizer.

@param modelPath File system path to the serialized model.
@param tokenizerPath File system path to the tokenizer data.
@param specialTokens An array of NSString special tokens to use during tokenization.
@return An initialized ExecuTorchLLMTextRunner instance.
*/
- (instancetype)initWithModelPath:(NSString *)modelPath
tokenizerPath:(NSString *)tokenizerPath
specialTokens:(NSArray<NSString *> *)tokens
specialTokens:(NSArray<NSString *> *)specialTokens
NS_DESIGNATED_INITIALIZER;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,22 @@ @implementation ExecuTorchLLMTextRunner {
std::unique_ptr<llm::TextLLMRunner> _runner;
}

- (instancetype)initWithModelPath:(NSString*)modelPath
tokenizerPath:(NSString*)tokenizerPath {
return [self initWithModelPath:modelPath
tokenizerPath:tokenizerPath
specialTokens:@[]];
}

- (instancetype)initWithModelPath:(NSString*)modelPath
tokenizerPath:(NSString*)tokenizerPath
specialTokens:(NSArray<NSString*>*)tokens {
specialTokens:(NSArray<NSString*>*)specialTokens {
self = [super init];
if (self) {
_modelPath = [modelPath copy];
_tokenizerPath = [tokenizerPath copy];
_specialTokens = std::make_unique<std::vector<std::string>>();
for (NSString *token in tokens) {
for (NSString *token in specialTokens) {
_specialTokens->emplace_back(token.UTF8String);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class MultimodalRunnerTest: XCTestCase {
let userPrompt = "What's on the picture?"
let sequenceLength = 768

func test() {
func testLLaVA() {
let bundle = Bundle(for: type(of: self))
guard let modelPath = bundle.path(forResource: "llava", ofType: "pte"),
let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "bin"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TextRunnerTest: XCTestCase {
let userPrompt = "The capital of France is called"
let sequenceLength = 128

func test() {
func testLLaMA() {
let bundle = Bundle(for: type(of: self))
guard let modelPath = bundle.path(forResource: "llama3_2-1B", ofType: "pte"),
let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "model") else {
Expand Down Expand Up @@ -73,4 +73,39 @@ class TextRunnerTest: XCTestCase {
}
XCTAssertTrue(text.lowercased().contains("paris"))
}

func testPhi4() {
let bundle = Bundle(for: type(of: self))
guard let modelPath = bundle.path(forResource: "phi4-mini", ofType: "pte"),
let tokenizerPath = bundle.path(forResource: "tokenizer", ofType: "json") else {
XCTFail("Couldn't find model or tokenizer files")
return
}
let runner = TextRunner(modelPath: modelPath, tokenizerPath: tokenizerPath)
var text = ""

do {
try runner.generate(userPrompt, Config {
$0.sequenceLength = sequenceLength
}) { token in
text += token
}
} catch {
XCTFail("Failed to generate text with error \(error)")
}
XCTAssertTrue(text.lowercased().contains("paris"))

text = ""
runner.reset()
do {
try runner.generate(userPrompt, Config {
$0.sequenceLength = sequenceLength
}) { token in
text += token
}
} catch {
XCTFail("Failed to generate text with error \(error)")
}
XCTAssertTrue(text.lowercased().contains("paris"))
}
}
Loading