diff --git a/docs/source/using-executorch-ios.md b/docs/source/using-executorch-ios.md index 3e01f0d4688..263f58a7dd0 100644 --- a/docs/source/using-executorch-ios.md +++ b/docs/source/using-executorch-ios.md @@ -243,7 +243,7 @@ let imageBuffer: UnsafeMutableRawPointer = ... // Existing image buffer let inputTensor = Tensor(&imageBuffer, shape: [1, 3, 224, 224]) // Execute the 'forward' method with the given input tensor and get an output tensor back. -let outputTensor: Tensor = try module.forward(inputTensor)[0].tensor()! +let outputTensor: Tensor = try module.forward(inputTensor)! // Copy the tensor data into logits array for easier access. let logits = outputTensor.scalars() diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorch+Module.swift b/extension/apple/ExecuTorch/Exported/ExecuTorch+Module.swift index cf7414c4552..599a990b64c 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorch+Module.swift +++ b/extension/apple/ExecuTorch/Exported/ExecuTorch+Module.swift @@ -93,3 +93,86 @@ public extension Module { try forward(inputs) } } + +@available(*, deprecated, message: "This API is experimental.") +public extension Module { + /// Executes a specific method and decodes the outputs into `Output` generic type. + /// + /// - Parameters: + /// - method: The name of the method to execute. + /// - inputs: An array of `ValueConvertible` inputs. + /// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch. + /// - Throws: An error if loading, execution or result conversion fails. + func execute(_ method: String, _ inputs: [ValueConvertible]) throws -> Output { + try Output(__executeMethod(method, withInputs: inputs.map { $0.asValue() })) + } + + /// Executes a specific method with variadic inputs and decodes into `Output` generic type. + /// + /// - Parameters: + /// - method: The name of the method to execute. + /// - inputs: A variadic list of `ValueConvertible` inputs. + /// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch. + /// - Throws: An error if loading, execution or result conversion fails. + func execute(_ method: String, _ inputs: ValueConvertible...) throws -> Output { + try execute(method, inputs) + } + + /// Executes a specific method with a single input and decodes into `Output` generic type. + /// + /// - Parameters: + /// - method: The name of the method to execute. + /// - input: A single `ValueConvertible` input. + /// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch. + /// - Throws: An error if loading, execution or result conversion fails. + func execute(_ method: String, _ input: ValueConvertible) throws -> Output { + try execute(method, [input]) + } + + /// Executes a specific method with no inputs and decodes into `Output` generic type. + /// + /// - Parameter method: The name of the method to execute. + /// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch. + /// - Throws: An error if loading, execution or result conversion fails. + func execute(_ method: String) throws -> Output { + try execute(method, []) + } + + /// Executes the "forward" method and decodes into `Output` generic type. + /// + /// - Parameters: + /// - inputs: An array of `ValueConvertible` inputs to pass to "forward". + /// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch. + /// - Throws: An error if loading, execution or result conversion fails. + func forward(_ inputs: [ValueConvertible]) throws -> Output { + try execute("forward", inputs) + } + + /// Executes the "forward" method with variadic inputs and decodes into `Output` generic type. + /// + /// - Parameters: + /// - inputs: A variadic list of `ValueConvertible` inputs. + /// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch. + /// - Throws: An error if loading, execution or result conversion fails. + func forward(_ inputs: ValueConvertible...) throws -> Output { + try forward(inputs) + } + + /// Executes the "forward" method with a single input and decodes into `Output` generic type. + /// + /// - Parameters: + /// - input: A single `ValueConvertible` to pass to "forward". + /// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch. + /// - Throws: An error if loading, execution or result conversion fails. + func forward(_ input: ValueConvertible) throws -> Output { + try forward([input]) + } + + /// Executes the "forward" method with no inputs and decodes into `Output` generic type. + /// + /// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch. + /// - Throws: An error if loading, execution or result conversion fails. + func forward() throws -> Output { + try execute("forward") + } +} diff --git a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift index 0aaeaefbcd3..a35247f9bce 100644 --- a/extension/apple/ExecuTorch/__tests__/ModuleTest.swift +++ b/extension/apple/ExecuTorch/__tests__/ModuleTest.swift @@ -81,7 +81,34 @@ class ModuleTest: XCTestCase { XCTAssertEqual(outputs4?.first?.tensor(), Tensor([Float(5)])) } - func testmethodMetadata() throws { + func testForwardReturnConversion() throws { + guard let modelPath = resourceBundle.path(forResource: "add", ofType: "pte") else { + XCTFail("Couldn't find the model file") + return + } + let module = Module(filePath: modelPath) + let inputs: [Tensor] = [Tensor([1]), Tensor([1])] + + let outputValues: [Value] = try module.forward(inputs) + XCTAssertEqual(outputValues, [Value(Tensor([2]))]) + + let outputValue: Value = try module.forward(inputs) + XCTAssertEqual(outputValue, Value(Tensor([2]))) + + let outputTensors: [Tensor] = try module.forward(inputs) + XCTAssertEqual(outputTensors, [Tensor([2])]) + + let outputTensor: Tensor = try module.forward(Tensor([1]), Tensor([1])) + XCTAssertEqual(outputTensor, Tensor([2])) + + let scalars = (try module.forward(Tensor([1]), Tensor([1])) as Tensor).scalars() + XCTAssertEqual(scalars, [2]) + + let scalars2 = try Tensor(module.forward(Tensor([1]), Tensor([1]))).scalars() + XCTAssertEqual(scalars2, [2]) + } + + func testMethodMetadata() throws { guard let modelPath = resourceBundle.path(forResource: "add", ofType: "pte") else { XCTFail("Couldn't find the model file") return