diff --git a/extension/tensor/test/tensor_impl_ptr_test.cpp b/extension/tensor/test/tensor_impl_ptr_test.cpp index 1330dfa60f5..f7fd062c462 100644 --- a/extension/tensor/test/tensor_impl_ptr_test.cpp +++ b/extension/tensor/test/tensor_impl_ptr_test.cpp @@ -23,6 +23,29 @@ class TensorImplPtrTest : public ::testing::Test { } }; +TEST_F(TensorImplPtrTest, ScalarTensorCreation) { + float scalar_data = 3.14f; + auto tensor_impl = + make_tensor_impl_ptr(exec_aten::ScalarType::Float, {}, &scalar_data); + + EXPECT_EQ(tensor_impl->numel(), 1); + EXPECT_EQ(tensor_impl->dim(), 0); + EXPECT_EQ(tensor_impl->sizes().size(), 0); + EXPECT_EQ(tensor_impl->strides().size(), 0); + EXPECT_EQ((float*)tensor_impl->data(), &scalar_data); + EXPECT_EQ(((float*)tensor_impl->data())[0], 3.14f); +} + +TEST_F(TensorImplPtrTest, ScalarTensorOwningData) { + auto tensor_impl = make_tensor_impl_ptr({}, {3.14f}); + + EXPECT_EQ(tensor_impl->numel(), 1); + EXPECT_EQ(tensor_impl->dim(), 0); + EXPECT_EQ(tensor_impl->sizes().size(), 0); + EXPECT_EQ(tensor_impl->strides().size(), 0); + EXPECT_EQ(((float*)tensor_impl->data())[0], 3.14f); +} + TEST_F(TensorImplPtrTest, TensorImplCreation) { float data[20] = {2}; auto tensor_impl = make_tensor_impl_ptr( @@ -34,8 +57,8 @@ TEST_F(TensorImplPtrTest, TensorImplCreation) { EXPECT_EQ(tensor_impl->strides()[0], 5); EXPECT_EQ(tensor_impl->strides()[1], 1); EXPECT_EQ(tensor_impl->data(), data); - EXPECT_EQ(tensor_impl->mutable_data(), data); - EXPECT_EQ(((float*)tensor_impl->mutable_data())[0], 2); + EXPECT_EQ(tensor_impl->data(), data); + EXPECT_EQ(((float*)tensor_impl->data())[0], 2); } TEST_F(TensorImplPtrTest, TensorImplSharedOwnership) { diff --git a/extension/tensor/test/tensor_ptr_test.cpp b/extension/tensor/test/tensor_ptr_test.cpp index 3f5e7ff58e2..d5582630494 100644 --- a/extension/tensor/test/tensor_ptr_test.cpp +++ b/extension/tensor/test/tensor_ptr_test.cpp @@ -22,6 +22,28 @@ class TensorPtrTest : public ::testing::Test { } }; +TEST_F(TensorPtrTest, ScalarTensorCreation) { + float scalar_data = 3.14f; + auto tensor = make_tensor_ptr(exec_aten::ScalarType::Float, {}, &scalar_data); + + EXPECT_EQ(tensor->numel(), 1); + EXPECT_EQ(tensor->dim(), 0); + EXPECT_EQ(tensor->sizes().size(), 0); + EXPECT_EQ(tensor->strides().size(), 0); + EXPECT_EQ(tensor->const_data_ptr(), &scalar_data); + EXPECT_EQ(tensor->const_data_ptr()[0], 3.14f); +} + +TEST_F(TensorPtrTest, ScalarTensorOwningData) { + auto tensor = make_tensor_ptr({}, {3.14f}); + + EXPECT_EQ(tensor->numel(), 1); + EXPECT_EQ(tensor->dim(), 0); + EXPECT_EQ(tensor->sizes().size(), 0); + EXPECT_EQ(tensor->strides().size(), 0); + EXPECT_EQ(tensor->const_data_ptr()[0], 3.14f); +} + TEST_F(TensorPtrTest, CreateTensorWithStridesAndDimOrder) { float data[20] = {2}; auto tensor = make_tensor_ptr(