diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift index 55920ce541f..708d68ba2bd 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift +++ b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift @@ -741,6 +741,14 @@ public final class Tensor: Equatable { Tensor(anyTensor.copy()) } + /// Returns a copy of the tensor, converted to the specified scalar type. + /// + /// - Parameter dataType: The target scalar type. + /// - Returns: A new tensor with the same shape and metadata but converted elements. + public func copy(to dataType: U.Type) -> Tensor { + Tensor(anyTensor.copy(to: U.dataType)) + } + /// Calls the closure with a typed, immutable buffer pointer over the tensor’s elements. /// /// - Parameter body: A closure that receives an `UnsafeBufferPointer` bound to the tensor’s data. diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h index 53d23258b7e..ec70cb9ceaa 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h @@ -168,6 +168,16 @@ __attribute__((objc_subclassing_restricted)) */ - (instancetype)copy; +/** + * Creates a deep copy of the tensor, potentially casting to a new data type. + * The new tensor will have its own copy of the data. + * + * @param dataType The desired data type for the new tensor. + * @return A new ExecuTorchTensor instance that is a duplicate (and possibly casted) of the current tensor. +*/ +- (instancetype)copyToDataType:(ExecuTorchDataType)dataType + NS_SWIFT_NAME(copy(to:)); + /** * Executes a block with a pointer to the tensor's immutable byte data. * diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index 24d355782c0..b1a24ef1f4f 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -147,6 +147,11 @@ - (instancetype)copyWithZone:(nullable NSZone *)zone { return [[ExecuTorchTensor allocWithZone:zone] initWithNativeInstance:&tensor]; } +- (instancetype)copyToDataType:(ExecuTorchDataType)dataType { + auto tensor = clone_tensor_ptr(_tensor, static_cast(dataType)); + return [[ExecuTorchTensor alloc] initWithNativeInstance:&tensor]; +} + - (void *)nativeInstance { return &_tensor; } diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index e2c330ac1c2..2a8d7f9b58a 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -182,6 +182,43 @@ class TensorTest: XCTestCase { XCTAssertEqual(tensor1.count, tensor2.count) } + func testCopyToSameDataType() { + let tensor1 = Tensor([1, 2, 3, 4], shape: [2, 2]) + let tensor2 = tensor1.copy(to: Float.self) + XCTAssertEqual(tensor2.dataType, .float) + XCTAssertEqual(tensor2.shape, [2, 2]) + XCTAssertEqual(tensor2.strides, tensor1.strides) + XCTAssertEqual(tensor2.dimensionOrder, tensor1.dimensionOrder) + XCTAssertEqual(tensor2.scalars(), [1, 2, 3, 4]) + } + + func testCopyToDifferentDataTypeKeepsSourceAlive() { + var data = [10.0, 20.0, 30.0, 40.0] + let tensor1 = data.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2]) + } + let tensor2 = tensor1.copy(to: Float.self) + data[0] = 999.0 + XCTAssertEqual(tensor2.dataType, .float) + XCTAssertEqual(tensor2.shape, [2, 2]) + XCTAssertEqual(tensor2.scalars(), [10.0, 20.0, 30.0, 40.0]) + } + + func testCopyToPreservesShapeAndOrderOn2D() { + let tensor1 = Tensor( + [1, 2, 3, 4, 5, 6], + shape: [2, 3], + strides: [3, 1], + dimensionOrder: [0, 1] + ) + let tensor2 = tensor1.copy(to: Double.self) + XCTAssertEqual(tensor2.shape, [2, 3]) + XCTAssertEqual(tensor2.strides, [3, 1]) + XCTAssertEqual(tensor2.dimensionOrder, [0, 1]) + XCTAssertEqual(tensor2.count, 6) + XCTAssertEqual(tensor2.scalars(), [1, 2, 3, 4, 5, 6]) + } + func testResize() { var data: [Int] = [1, 2, 3, 4] let tensor = data.withUnsafeMutableBytes { diff --git a/extension/tensor/tensor_ptr.cpp b/extension/tensor/tensor_ptr.cpp index dab1a8ab176..bb76311bd67 100644 --- a/extension/tensor/tensor_ptr.cpp +++ b/extension/tensor/tensor_ptr.cpp @@ -164,7 +164,9 @@ TensorPtr make_tensor_ptr( [data = std::move(data)](void*) {}); } -TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor) { +TensorPtr clone_tensor_ptr( + const executorch::aten::Tensor& tensor, + executorch::aten::ScalarType type) { std::vector sizes( tensor.sizes().begin(), tensor.sizes().end()); std::vector dim_order{ @@ -178,23 +180,63 @@ TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor) { #ifndef USE_ATEN_LIB dynamism = tensor.shape_dynamism(); #endif // USE_ATEN_LIB - return tensor.const_data_ptr() - ? make_tensor_ptr( - std::move(sizes), - std::vector( - (uint8_t*)tensor.const_data_ptr(), - (uint8_t*)tensor.const_data_ptr() + tensor.nbytes()), - std::move(dim_order), - std::move(strides), - tensor.scalar_type(), - dynamism) - : make_tensor_ptr( - std::move(sizes), - nullptr, - std::move(dim_order), - std::move(strides), - tensor.scalar_type(), - dynamism); + const auto* tensor_data = tensor.const_data_ptr(); + if (!tensor_data) { + return make_tensor_ptr( + std::move(sizes), + nullptr, + std::move(dim_order), + std::move(strides), + type, + dynamism); + } + const auto tensor_type = tensor.scalar_type(); + if (tensor_type == type) { + return make_tensor_ptr( + std::move(sizes), + std::vector( + (uint8_t*)tensor_data, (uint8_t*)tensor_data + tensor.nbytes()), + std::move(dim_order), + std::move(strides), + tensor_type, + dynamism); + } + ET_CHECK_MSG( + runtime::canCast(tensor_type, type), + "Cannot cast tensor type to desired type."); + const auto tensor_numel = static_cast(tensor.numel()); + std::vector data(tensor_numel * aten::elementSize(type)); + + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype in clone_tensor_ptr"); + } + } ctx; + + ET_SWITCH_REALHBBF16_AND_UINT_TYPES( + tensor_type, ctx, "clone_tensor_ptr_from", CTYPE_FROM, [&] { + const CTYPE_FROM* tensor_data_ptr = + static_cast(tensor_data); + ET_SWITCH_REALHBBF16_AND_UINT_TYPES( + type, ctx, "clone_tensor_ptr_to", CTYPE_TO, [&] { + CTYPE_TO* data_ptr = reinterpret_cast(data.data()); + std::transform( + tensor_data_ptr, + tensor_data_ptr + tensor_numel, + data_ptr, + [](const CTYPE_FROM& val) { + return static_cast(val); + }); + }); + }); + return make_tensor_ptr( + std::move(sizes), + std::move(data), + std::move(dim_order), + std::move(strides), + type, + dynamism); } runtime::Error resize_tensor_ptr( diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index d8fad857cd2..41e21b14816 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -114,7 +114,7 @@ inline TensorPtr make_tensor_ptr( ET_CHECK_MSG( runtime::canCast(deduced_type, type), "Cannot cast deduced type to specified type."); - std::vector casted_data(data.size() * runtime::elementSize(type)); + std::vector casted_data(data.size() * aten::elementSize(type)); // Create a minimal context for error handling in ET_SWITCH struct { @@ -408,6 +408,21 @@ inline TensorPtr make_tensor_ptr( [tensor_ptr](void*) {}); } +/** + * Creates a TensorPtr that manages a new Tensor with the same properties + * as the given Tensor, but with a copy of the data owned by the returned + * TensorPtr, or nullptr if the original data is null. + * + * @param tensor The Tensor to clone. + * @param type The data type for the cloned tensor. The data will be cast + * from the source tensor's type. + * @return A new TensorPtr that manages a Tensor with the specified type + * and copied/cast data. + */ +TensorPtr clone_tensor_ptr( + const executorch::aten::Tensor& tensor, + executorch::aten::ScalarType type); + /** * Creates a TensorPtr that manages a new Tensor with the same properties * as the given Tensor, but with a copy of the data owned by the returned @@ -417,7 +432,25 @@ inline TensorPtr make_tensor_ptr( * @return A new TensorPtr that manages a Tensor with the same properties as the * original but with copied data. */ -TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor); +inline TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor) { + return clone_tensor_ptr(tensor, tensor.scalar_type()); +} + +/** + * Creates a new TensorPtr by cloning the given TensorPtr, copying the + * underlying data. + * + * @param tensor The TensorPtr to clone. + * @param type The data type for the cloned tensor. The data will be cast + * from the source tensor's type. + * @return A new TensorPtr that manages a Tensor with the specified type + * and copied/cast data. + */ +inline TensorPtr clone_tensor_ptr( + const TensorPtr& tensor, + executorch::aten::ScalarType type) { + return clone_tensor_ptr(*tensor, type); +} /** * Creates a new TensorPtr by cloning the given TensorPtr, copying the @@ -428,7 +461,7 @@ TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor); * original but with copied data. */ inline TensorPtr clone_tensor_ptr(const TensorPtr& tensor) { - return clone_tensor_ptr(*tensor); + return clone_tensor_ptr(*tensor, tensor->scalar_type()); } /** diff --git a/extension/tensor/test/tensor_ptr_test.cpp b/extension/tensor/test/tensor_ptr_test.cpp index 5e242e5eb02..b8e065481f6 100644 --- a/extension/tensor/test/tensor_ptr_test.cpp +++ b/extension/tensor/test/tensor_ptr_test.cpp @@ -571,6 +571,82 @@ TEST_F(TensorPtrTest, CloneTensorPtrFromExistingTensorInt32) { EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Int); } +TEST_F(TensorPtrTest, CloneTensorPtrCastInt32ToFloat) { + std::vector data = {1, 2, 3, 4}; + auto tensor = make_tensor_ptr({2, 2}, std::move(data)); + auto cloned_tensor = + clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Float); + + EXPECT_EQ(cloned_tensor->dim(), 2); + EXPECT_EQ(cloned_tensor->size(0), 2); + EXPECT_EQ(cloned_tensor->size(1), 2); + EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Float); + auto ptr = cloned_tensor->const_data_ptr(); + EXPECT_FLOAT_EQ(ptr[0], 1.0f); + EXPECT_FLOAT_EQ(ptr[1], 2.0f); + EXPECT_FLOAT_EQ(ptr[2], 3.0f); + EXPECT_FLOAT_EQ(ptr[3], 4.0f); +} + +TEST_F(TensorPtrTest, CloneTensorPtrCastFloatToBFloat16) { + std::vector data = {1.0f, 2.0f, 3.5f}; + auto tensor = make_tensor_ptr({3}, std::move(data)); + auto cloned_tensor = + clone_tensor_ptr(*tensor, executorch::aten::ScalarType::BFloat16); + + EXPECT_EQ(cloned_tensor->dim(), 1); + EXPECT_EQ(cloned_tensor->size(0), 3); + EXPECT_EQ( + cloned_tensor->scalar_type(), executorch::aten::ScalarType::BFloat16); + auto ptr = cloned_tensor->const_data_ptr(); + EXPECT_NEAR(static_cast(ptr[0]), 1.0f, 0.01f); + EXPECT_NEAR(static_cast(ptr[1]), 2.0f, 0.01f); + EXPECT_NEAR(static_cast(ptr[2]), 3.5f, 0.01f); +} + +TEST_F(TensorPtrTest, CloneTensorPtrCastKeepsMetadata) { + std::vector data( + 6 * executorch::aten::elementSize(executorch::aten::ScalarType::Float)); + auto tensor = make_tensor_ptr({2, 3}, std::move(data)); + auto cloned_tensor = + clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Float); + + EXPECT_EQ(cloned_tensor->dim(), 2); + EXPECT_EQ(cloned_tensor->size(0), 2); + EXPECT_EQ(cloned_tensor->size(1), 3); + EXPECT_EQ(cloned_tensor->strides()[0], 3); + EXPECT_EQ(cloned_tensor->strides()[1], 1); + EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Float); +} + +TEST_F(TensorPtrTest, CloneTensorPtrCastNullData) { + auto tensor = make_tensor_ptr( + {2, 2}, + nullptr, + {}, + {}, + executorch::aten::ScalarType::Float, + executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND); + auto cloned_tensor = + clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Int); + + EXPECT_EQ(cloned_tensor->dim(), 2); + EXPECT_EQ(cloned_tensor->size(0), 2); + EXPECT_EQ(cloned_tensor->size(1), 2); + EXPECT_EQ(cloned_tensor->const_data_ptr(), nullptr); + EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Int); +} + +TEST_F(TensorPtrTest, CloneTensorPtrCastInvalidExpectDeath) { + std::vector data = {1.0f, 2.0f}; + auto tensor = make_tensor_ptr({2}, std::move(data)); + ET_EXPECT_DEATH( + { + auto _ = clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Int); + }, + ""); +} + TEST_F(TensorPtrTest, MakeTensorPtrFromTensorPtrInt32) { std::vector data = {1, 2, 3, 4}; auto tensor = make_tensor_ptr({2, 2}, data);