diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h index 6e0fba77abd..152a74b7cb2 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h @@ -91,7 +91,7 @@ NSInteger ExecuTorchElementCountOfShape(NSArray *shape) */ NS_SWIFT_NAME(Tensor) __attribute__((deprecated("This API is experimental."))) -@interface ExecuTorchTensor : NSObject +@interface ExecuTorchTensor : NSObject /** * Pointer to the underlying native TensorPtr instance. @@ -160,6 +160,13 @@ __attribute__((deprecated("This API is experimental."))) - (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor NS_SWIFT_NAME(init(_:)); +/** + * Returns a copy of the tensor. + * + * @return A new ExecuTorchTensor instance that is a duplicate of the current tensor. + */ +- (instancetype)copy; + /** * Executes a block with a pointer to the tensor's immutable byte data. * diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index 4c97642ba38..912bc4f59d2 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -52,6 +52,15 @@ - (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor { return [self initWithNativeInstance:&tensor]; } +- (instancetype)copy { + return [self copyWithZone:nil]; +} + +- (instancetype)copyWithZone:(nullable NSZone *)zone { + auto tensor = clone_tensor_ptr(_tensor); + return [[ExecuTorchTensor allocWithZone:zone] initWithNativeInstance:&tensor]; +} + - (void *)nativeInstance { return &_tensor; } diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index 2d23bfaea73..fef9da87906 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -151,4 +151,18 @@ class TensorTest: XCTestCase { XCTAssertEqual(tensor2.dimensionOrder, tensor1.dimensionOrder) XCTAssertEqual(tensor2.count, tensor1.count) } + + func testCopy() { + var data: [Double] = [10.0, 20.0, 30.0, 40.0] + let tensor1 = data.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2], dataType: .double) + } + let tensor2 = tensor1.copy() + + XCTAssertEqual(tensor1.dataType, tensor2.dataType) + XCTAssertEqual(tensor1.shape, tensor2.shape) + XCTAssertEqual(tensor1.strides, tensor2.strides) + XCTAssertEqual(tensor1.dimensionOrder, tensor2.dimensionOrder) + XCTAssertEqual(tensor1.count, tensor2.count) + } }