Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ __attribute__((deprecated("This API is experimental.")))
- (instancetype)initWithNativeInstance:(void *)nativeInstance
NS_DESIGNATED_INITIALIZER NS_SWIFT_UNAVAILABLE("");

/**
* Executes a block with a pointer to the tensor's immutable byte data.
*
* @param handler A block that receives:
* - a pointer to the data,
* - the total number of elements,
* - and the data type.
*/
- (void)bytesWithHandler:(void (^)(const void *pointer, NSInteger count, ExecuTorchDataType dataType))handler
NS_SWIFT_NAME(bytes(_:));

/**
* Executes a block with a pointer to the tensor's mutable byte data.
*
* @param handler A block that receives:
* - a mutable pointer to the data,
* - the total number of elements,
* - and the data type.
*/
- (void)mutableBytesWithHandler:(void (^)(void *pointer, NSInteger count, ExecuTorchDataType dataType))handler
NS_SWIFT_NAME(mutableBytes(_:));

+ (instancetype)new NS_UNAVAILABLE;
- (instancetype)init NS_UNAVAILABLE;

Expand Down
10 changes: 10 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ - (NSInteger)count {
return _tensor->numel();
}

- (void)bytesWithHandler:(void (^)(const void *pointer, NSInteger count, ExecuTorchDataType type))handler {
ET_CHECK(handler);
handler(_tensor->unsafeGetTensorImpl()->data(), self.count, self.dataType);
}

- (void)mutableBytesWithHandler:(void (^)(void *pointer, NSInteger count, ExecuTorchDataType dataType))handler {
ET_CHECK(handler);
handler(_tensor->unsafeGetTensorImpl()->mutable_data(), self.count, self.dataType);
}

@end

@implementation ExecuTorchTensor (BytesNoCopy)
Expand Down
42 changes: 42 additions & 0 deletions extension/apple/ExecuTorch/__tests__/TensorTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,45 @@ class TensorTest: XCTestCase {
let tensor = data.withUnsafeMutableBytes {
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3], dataType: .float)
}
// Modify the original data to make sure the tensor does not copy the data.
data.indices.forEach { data[$0] += 1 }

XCTAssertEqual(tensor.dataType, .float)
XCTAssertEqual(tensor.shape, [2, 3])
XCTAssertEqual(tensor.strides, [3, 1])
XCTAssertEqual(tensor.dimensionOrder, [0, 1])
XCTAssertEqual(tensor.shapeDynamism, .dynamicBound)
XCTAssertEqual(tensor.count, 6)

tensor.bytes { pointer, count, dataType in
XCTAssertEqual(dataType, .float)
XCTAssertEqual(count, 6)
XCTAssertEqual(size(ofDataType: dataType), 4)
XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count)), data)
}
}

func testInitBytes() {
var data: [Double] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
let tensor = data.withUnsafeMutableBytes {
Tensor(bytes: $0.baseAddress!, shape: [2, 3], dataType: .double)
}
// Modify the original data to make sure the tensor copies the data.
data.indices.forEach { data[$0] += 1 }

XCTAssertEqual(tensor.dataType, .double)
XCTAssertEqual(tensor.shape, [2, 3])
XCTAssertEqual(tensor.strides, [3, 1])
XCTAssertEqual(tensor.dimensionOrder, [0, 1])
XCTAssertEqual(tensor.shapeDynamism, .dynamicBound)
XCTAssertEqual(tensor.count, 6)

tensor.bytes { pointer, count, dataType in
XCTAssertEqual(dataType, .double)
XCTAssertEqual(count, 6)
XCTAssertEqual(size(ofDataType: dataType), 8)
XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Double.self), count: count)).map { $0 + 1 }, data)
}
}

func testWithCustomStridesAndDimensionOrder() {
Expand All @@ -94,5 +114,27 @@ class TensorTest: XCTestCase {
XCTAssertEqual(tensor.strides, [1, 2])
XCTAssertEqual(tensor.dimensionOrder, [1, 0])
XCTAssertEqual(tensor.count, 4)

tensor.bytes { pointer, count, dataType in
XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count)), data)
}
}

func testMutableBytes() {
var data: [Int32] = [1, 2, 3, 4]
let tensor = data.withUnsafeMutableBytes {
Tensor(bytes: $0.baseAddress!, shape: [4], dataType: .int)
}
tensor.mutableBytes { pointer, count, dataType in
XCTAssertEqual(dataType, .int)
let buffer = pointer.assumingMemoryBound(to: Int32.self)
for i in 0..<count {
buffer[i] *= 2
}
}
tensor.bytes { pointer, count, dataType in
let updatedData = Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int32.self), count: count))
XCTAssertEqual(updatedData, [2, 4, 6, 8])
}
}
}
Loading