diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift new file mode 100644 index 00000000000..f1da13e7ae1 --- /dev/null +++ b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +@_exported import ExecuTorch + +/// A protocol that types conform to in order to be used as tensor element types. +/// Provides the mapping from the Swift type to the underlying `DataType`. +@available(*, deprecated, message: "This API is experimental.") +protocol Scalar { + /// The `DataType` corresponding to this scalar type. + static var dataType: DataType { get } +} + +@available(*, deprecated, message: "This API is experimental.") +extension UInt8: Scalar { static var dataType: DataType { .byte } } +@available(*, deprecated, message: "This API is experimental.") +extension Int8: Scalar { static var dataType: DataType { .char } } +@available(*, deprecated, message: "This API is experimental.") +extension Int16: Scalar { static var dataType: DataType { .short } } +@available(*, deprecated, message: "This API is experimental.") +extension Int32: Scalar { static var dataType: DataType { .int } } +@available(*, deprecated, message: "This API is experimental.") +extension Int64: Scalar { static var dataType: DataType { .long } } +@available(*, deprecated, message: "This API is experimental.") +extension Int: Scalar { static var dataType: DataType { .long } } +@available(macOS 11.0, *) +@available(*, deprecated, message: "This API is experimental.") +extension Float16: Scalar { static var dataType: DataType { .half } } +@available(*, deprecated, message: "This API is experimental.") +extension Float: Scalar { static var dataType: DataType { .float } } +@available(*, deprecated, message: "This API is experimental.") +extension Double: Scalar { static var dataType: DataType { .double } } +@available(*, deprecated, message: "This API is experimental.") +extension Bool: Scalar { static var dataType: DataType { .bool } } +@available(*, deprecated, message: "This API is experimental.") +extension UInt16: Scalar { static var dataType: DataType { .uInt16 } } +@available(*, deprecated, message: "This API is experimental.") +extension UInt32: Scalar { static var dataType: DataType { .uInt32 } } +@available(*, deprecated, message: "This API is experimental.") +extension UInt64: Scalar { static var dataType: DataType { .uInt64 } } +@available(*, deprecated, message: "This API is experimental.") +extension UInt: Scalar { static var dataType: DataType { .uInt64 } } + +@available(*, deprecated, message: "This API is experimental.") +extension Tensor { + /// Calls the closure with a typed, immutable buffer pointer over the tensor’s elements. + /// + /// - Parameter body: A closure that receives an `UnsafeBufferPointer` bound to the tensor’s data. + /// - Returns: The value returned by `body`. + /// - Throws: `Error(code: .invalidArgument)` if `T.dataType` doesn’t match the tensor’s `dataType`, + /// or any error thrown by `body`. + func withUnsafeBytes(_ body: (UnsafeBufferPointer) throws -> R) throws -> R { + guard dataType == T.dataType else { throw Error(code: .invalidArgument) } + var result: Result? + bytes { pointer, count, _ in + result = Result { try body( + UnsafeBufferPointer( + start: pointer.assumingMemoryBound(to: T.self), + count: count + ) + ) } + } + return try result!.get() + } + + /// Calls the closure with a typed, mutable buffer pointer over the tensor’s elements. + /// + /// - Parameter body: A closure that receives an `UnsafeMutableBufferPointer` bound to the tensor’s data. + /// - Returns: The value returned by `body`. + /// - Throws: `Error(code: .invalidArgument)` if `T.dataType` doesn’t match the tensor’s `dataType`, + /// or any error thrown by `body`. + func withUnsafeMutableBytes(_ body: (UnsafeMutableBufferPointer) throws -> R) throws -> R { + guard dataType == T.dataType else { throw Error(code: .invalidArgument) } + var result: Result? + mutableBytes { pointer, count, _ in + result = Result { try body( + UnsafeMutableBufferPointer( + start: pointer.assumingMemoryBound(to: T.self), + count: count + ) + ) } + } + return try result!.get() + } +} diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index 0233a30f780..01fe094ea7a 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -148,6 +148,54 @@ class TensorTest: XCTestCase { } } + func testWithUnsafeBytes() throws { + var data: [Float] = [1, 2, 3, 4, 5, 6] + let tensor = data.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3], dataType: .float) + } + let array: [Float] = try tensor.withUnsafeBytes { Array($0) } + XCTAssertEqual(array, data) + } + + func testWithUnsafeMutableBytes() throws { + var data = [1, 2, 3, 4] + let tensor = data.withUnsafeMutableBytes { + Tensor(bytes: $0.baseAddress!, shape: [4], dataType: .long) + } + try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer) in + for i in buffer.indices { + buffer[i] *= 2 + } + } + try tensor.withUnsafeBytes { buffer in + XCTAssertEqual(Array(buffer), [2, 4, 6, 8]) + } + } + + func testWithUnsafeBytesFloat16() throws { + var data: [Float16] = [1, 2, 3, 4, 5, 6] + let tensor = data.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape: [6], dataType: .half) + } + let array: [Float16] = try tensor.withUnsafeBytes { Array($0) } + XCTAssertEqual(array, data) + } + + func testWithUnsafeMutableBytesFloat16() throws { + var data: [Float16] = [1, 2, 3, 4] + let tensor = data.withUnsafeMutableBytes { buffer in + Tensor(bytes: buffer.baseAddress!, shape: [4], dataType: .half) + } + try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer) in + for i in buffer.indices { + buffer[i] *= 2 + } + } + try tensor.withUnsafeBytes { buffer in + XCTAssertEqual(Array(buffer), data.map { $0 * 2 }) + } + } + func testInitWithTensor() { var data: [Int] = [10, 20, 30, 40] let tensor1 = data.withUnsafeMutableBytes { @@ -618,7 +666,7 @@ class TensorTest: XCTestCase { } } } - + func testZeros() { let tensor = Tensor.zeros(shape: [2, 3], dataType: .double) XCTAssertEqual(tensor.shape, [2, 3]) diff --git a/scripts/build_apple_frameworks.sh b/scripts/build_apple_frameworks.sh index 7615ad204bf..d93a2397f94 100755 --- a/scripts/build_apple_frameworks.sh +++ b/scripts/build_apple_frameworks.sh @@ -33,7 +33,7 @@ libextension_data_loader.a,\ libextension_flat_tensor.a,\ libextension_module.a,\ libextension_tensor.a,\ -:$HEADERS_PATH" +:$HEADERS_PATH:ExecuTorch" FRAMEWORK_BACKEND_COREML="backend_coreml:\ libcoreml_util.a,\