From 70a695ac609921b48b4aa37680cc1aadef6fa6a8 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Thu, 22 May 2025 12:01:16 -0700 Subject: [PATCH 1/4] Create a native Tensor swift extension --- .../Exported/ExecuTorch+Tensor.swift | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift new file mode 100644 index 00000000000..f1da13e7ae1 --- /dev/null +++ b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +@_exported import ExecuTorch + +/// A protocol that types conform to in order to be used as tensor element types. +/// Provides the mapping from the Swift type to the underlying `DataType`. +@available(*, deprecated, message: "This API is experimental.") +protocol Scalar { + /// The `DataType` corresponding to this scalar type. + static var dataType: DataType { get } +} + +@available(*, deprecated, message: "This API is experimental.") +extension UInt8: Scalar { static var dataType: DataType { .byte } } +@available(*, deprecated, message: "This API is experimental.") +extension Int8: Scalar { static var dataType: DataType { .char } } +@available(*, deprecated, message: "This API is experimental.") +extension Int16: Scalar { static var dataType: DataType { .short } } +@available(*, deprecated, message: "This API is experimental.") +extension Int32: Scalar { static var dataType: DataType { .int } } +@available(*, deprecated, message: "This API is experimental.") +extension Int64: Scalar { static var dataType: DataType { .long } } +@available(*, deprecated, message: "This API is experimental.") +extension Int: Scalar { static var dataType: DataType { .long } } +@available(macOS 11.0, *) +@available(*, deprecated, message: "This API is experimental.") +extension Float16: Scalar { static var dataType: DataType { .half } } +@available(*, deprecated, message: "This API is experimental.") +extension Float: Scalar { static var dataType: DataType { .float } } +@available(*, deprecated, message: "This API is experimental.") +extension Double: Scalar { static var dataType: DataType { .double } } +@available(*, deprecated, message: "This API is experimental.") +extension Bool: Scalar { static var dataType: DataType { .bool } } +@available(*, deprecated, message: "This API is experimental.") +extension UInt16: Scalar { static var dataType: DataType { .uInt16 } } +@available(*, deprecated, message: "This API is experimental.") +extension UInt32: Scalar { static var dataType: DataType { .uInt32 } } +@available(*, deprecated, message: "This API is experimental.") +extension UInt64: Scalar { static var dataType: DataType { .uInt64 } } +@available(*, deprecated, message: "This API is experimental.") +extension UInt: Scalar { static var dataType: DataType { .uInt64 } } + +@available(*, deprecated, message: "This API is experimental.") +extension Tensor { + /// Calls the closure with a typed, immutable buffer pointer over the tensor’s elements. + /// + /// - Parameter body: A closure that receives an `UnsafeBufferPointer` bound to the tensor’s data. + /// - Returns: The value returned by `body`. + /// - Throws: `Error(code: .invalidArgument)` if `T.dataType` doesn’t match the tensor’s `dataType`, + /// or any error thrown by `body`. + func withUnsafeBytes(_ body: (UnsafeBufferPointer) throws -> R) throws -> R { + guard dataType == T.dataType else { throw Error(code: .invalidArgument) } + var result: Result? + bytes { pointer, count, _ in + result = Result { try body( + UnsafeBufferPointer( + start: pointer.assumingMemoryBound(to: T.self), + count: count + ) + ) } + } + return try result!.get() + } + + /// Calls the closure with a typed, mutable buffer pointer over the tensor’s elements. + /// + /// - Parameter body: A closure that receives an `UnsafeMutableBufferPointer` bound to the tensor’s data. + /// - Returns: The value returned by `body`. + /// - Throws: `Error(code: .invalidArgument)` if `T.dataType` doesn’t match the tensor’s `dataType`, + /// or any error thrown by `body`. + func withUnsafeMutableBytes(_ body: (UnsafeMutableBufferPointer) throws -> R) throws -> R { + guard dataType == T.dataType else { throw Error(code: .invalidArgument) } + var result: Result? + mutableBytes { pointer, count, _ in + result = Result { try body( + UnsafeMutableBufferPointer( + start: pointer.assumingMemoryBound(to: T.self), + count: count + ) + ) } + } + return try result!.get() + } +} From 129aac6e391d3f5168ea1b7526cd3cbe2c8cba0b Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Thu, 22 May 2025 12:02:14 -0700 Subject: [PATCH 2/4] Update TensorTest.swift --- .../ExecuTorch/__tests__/TensorTest.swift | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index 0233a30f780..d282ac9251b 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -148,6 +148,54 @@ class TensorTest: XCTestCase { } } + func testWithUnsafeBytes() throws { + var data: [Float] = [1, 2, 3, 4, 5, 6] + let tensor = data.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3], dataType: .float) + } + let array: [Float] = try tensor.withUnsafeBytes { Array($0) } + XCTAssertEqual(array, data) + } + + func testWithUnsafeMutableBytes() throws { + var data = [1, 2, 3, 4] + let tensor = data.withUnsafeMutableBytes { + Tensor(bytes: $0.baseAddress!, shape: [4], dataType: .long) + } + try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer) in + for i in buffer.indices { + buffer[i] *= 2 + } + } + try tensor.withUnsafeBytes { buffer in + XCTAssertEqual(Array(buffer), [2, 4, 6, 8]) + } + } + + func testWithUnsafeBytesFloat16() throws { + var data: [Float16] = [1, 2, 3, 4, 5, 6] + let tensor = data.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape: [6], dataType: .half) + } + let array: [Float16] = try tensor.withUnsafeBytes { Array($0) } + XCTAssertEqual(array, data) + } + + func testWithUnsafeMutableBytesFloat16() throws { + var data: [Float16] = [1, 2, 3, 4] + let tensor = data.withUnsafeMutableBytes { buf in + Tensor(bytes: buf.baseAddress!, shape: [4], dataType: .half) + } + try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer) in + for i in buffer.indices { + buffer[i] *= 2 + } + } + try tensor.withUnsafeBytes { buffer in + XCTAssertEqual(Array(buffer), data.map { $0 * 2 }) + } + } + func testInitWithTensor() { var data: [Int] = [10, 20, 30, 40] let tensor1 = data.withUnsafeMutableBytes { @@ -618,7 +666,7 @@ class TensorTest: XCTestCase { } } } - + func testZeros() { let tensor = Tensor.zeros(shape: [2, 3], dataType: .double) XCTAssertEqual(tensor.shape, [2, 3]) From 1b7de7b8e08bee5d2ca2b9b1dc3498d74c2a9039 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Thu, 22 May 2025 12:02:55 -0700 Subject: [PATCH 3/4] Update build_apple_frameworks.sh --- scripts/build_apple_frameworks.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/build_apple_frameworks.sh b/scripts/build_apple_frameworks.sh index 7615ad204bf..d93a2397f94 100755 --- a/scripts/build_apple_frameworks.sh +++ b/scripts/build_apple_frameworks.sh @@ -33,7 +33,7 @@ libextension_data_loader.a,\ libextension_flat_tensor.a,\ libextension_module.a,\ libextension_tensor.a,\ -:$HEADERS_PATH" +:$HEADERS_PATH:ExecuTorch" FRAMEWORK_BACKEND_COREML="backend_coreml:\ libcoreml_util.a,\ From 38e45326f7cb57337594e2ecf0c9413b99f06599 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Thu, 22 May 2025 12:06:38 -0700 Subject: [PATCH 4/4] Update TensorTest.swift --- extension/apple/ExecuTorch/__tests__/TensorTest.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index d282ac9251b..01fe094ea7a 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -183,8 +183,8 @@ class TensorTest: XCTestCase { func testWithUnsafeMutableBytesFloat16() throws { var data: [Float16] = [1, 2, 3, 4] - let tensor = data.withUnsafeMutableBytes { buf in - Tensor(bytes: buf.baseAddress!, shape: [4], dataType: .half) + let tensor = data.withUnsafeMutableBytes { buffer in + Tensor(bytes: buffer.baseAddress!, shape: [4], dataType: .half) } try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer) in for i in buffer.indices {