From 2a74da5a324ee31c84dfcf831fb85a2ef224da18 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Wed, 26 Mar 2025 18:09:14 -0700 Subject: [PATCH] Tensor accessorts to get raw data buffer. Summary: https://github.com/pytorch/executorch/issues/8366 Reviewed By: mergennachin Differential Revision: D71905971 --- .../ExecuTorch/Exported/ExecuTorchTensor.h | 22 ++++++++++ .../ExecuTorch/Exported/ExecuTorchTensor.mm | 10 +++++ .../ExecuTorch/__tests__/TensorTest.swift | 42 +++++++++++++++++++ 3 files changed, 74 insertions(+) diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h index 627f28092b6..3b2aa4db7b3 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h @@ -151,6 +151,28 @@ __attribute__((deprecated("This API is experimental."))) - (instancetype)initWithNativeInstance:(void *)nativeInstance NS_DESIGNATED_INITIALIZER NS_SWIFT_UNAVAILABLE(""); +/** + * Executes a block with a pointer to the tensor's immutable byte data. + * + * @param handler A block that receives: + * - a pointer to the data, + * - the total number of elements, + * - and the data type. + */ +- (void)bytesWithHandler:(void (^)(const void *pointer, NSInteger count, ExecuTorchDataType dataType))handler + NS_SWIFT_NAME(bytes(_:)); + +/** + * Executes a block with a pointer to the tensor's mutable byte data. + * + * @param handler A block that receives: + * - a mutable pointer to the data, + * - the total number of elements, + * - and the data type. + */ +- (void)mutableBytesWithHandler:(void (^)(void *pointer, NSInteger count, ExecuTorchDataType dataType))handler + NS_SWIFT_NAME(mutableBytes(_:)); + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index 7f9a89742ea..d2c3cb08cea 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -81,6 +81,16 @@ - (NSInteger)count { return _tensor->numel(); } +- (void)bytesWithHandler:(void (^)(const void *pointer, NSInteger count, ExecuTorchDataType type))handler { + ET_CHECK(handler); + handler(_tensor->unsafeGetTensorImpl()->data(), self.count, self.dataType); +} + +- (void)mutableBytesWithHandler:(void (^)(void *pointer, NSInteger count, ExecuTorchDataType dataType))handler { + ET_CHECK(handler); + handler(_tensor->unsafeGetTensorImpl()->mutable_data(), self.count, self.dataType); +} + @end @implementation ExecuTorchTensor (BytesNoCopy) diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index 1d023abebe0..428b5ba0450 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -60,12 +60,22 @@ class TensorTest: XCTestCase { let tensor = data.withUnsafeMutableBytes { Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3], dataType: .float) } + // Modify the original data to make sure the tensor does not copy the data. + data.indices.forEach { data[$0] += 1 } + XCTAssertEqual(tensor.dataType, .float) XCTAssertEqual(tensor.shape, [2, 3]) XCTAssertEqual(tensor.strides, [3, 1]) XCTAssertEqual(tensor.dimensionOrder, [0, 1]) XCTAssertEqual(tensor.shapeDynamism, .dynamicBound) XCTAssertEqual(tensor.count, 6) + + tensor.bytes { pointer, count, dataType in + XCTAssertEqual(dataType, .float) + XCTAssertEqual(count, 6) + XCTAssertEqual(size(ofDataType: dataType), 4) + XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count)), data) + } } func testInitBytes() { @@ -73,12 +83,22 @@ class TensorTest: XCTestCase { let tensor = data.withUnsafeMutableBytes { Tensor(bytes: $0.baseAddress!, shape: [2, 3], dataType: .double) } + // Modify the original data to make sure the tensor copies the data. + data.indices.forEach { data[$0] += 1 } + XCTAssertEqual(tensor.dataType, .double) XCTAssertEqual(tensor.shape, [2, 3]) XCTAssertEqual(tensor.strides, [3, 1]) XCTAssertEqual(tensor.dimensionOrder, [0, 1]) XCTAssertEqual(tensor.shapeDynamism, .dynamicBound) XCTAssertEqual(tensor.count, 6) + + tensor.bytes { pointer, count, dataType in + XCTAssertEqual(dataType, .double) + XCTAssertEqual(count, 6) + XCTAssertEqual(size(ofDataType: dataType), 8) + XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Double.self), count: count)).map { $0 + 1 }, data) + } } func testWithCustomStridesAndDimensionOrder() { @@ -94,5 +114,27 @@ class TensorTest: XCTestCase { XCTAssertEqual(tensor.strides, [1, 2]) XCTAssertEqual(tensor.dimensionOrder, [1, 0]) XCTAssertEqual(tensor.count, 4) + + tensor.bytes { pointer, count, dataType in + XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count)), data) + } + } + + func testMutableBytes() { + var data: [Int32] = [1, 2, 3, 4] + let tensor = data.withUnsafeMutableBytes { + Tensor(bytes: $0.baseAddress!, shape: [4], dataType: .int) + } + tensor.mutableBytes { pointer, count, dataType in + XCTAssertEqual(dataType, .int) + let buffer = pointer.assumingMemoryBound(to: Int32.self) + for i in 0..