From ed4152eb88e5450fc2a3a0ac92974682156d1830 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Wed, 26 Mar 2025 18:22:40 -0700 Subject: [PATCH] Tensor constructor to create with an existing Tensor. Summary: https://github.com/pytorch/executorch/issues/8366 Reviewed By: mergennachin Differential Revision: D71906972 --- .../apple/ExecuTorch/Exported/ExecuTorchTensor.h | 9 +++++++++ .../apple/ExecuTorch/Exported/ExecuTorchTensor.mm | 8 ++++++++ .../apple/ExecuTorch/__tests__/TensorTest.swift | 14 ++++++++++++++ 3 files changed, 31 insertions(+) diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h index 3b2aa4db7b3..6e0fba77abd 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h @@ -151,6 +151,15 @@ __attribute__((deprecated("This API is experimental."))) - (instancetype)initWithNativeInstance:(void *)nativeInstance NS_DESIGNATED_INITIALIZER NS_SWIFT_UNAVAILABLE(""); +/** + * Creates a new tensor by copying an existing tensor. + * + * @param otherTensor The tensor instance to copy. + * @return A new ExecuTorchTensor instance that is a copy of otherTensor. + */ +- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor + NS_SWIFT_NAME(init(_:)); + /** * Executes a block with a pointer to the tensor's immutable byte data. * diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index d2c3cb08cea..4c97642ba38 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -44,6 +44,14 @@ - (instancetype)initWithNativeInstance:(void *)nativeInstance { return self; } +- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor { + ET_CHECK(otherTensor); + auto tensor = make_tensor_ptr( + **reinterpret_cast(otherTensor.nativeInstance) + ); + return [self initWithNativeInstance:&tensor]; +} + - (void *)nativeInstance { return &_tensor; } diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index 428b5ba0450..2d23bfaea73 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -137,4 +137,18 @@ class TensorTest: XCTestCase { XCTAssertEqual(updatedData, [2, 4, 6, 8]) } } + + func testInitWithTensor() { + var data: [Int] = [10, 20, 30, 40] + let tensor1 = data.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2], dataType: .int) + } + let tensor2 = Tensor(tensor1) + + XCTAssertEqual(tensor2.dataType, tensor1.dataType) + XCTAssertEqual(tensor2.shape, tensor1.shape) + XCTAssertEqual(tensor2.strides, tensor1.strides) + XCTAssertEqual(tensor2.dimensionOrder, tensor1.dimensionOrder) + XCTAssertEqual(tensor2.count, tensor1.count) + } }