diff --git a/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/Exported/ExecutorchRuntimeEngine.h b/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/Exported/ExecutorchRuntimeEngine.h index 633b6728699..a03f6b3c62f 100644 --- a/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/Exported/ExecutorchRuntimeEngine.h +++ b/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/Exported/ExecutorchRuntimeEngine.h @@ -21,7 +21,7 @@ NS_ASSUME_NONNULL_BEGIN modelMethodName:(NSString *)modelMethodName error:(NSError * _Nullable * _Nullable)error NS_DESIGNATED_INITIALIZER; -- (nullable NSArray *)infer:(NSArray *)input +- (nullable NSArray *)infer:(NSArray *)values error:(NSError * _Nullable * _Nullable)error NS_SWIFT_NAME(infer(input:)); @end diff --git a/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/Exported/ExecutorchRuntimeEngine.mm b/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/Exported/ExecutorchRuntimeEngine.mm index 1ebbefde3db..756ca94f114 100644 --- a/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/Exported/ExecutorchRuntimeEngine.mm +++ b/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/Exported/ExecutorchRuntimeEngine.mm @@ -13,32 +13,6 @@ #import -static int kInitFailed = 0; -static int kInferenceFailed = 1; - -static auto NSStringToString(NSString *string) -> std::string -{ - const char *cStr = [string cStringUsingEncoding:NSUTF8StringEncoding]; - if (cStr) { - return cStr; - } - - NSData *data = [string dataUsingEncoding:NSUTF8StringEncoding allowLossyConversion:NO]; - return {reinterpret_cast([data bytes]), [data length]}; -} - -static auto StringToNSString(const std::string &string) -> NSString * -{ - CFStringRef cfString = CFStringCreateWithBytes( - kCFAllocatorDefault, - reinterpret_cast(string.c_str()), - string.size(), - kCFStringEncodingUTF8, - false - ); - return (__bridge_transfer NSString *)cfString; -} - @implementation ExecutorchRuntimeEngine { NSString *_modelPath; @@ -48,27 +22,18 @@ @implementation ExecutorchRuntimeEngine - (instancetype)initWithModelPath:(NSString *)modelPath modelMethodName:(NSString *)modelMethodName - error:(NSError * _Nullable * _Nullable)error + error:(NSError **)error { if (self = [super init]) { _modelPath = modelPath; _modelMethodName = modelMethodName; - try { - _module = std::make_unique(NSStringToString(modelPath)); - const auto e = _module->load_method(NSStringToString(modelMethodName)); - if (e != executorch::runtime::Error::Ok) { - if (error) { - *error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine" - code:kInitFailed - userInfo:@{NSDebugDescriptionErrorKey : StringToNSString(std::to_string(static_cast(e)))}]; - } - return nil; - } - } catch (...) { + _module = std::make_unique(modelPath.UTF8String); + const auto e = _module->load_method(modelMethodName.UTF8String); + if (e != executorch::runtime::Error::Ok) { if (error) { *error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine" - code:kInitFailed - userInfo:@{NSDebugDescriptionErrorKey : @"Unknown error"}]; + code:(NSInteger)e + userInfo:nil]; } return nil; } @@ -76,38 +41,28 @@ - (instancetype)initWithModelPath:(NSString *)modelPath return self; } -- (nullable NSArray *)infer:(NSArray *)input - error:(NSError * _Nullable * _Nullable)error +- (nullable NSArray *)infer:(NSArray *)values + error:(NSError **)error { - try { - std::vector inputEValues; - inputEValues.reserve(input.count); - for (ExecutorchRuntimeValue *inputValue in input) { - inputEValues.push_back([inputValue getBackedValue]); - } - const auto result = _module->execute(NSStringToString(_modelMethodName), inputEValues); - if (!result.ok()) { - const auto executorchError = static_cast(result.error()); - if (error) { - *error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine" - code:kInferenceFailed - userInfo:@{NSDebugDescriptionErrorKey : StringToNSString(std::to_string(executorchError))}]; - } - return nil; - } - NSMutableArray *const resultValues = [NSMutableArray new]; - for (const auto &evalue : result.get()) { - [resultValues addObject:[[ExecutorchRuntimeValue alloc] initWithEValue:evalue]]; - } - return resultValues; - } catch (...) { + std::vector inputEValues; + inputEValues.reserve(values.count); + for (ExecutorchRuntimeValue *inputValue in values) { + inputEValues.push_back([inputValue getBackedValue]); + } + const auto result = _module->execute(_modelMethodName.UTF8String, inputEValues); + if (!result.ok()) { if (error) { *error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine" - code:kInferenceFailed - userInfo:@{NSDebugDescriptionErrorKey : @"Unknown error"}]; + code:(NSInteger)result.error() + userInfo:nil]; } return nil; } + NSMutableArray *const resultValues = [NSMutableArray new]; + for (const auto &evalue : result.get()) { + [resultValues addObject:[[ExecutorchRuntimeValue alloc] initWithEValue:evalue]]; + } + return resultValues; } @end diff --git a/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/__tests__/ExecutorchRuntimeEngineTests.mm b/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/__tests__/ExecutorchRuntimeEngineTests.mm index 610bddd51c9..e243535cdf5 100644 --- a/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/__tests__/ExecutorchRuntimeEngineTests.mm +++ b/extension/apple/ExecutorchRuntimeBridge/ExecutorchRuntimeBridge/__tests__/ExecutorchRuntimeEngineTests.mm @@ -26,8 +26,7 @@ - (void)testInvalidModel XCTAssertNil(engine); XCTAssertNotNil(runtimeInitError); - XCTAssertEqual(runtimeInitError.code, 0); - XCTAssertEqualObjects(runtimeInitError.userInfo[NSDebugDescriptionErrorKey], @"34"); + XCTAssertEqual(runtimeInitError.code, 34); // 34 is the code for AccessFailed. }