From adc8a1259b55f1179ca0e7893a7a3b7735f4bc2b Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Wed, 26 Mar 2025 20:30:37 -0700 Subject: [PATCH] Overloads for Module execute API. Summary: https://github.com/pytorch/executorch/issues/8363 Reviewed By: mergennachin Differential Revision: D71921054 --- .../ExecuTorch/Exported/ExecuTorchModule.h | 43 +++++++++++++++++++ .../ExecuTorch/Exported/ExecuTorchModule.mm | 27 ++++++++++++ .../ExecuTorch/__tests__/ModuleTest.swift | 2 +- 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h index 9229c60512a..34789071caa 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.h @@ -130,6 +130,49 @@ __attribute__((deprecated("This API is experimental."))) error:(NSError **)error NS_SWIFT_NAME(execute(_:_:)); +/** + * Executes a specific method with the provided single input value. + * + * The method is loaded on demand if not already loaded. + * + * @param methodName A string representing the method name. + * @param value An ExecuTorchValue object representing the input. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)executeMethod:(NSString *)methodName + withInput:(ExecuTorchValue *)value + error:(NSError **)error + NS_SWIFT_NAME(execute(_:_:)); + +/** + * Executes a specific method with no input values. + * + * The method is loaded on demand if not already loaded. + * + * @param methodName A string representing the method name. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)executeMethod:(NSString *)methodName + error:(NSError **)error + NS_SWIFT_NAME(execute(_:)); + +/** + * Executes a specific method with the provided input tensors. + * + * The method is loaded on demand if not already loaded. + * + * @param methodName A string representing the method name. + * @param tensors An NSArray of ExecuTorchTensor objects representing the inputs. + * @param error A pointer to an NSError pointer that is set if an error occurs. + * @return An NSArray of ExecuTorchValue objects representing the outputs, or nil in case of an error. + */ +- (nullable NSArray *)executeMethod:(NSString *)methodName + withTensors:(NSArray *)tensors + error:(NSError **)error + NS_SWIFT_NAME(execute(_:_:)); + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm index 5142a969c8f..243ab3c159b 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchModule.mm @@ -139,4 +139,31 @@ - (BOOL)isMethodLoaded:(NSString *)methodName { return outputs; } +- (nullable NSArray *)executeMethod:(NSString *)methodName + withInput:(ExecuTorchValue *)value + error:(NSError **)error { + return [self executeMethod:methodName + withInputs:@[value] + error:error]; +} + +- (nullable NSArray *)executeMethod:(NSString *)methodName + error:(NSError **)error { + return [self executeMethod:methodName + withInputs:@[] + error:error]; +} + +- (nullable NSArray *)executeMethod:(NSString *)methodName + withTensors:(NSArray *)tensors + error:(NSError **)error { + NSMutableArray *values = [NSMutableArray arrayWithCapacity:tensors.count]; + for (ExecuTorchTensor *tensor in tensors) { + [values addObject:[ExecuTorchValue valueWithTensor:tensor]]; + } + return [self executeMethod:methodName + withInputs:values + error:error]; +} + @end diff --git a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift index feaa0f19826..87e35d510ce 100644 --- a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift +++ b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift @@ -63,7 +63,7 @@ class ModuleTest: XCTestCase { let inputTensor = inputData.withUnsafeMutableBytes { Tensor(bytesNoCopy: $0.baseAddress!, shape:[1], dataType: .float) } - let inputs = [Value(inputTensor), Value(inputTensor)] + let inputs = [inputTensor, inputTensor] var outputs: [Value]? XCTAssertNoThrow(outputs = try module.execute("forward", inputs)) var outputData: [Float] = [2.0]