diff --git a/extension/tensor/tensor_ptr.cpp b/extension/tensor/tensor_ptr.cpp index 08ba6d70a8d..dab1a8ab176 100644 --- a/extension/tensor/tensor_ptr.cpp +++ b/extension/tensor/tensor_ptr.cpp @@ -148,10 +148,10 @@ TensorPtr make_tensor_ptr( executorch::aten::ScalarType type, executorch::aten::TensorShapeDynamism dynamism) { ET_CHECK_MSG( - data.size() >= + data.size() == executorch::aten::compute_numel(sizes.data(), sizes.size()) * executorch::aten::elementSize(type), - "Data size is smaller than required by sizes and scalar type."); + "Data size does not match tensor size."); auto data_ptr = data.data(); return make_tensor_ptr( std::move(sizes), diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index 59690de9f26..4753ec296da 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -106,6 +106,10 @@ inline TensorPtr make_tensor_ptr( executorch::aten::ScalarType type = deduced_type, executorch::aten::TensorShapeDynamism dynamism = executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND) { + ET_CHECK_MSG( + data.size() == + executorch::aten::compute_numel(sizes.data(), sizes.size()), + "Data size does not match tensor size."); if (type != deduced_type) { ET_CHECK_MSG( runtime::canCast(deduced_type, type), diff --git a/extension/tensor/test/tensor_ptr_test.cpp b/extension/tensor/test/tensor_ptr_test.cpp index 99c4f1b0d1a..6c98db52d41 100644 --- a/extension/tensor/test/tensor_ptr_test.cpp +++ b/extension/tensor/test/tensor_ptr_test.cpp @@ -784,16 +784,30 @@ TEST_F(TensorPtrTest, TensorUint8BufferTooSmallExpectDeath) { { auto tensor = make_tensor_ptr({2, 2}, std::move(data)); }, ""); } -TEST_F(TensorPtrTest, TensorUint8BufferTooLarge) { +TEST_F(TensorPtrTest, TensorUint8BufferTooLargeExpectDeath) { std::vector data( - 4 * executorch::aten::elementSize(executorch::aten::ScalarType::Float)); - auto tensor = make_tensor_ptr({2, 2}, std::move(data)); + 5 * executorch::aten::elementSize(executorch::aten::ScalarType::Float)); + ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 2}, std::move(data)); }, ""); +} - EXPECT_EQ(tensor->dim(), 2); - EXPECT_EQ(tensor->size(0), 2); - EXPECT_EQ(tensor->size(1), 2); - EXPECT_EQ(tensor->strides()[0], 2); - EXPECT_EQ(tensor->strides()[1], 1); +TEST_F(TensorPtrTest, VectorFloatTooSmallExpectDeath) { + std::vector data(9, 1.f); + ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 5}, std::move(data)); }, ""); +} + +TEST_F(TensorPtrTest, VectorFloatTooLargeExpectDeath) { + std::vector data(11, 1.f); + ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 5}, std::move(data)); }, ""); +} + +TEST_F(TensorPtrTest, VectorIntToFloatCastTooSmallExpectDeath) { + std::vector data(9, 1); + ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 5}, std::move(data)); }, ""); +} + +TEST_F(TensorPtrTest, VectorIntToFloatCastTooLargeExpectDeath) { + std::vector data(11, 1); + ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 5}, std::move(data)); }, ""); } TEST_F(TensorPtrTest, StridesAndDimOrderMustMatchSizes) {