Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/using-executorch-ios.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ let imageBuffer: UnsafeMutableRawPointer = ... // Existing image buffer
let inputTensor = Tensor<Float>(&imageBuffer, shape: [1, 3, 224, 224])

// Execute the 'forward' method with the given input tensor and get an output tensor back.
let outputTensor: Tensor<Float> = try module.forward(inputTensor)[0].tensor()!
let outputTensor: Tensor<Float> = try module.forward(inputTensor)!

// Copy the tensor data into logits array for easier access.
let logits = outputTensor.scalars()
Expand Down
83 changes: 83 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorch+Module.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,86 @@ public extension Module {
try forward(inputs)
}
}

@available(*, deprecated, message: "This API is experimental.")
public extension Module {
/// Executes a specific method and decodes the outputs into `Output` generic type.
///
/// - Parameters:
/// - method: The name of the method to execute.
/// - inputs: An array of `ValueConvertible` inputs.
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
/// - Throws: An error if loading, execution or result conversion fails.
func execute<Output: ValueSequenceConstructible>(_ method: String, _ inputs: [ValueConvertible]) throws -> Output {
try Output(__executeMethod(method, withInputs: inputs.map { $0.asValue() }))
}

/// Executes a specific method with variadic inputs and decodes into `Output` generic type.
///
/// - Parameters:
/// - method: The name of the method to execute.
/// - inputs: A variadic list of `ValueConvertible` inputs.
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
/// - Throws: An error if loading, execution or result conversion fails.
func execute<Output: ValueSequenceConstructible>(_ method: String, _ inputs: ValueConvertible...) throws -> Output {
try execute(method, inputs)
}

/// Executes a specific method with a single input and decodes into `Output` generic type.
///
/// - Parameters:
/// - method: The name of the method to execute.
/// - input: A single `ValueConvertible` input.
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
/// - Throws: An error if loading, execution or result conversion fails.
func execute<Output: ValueSequenceConstructible>(_ method: String, _ input: ValueConvertible) throws -> Output {
try execute(method, [input])
}

/// Executes a specific method with no inputs and decodes into `Output` generic type.
///
/// - Parameter method: The name of the method to execute.
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
/// - Throws: An error if loading, execution or result conversion fails.
func execute<Output: ValueSequenceConstructible>(_ method: String) throws -> Output {
try execute(method, [])
}

/// Executes the "forward" method and decodes into `Output` generic type.
///
/// - Parameters:
/// - inputs: An array of `ValueConvertible` inputs to pass to "forward".
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
/// - Throws: An error if loading, execution or result conversion fails.
func forward<Output: ValueSequenceConstructible>(_ inputs: [ValueConvertible]) throws -> Output {
try execute("forward", inputs)
}

/// Executes the "forward" method with variadic inputs and decodes into `Output` generic type.
///
/// - Parameters:
/// - inputs: A variadic list of `ValueConvertible` inputs.
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
/// - Throws: An error if loading, execution or result conversion fails.
func forward<Output: ValueSequenceConstructible>(_ inputs: ValueConvertible...) throws -> Output {
try forward(inputs)
}

/// Executes the "forward" method with a single input and decodes into `Output` generic type.
///
/// - Parameters:
/// - input: A single `ValueConvertible` to pass to "forward".
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
/// - Throws: An error if loading, execution or result conversion fails.
func forward<Output: ValueSequenceConstructible>(_ input: ValueConvertible) throws -> Output {
try forward([input])
}

/// Executes the "forward" method with no inputs and decodes into `Output` generic type.
///
/// - Returns: An instance of `Output` decoded from the returned `[Value]`, or `nil` on mismatch.
/// - Throws: An error if loading, execution or result conversion fails.
func forward<Output: ValueSequenceConstructible>() throws -> Output {
try execute("forward")
}
}
29 changes: 28 additions & 1 deletion extension/apple/ExecuTorch/__tests__/ModuleTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,34 @@ class ModuleTest: XCTestCase {
XCTAssertEqual(outputs4?.first?.tensor(), Tensor([Float(5)]))
}

func testmethodMetadata() throws {
func testForwardReturnConversion() throws {
guard let modelPath = resourceBundle.path(forResource: "add", ofType: "pte") else {
XCTFail("Couldn't find the model file")
return
}
let module = Module(filePath: modelPath)
let inputs: [Tensor<Float>] = [Tensor([1]), Tensor([1])]

let outputValues: [Value] = try module.forward(inputs)
XCTAssertEqual(outputValues, [Value(Tensor<Float>([2]))])

let outputValue: Value = try module.forward(inputs)
XCTAssertEqual(outputValue, Value(Tensor<Float>([2])))

let outputTensors: [Tensor<Float>] = try module.forward(inputs)
XCTAssertEqual(outputTensors, [Tensor([2])])

let outputTensor: Tensor<Float> = try module.forward(Tensor<Float>([1]), Tensor<Float>([1]))
XCTAssertEqual(outputTensor, Tensor([2]))

let scalars = (try module.forward(Tensor<Float>([1]), Tensor<Float>([1])) as Tensor<Float>).scalars()
XCTAssertEqual(scalars, [2])

let scalars2 = try Tensor<Float>(module.forward(Tensor<Float>([1]), Tensor<Float>([1]))).scalars()
XCTAssertEqual(scalars2, [2])
}

func testMethodMetadata() throws {
guard let modelPath = resourceBundle.path(forResource: "add", ofType: "pte") else {
XCTFail("Couldn't find the model file")
return
Expand Down
Loading