diff --git a/extension/tensor/tensor_ptr.cpp b/extension/tensor/tensor_ptr.cpp index 7a0aa997f02..647de8dbe57 100644 --- a/extension/tensor/tensor_ptr.cpp +++ b/extension/tensor/tensor_ptr.cpp @@ -13,6 +13,39 @@ namespace executorch { namespace extension { +TensorPtr clone_tensor_ptr(const exec_aten::Tensor& tensor) { + std::vector sizes( + tensor.sizes().begin(), tensor.sizes().end()); + std::vector dim_order{ +#ifndef USE_ATEN_LIB + tensor.dim_order().begin(), tensor.dim_order().end() +#endif // USE_ATEN_LIB + }; + std::vector strides( + tensor.strides().begin(), tensor.strides().end()); + auto dynamism = exec_aten::TensorShapeDynamism::DYNAMIC_BOUND; +#ifndef USE_ATEN_LIB + dynamism = tensor.shape_dynamism(); +#endif // USE_ATEN_LIB + return tensor.const_data_ptr() + ? make_tensor_ptr( + std::move(sizes), + std::vector( + (uint8_t*)tensor.const_data_ptr(), + (uint8_t*)tensor.const_data_ptr() + tensor.nbytes()), + std::move(dim_order), + std::move(strides), + tensor.scalar_type(), + dynamism) + : make_tensor_ptr( + std::move(sizes), + nullptr, + std::move(dim_order), + std::move(strides), + tensor.scalar_type(), + dynamism); +} + runtime::Error resize_tensor_ptr( TensorPtr& tensor, const std::vector& sizes) { diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index e8a97be0367..1ec73882573 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -397,34 +397,13 @@ inline TensorPtr make_tensor_ptr( /** * Creates a TensorPtr that manages a new Tensor with the same properties * as the given Tensor, but with a copy of the data owned by the returned - * TensorPtr. + * TensorPtr, or nullptr if the original data is null. * * @param tensor The Tensor to clone. * @return A new TensorPtr that manages a Tensor with the same properties as the * original but with copied data. */ -inline TensorPtr clone_tensor_ptr(const exec_aten::Tensor& tensor) { - return make_tensor_ptr(make_tensor_impl_ptr( - std::vector( - tensor.sizes().begin(), tensor.sizes().end()), - std::vector( - (uint8_t*)tensor.const_data_ptr(), - (uint8_t*)tensor.const_data_ptr() + tensor.nbytes()), -#ifndef USE_ATEN_LIB - std::vector( - tensor.dim_order().begin(), tensor.dim_order().end()), - std::vector( - tensor.strides().begin(), tensor.strides().end()), - tensor.scalar_type(), - tensor.shape_dynamism() -#else // USE_ATEN_LIB - {}, - std::vector( - tensor.strides().begin(), tensor.strides().end()), - tensor.scalar_type() -#endif // USE_ATEN_LIB - )); -} +TensorPtr clone_tensor_ptr(const exec_aten::Tensor& tensor); /** * Creates a new TensorPtr by cloning the given TensorPtr, copying the diff --git a/extension/tensor/test/tensor_ptr_test.cpp b/extension/tensor/test/tensor_ptr_test.cpp index 291d19e06b4..2473fc7ccd7 100644 --- a/extension/tensor/test/tensor_ptr_test.cpp +++ b/extension/tensor/test/tensor_ptr_test.cpp @@ -478,6 +478,17 @@ TEST_F(TensorPtrTest, CloneTensorPtrFromTensorPtrInt64) { EXPECT_EQ(cloned_tensor->scalar_type(), exec_aten::ScalarType::Long); } +TEST_F(TensorPtrTest, CloneTensorPtrFromTensorPtrNull) { + auto tensor = make_tensor_ptr({2, 2}, nullptr); + auto cloned_tensor = clone_tensor_ptr(tensor); + + EXPECT_EQ(cloned_tensor->dim(), tensor->dim()); + EXPECT_EQ(cloned_tensor->size(0), tensor->size(0)); + EXPECT_EQ(cloned_tensor->size(1), tensor->size(1)); + EXPECT_EQ(cloned_tensor->const_data_ptr(), tensor->const_data_ptr()); + EXPECT_EQ(cloned_tensor->const_data_ptr(), nullptr); +} + TEST_F(TensorPtrTest, TensorDataCastingFromIntToFloat) { std::vector int_data = {1, 2, 3, 4, 5, 6}; auto tensor = make_tensor_ptr(