Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ NS_ASSUME_NONNULL_BEGIN
modelMethodName:(NSString *)modelMethodName
error:(NSError * _Nullable * _Nullable)error NS_DESIGNATED_INITIALIZER;

- (nullable NSArray<ExecutorchRuntimeValue *> *)infer:(NSArray<ExecutorchRuntimeValue *> *)input
- (nullable NSArray<ExecutorchRuntimeValue *> *)infer:(NSArray<ExecutorchRuntimeValue *> *)values
error:(NSError * _Nullable * _Nullable)error NS_SWIFT_NAME(infer(input:));

@end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,6 @@

#import <executorch/extension/module/module.h>

static int kInitFailed = 0;
static int kInferenceFailed = 1;

static auto NSStringToString(NSString *string) -> std::string
{
const char *cStr = [string cStringUsingEncoding:NSUTF8StringEncoding];
if (cStr) {
return cStr;
}

NSData *data = [string dataUsingEncoding:NSUTF8StringEncoding allowLossyConversion:NO];
return {reinterpret_cast<const char *>([data bytes]), [data length]};
}

static auto StringToNSString(const std::string &string) -> NSString *
{
CFStringRef cfString = CFStringCreateWithBytes(
kCFAllocatorDefault,
reinterpret_cast<const UInt8 *>(string.c_str()),
string.size(),
kCFStringEncodingUTF8,
false
);
return (__bridge_transfer NSString *)cfString;
}

@implementation ExecutorchRuntimeEngine
{
NSString *_modelPath;
Expand All @@ -48,66 +22,47 @@ @implementation ExecutorchRuntimeEngine

- (instancetype)initWithModelPath:(NSString *)modelPath
modelMethodName:(NSString *)modelMethodName
error:(NSError * _Nullable * _Nullable)error
error:(NSError **)error
{
if (self = [super init]) {
_modelPath = modelPath;
_modelMethodName = modelMethodName;
try {
_module = std::make_unique<torch::executor::Module>(NSStringToString(modelPath));
const auto e = _module->load_method(NSStringToString(modelMethodName));
if (e != executorch::runtime::Error::Ok) {
if (error) {
*error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine"
code:kInitFailed
userInfo:@{NSDebugDescriptionErrorKey : StringToNSString(std::to_string(static_cast<uint32_t>(e)))}];
}
return nil;
}
} catch (...) {
_module = std::make_unique<torch::executor::Module>(modelPath.UTF8String);
const auto e = _module->load_method(modelMethodName.UTF8String);
if (e != executorch::runtime::Error::Ok) {
if (error) {
*error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine"
code:kInitFailed
userInfo:@{NSDebugDescriptionErrorKey : @"Unknown error"}];
code:(NSInteger)e
userInfo:nil];
}
return nil;
}
}
return self;
}

- (nullable NSArray<ExecutorchRuntimeValue *> *)infer:(NSArray<ExecutorchRuntimeValue *> *)input
error:(NSError * _Nullable * _Nullable)error
- (nullable NSArray<ExecutorchRuntimeValue *> *)infer:(NSArray<ExecutorchRuntimeValue *> *)values
error:(NSError **)error
{
try {
std::vector<torch::executor::EValue> inputEValues;
inputEValues.reserve(input.count);
for (ExecutorchRuntimeValue *inputValue in input) {
inputEValues.push_back([inputValue getBackedValue]);
}
const auto result = _module->execute(NSStringToString(_modelMethodName), inputEValues);
if (!result.ok()) {
const auto executorchError = static_cast<uint32_t>(result.error());
if (error) {
*error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine"
code:kInferenceFailed
userInfo:@{NSDebugDescriptionErrorKey : StringToNSString(std::to_string(executorchError))}];
}
return nil;
}
NSMutableArray<ExecutorchRuntimeValue *> *const resultValues = [NSMutableArray new];
for (const auto &evalue : result.get()) {
[resultValues addObject:[[ExecutorchRuntimeValue alloc] initWithEValue:evalue]];
}
return resultValues;
} catch (...) {
std::vector<torch::executor::EValue> inputEValues;
inputEValues.reserve(values.count);
for (ExecutorchRuntimeValue *inputValue in values) {
inputEValues.push_back([inputValue getBackedValue]);
}
const auto result = _module->execute(_modelMethodName.UTF8String, inputEValues);
if (!result.ok()) {
if (error) {
*error = [NSError errorWithDomain:@"ExecutorchRuntimeEngine"
code:kInferenceFailed
userInfo:@{NSDebugDescriptionErrorKey : @"Unknown error"}];
code:(NSInteger)result.error()
userInfo:nil];
}
return nil;
}
NSMutableArray<ExecutorchRuntimeValue *> *const resultValues = [NSMutableArray new];
for (const auto &evalue : result.get()) {
[resultValues addObject:[[ExecutorchRuntimeValue alloc] initWithEValue:evalue]];
}
return resultValues;
}

@end
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ - (void)testInvalidModel
XCTAssertNil(engine);
XCTAssertNotNil(runtimeInitError);

XCTAssertEqual(runtimeInitError.code, 0);
XCTAssertEqualObjects(runtimeInitError.userInfo[NSDebugDescriptionErrorKey], @"34");
XCTAssertEqual(runtimeInitError.code, 34);
// 34 is the code for AccessFailed.
}

Expand Down
Loading