diff --git a/extension/tensor/tensor_impl_ptr.h b/extension/tensor/tensor_impl_ptr.h index 3ccede79b1d..f336faf07b0 100644 --- a/extension/tensor/tensor_impl_ptr.h +++ b/extension/tensor/tensor_impl_ptr.h @@ -96,7 +96,7 @@ TensorImplPtr make_tensor_impl_ptr( exec_aten::TensorShapeDynamism::STATIC) { constexpr exec_aten::ScalarType scalar_type = runtime::CppTypeToScalarType::value; - auto raw_data_ptr = data.data(); + const auto raw_data_ptr = data.data(); auto data_ptr = std::make_shared>(std::move(data)); return make_tensor_impl_ptr( scalar_type, @@ -108,6 +108,40 @@ TensorImplPtr make_tensor_impl_ptr( [data_ptr = std::move(data_ptr)](void*) {}); } +/** + * Creates a TensorImplPtr that manages a newly created TensorImpl with the + * specified properties. + * + * This template overload is specialized for cases where the tensor data is + * provided as a vector. The scalar type is automatically deduced from the + * vector's data type. The deleter ensures that the data vector is properly + * managed and its lifetime is tied to the TensorImpl. + * + * @tparam T The C++ type of the tensor elements, deduced from the vector. + * @param data A vector containing the tensor's data. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorImplPtr that manages the newly created TensorImpl. + */ +template +TensorImplPtr make_tensor_impl_ptr( + std::vector data, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::STATIC) { + constexpr exec_aten::ScalarType scalar_type = + runtime::CppTypeToScalarType::value; + std::vector sizes{exec_aten::SizesType(data.size())}; + const auto raw_data_ptr = data.data(); + auto data_ptr = std::make_shared>(std::move(data)); + return make_tensor_impl_ptr( + scalar_type, + std::move(sizes), + raw_data_ptr, + {0}, + {1}, + dynamism, + [data_ptr = std::move(data_ptr)](void*) {}); +} + /** * Creates a TensorImplPtr that manages a newly created TensorImpl with the * specified properties. diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index 18568876607..ef29d598b84 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -169,6 +169,27 @@ TensorPtr make_tensor_ptr( dynamism)); } +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * This template overload is specialized for cases where the tensor data is + * provided as a vector. The scalar type is automatically deduced from the + * vector's data type. The deleter ensures that the data vector is properly + * managed and its lifetime is tied to the TensorImpl. + * + * @tparam T The C++ type of the tensor elements, deduced from the vector. + * @param data A vector containing the tensor's data. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorPtr that manages the newly created TensorImpl. + */ +template +TensorPtr make_tensor_ptr( + std::vector data, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::STATIC) { + return make_tensor_ptr(make_tensor_impl_ptr(std::move(data), dynamism)); +} + /** * Creates a TensorPtr that manages a Tensor with the specified properties. * diff --git a/extension/tensor/test/tensor_impl_ptr_test.cpp b/extension/tensor/test/tensor_impl_ptr_test.cpp index 45d79f240af..09d55de3e8e 100644 --- a/extension/tensor/test/tensor_impl_ptr_test.cpp +++ b/extension/tensor/test/tensor_impl_ptr_test.cpp @@ -172,7 +172,7 @@ TEST_F(TensorImplPtrTest, TensorImplOwningData) { } TEST_F(TensorImplPtrTest, TensorImplOwningEmptyData) { - auto tensor_impl = make_tensor_impl_ptr({0, 5}, {}); + auto tensor_impl = make_tensor_impl_ptr({0, 5}, std::vector()); EXPECT_EQ(tensor_impl->dim(), 2); EXPECT_EQ(tensor_impl->size(0), 0); @@ -182,6 +182,74 @@ TEST_F(TensorImplPtrTest, TensorImplOwningEmptyData) { EXPECT_EQ(tensor_impl->data(), nullptr); } +TEST_F(TensorImplPtrTest, TensorImplDataOnlyDoubleType) { + std::vector data = {1.0, 2.0, 3.0, 4.0}; + auto tensor_impl = make_tensor_impl_ptr(std::move(data)); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 4); + EXPECT_EQ(tensor_impl->strides()[0], 1); + EXPECT_EQ(((double*)tensor_impl->data())[0], 1.0); + EXPECT_EQ(((double*)tensor_impl->data())[3], 4.0); +} + +TEST_F(TensorImplPtrTest, TensorImplDataOnlyInt32Type) { + std::vector data = {10, 20, 30, 40}; + auto tensor_impl = make_tensor_impl_ptr(std::move(data)); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 4); + EXPECT_EQ(tensor_impl->strides()[0], 1); + EXPECT_EQ(((int32_t*)tensor_impl->data())[0], 10); + EXPECT_EQ(((int32_t*)tensor_impl->data())[3], 40); +} + +TEST_F(TensorImplPtrTest, TensorImplDataOnlyInt64Type) { + std::vector data = {100, 200, 300, 400}; + auto tensor_impl = make_tensor_impl_ptr(std::move(data)); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 4); + EXPECT_EQ(tensor_impl->strides()[0], 1); + EXPECT_EQ(((int64_t*)tensor_impl->data())[0], 100); + EXPECT_EQ(((int64_t*)tensor_impl->data())[3], 400); +} + +TEST_F(TensorImplPtrTest, TensorImplDataOnlyUint8Type) { + std::vector data = {10, 20, 30, 40}; + auto tensor_impl = make_tensor_impl_ptr(std::move(data)); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 4); + EXPECT_EQ(tensor_impl->strides()[0], 1); + EXPECT_EQ(((uint8_t*)tensor_impl->data())[0], 10); + EXPECT_EQ(((uint8_t*)tensor_impl->data())[3], 40); +} + +TEST_F(TensorImplPtrTest, TensorImplAmbiguityWithMixedVectors) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + auto tensor_impl = make_tensor_impl_ptr(std::move(sizes), std::move(data)); + + EXPECT_EQ(tensor_impl->dim(), 2); + EXPECT_EQ(tensor_impl->size(0), 2); + EXPECT_EQ(tensor_impl->size(1), 2); + EXPECT_EQ(tensor_impl->strides()[0], 2); + EXPECT_EQ(tensor_impl->strides()[1], 1); + EXPECT_EQ(((float*)tensor_impl->data())[0], 1.0f); + EXPECT_EQ(((float*)tensor_impl->data())[3], 4.0f); + + auto tensor_impl2 = make_tensor_impl_ptr({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); + + EXPECT_EQ(tensor_impl2->dim(), 2); + EXPECT_EQ(tensor_impl2->size(0), 2); + EXPECT_EQ(tensor_impl2->size(1), 2); + EXPECT_EQ(tensor_impl2->strides()[0], 2); + EXPECT_EQ(tensor_impl2->strides()[1], 1); + EXPECT_EQ(((float*)tensor_impl2->data())[0], 1.0f); + EXPECT_EQ(((float*)tensor_impl2->data())[3], 4.0f); +} + TEST_F(TensorImplPtrTest, SharedDataManagement) { auto data = std::make_shared>(100, 1.0f); auto tensor_impl1 = make_tensor_impl_ptr( diff --git a/extension/tensor/test/tensor_ptr_test.cpp b/extension/tensor/test/tensor_ptr_test.cpp index 1542824fb73..24aa20a8355 100644 --- a/extension/tensor/test/tensor_ptr_test.cpp +++ b/extension/tensor/test/tensor_ptr_test.cpp @@ -167,7 +167,7 @@ TEST_F(TensorPtrTest, TensorOwningData) { } TEST_F(TensorPtrTest, TensorOwningEmptyData) { - auto tensor = make_tensor_ptr({0, 5}, {}); + auto tensor = make_tensor_ptr({0, 5}, std::vector()); EXPECT_EQ(tensor->dim(), 2); EXPECT_EQ(tensor->size(0), 0); @@ -177,6 +177,74 @@ TEST_F(TensorPtrTest, TensorOwningEmptyData) { EXPECT_EQ(tensor->data_ptr(), nullptr); } +TEST_F(TensorPtrTest, TensorImplDataOnlyDoubleType) { + std::vector data = {1.0, 2.0, 3.0, 4.0}; + auto tensor = make_tensor_ptr(std::move(data)); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->strides()[0], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 1.0); + EXPECT_EQ(tensor->const_data_ptr()[3], 4.0); +} + +TEST_F(TensorPtrTest, TensorImplDataOnlyInt32Type) { + std::vector data = {10, 20, 30, 40}; + auto tensor = make_tensor_ptr(std::move(data)); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->strides()[0], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 10); + EXPECT_EQ(tensor->const_data_ptr()[3], 40); +} + +TEST_F(TensorPtrTest, TensorImplDataOnlyInt64Type) { + std::vector data = {100, 200, 300, 400}; + auto tensor = make_tensor_ptr(std::move(data)); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->strides()[0], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 100); + EXPECT_EQ(tensor->const_data_ptr()[3], 400); +} + +TEST_F(TensorPtrTest, TensorImplDataOnlyUint8Type) { + std::vector data = {10, 20, 30, 40}; + auto tensor = make_tensor_ptr(std::move(data)); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->strides()[0], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 10); + EXPECT_EQ(tensor->const_data_ptr()[3], 40); +} + +TEST_F(TensorPtrTest, TensorImplAmbiguityWithMixedVectors) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + auto tensor = make_tensor_ptr(std::move(sizes), std::move(data)); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 2); + EXPECT_EQ(tensor->strides()[0], 2); + EXPECT_EQ(tensor->strides()[1], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 1.0f); + EXPECT_EQ(tensor->const_data_ptr()[3], 4.0f); + + auto tensor2 = make_tensor_ptr({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); + + EXPECT_EQ(tensor2->dim(), 2); + EXPECT_EQ(tensor2->size(0), 2); + EXPECT_EQ(tensor2->size(1), 2); + EXPECT_EQ(tensor2->strides()[0], 2); + EXPECT_EQ(tensor2->strides()[1], 1); + EXPECT_EQ(tensor2->const_data_ptr()[0], 1.0f); + EXPECT_EQ(tensor2->const_data_ptr()[3], 4.0f); +} + TEST_F(TensorPtrTest, TensorSharingImplModifiesSharedDataVector) { std::vector data = {1, 2, 3, 4, 5, 6};