diff --git a/extension/tensor/tensor_ptr.cpp b/extension/tensor/tensor_ptr.cpp index 8a35e83a526..08ba6d70a8d 100644 --- a/extension/tensor/tensor_ptr.cpp +++ b/extension/tensor/tensor_ptr.cpp @@ -80,15 +80,27 @@ TensorPtr make_tensor_ptr( } } std::vector computed_strides(dim); + auto error = runtime::dim_order_to_stride( sizes.data(), dim_order.data(), dim, computed_strides.data()); ET_CHECK_MSG(error == runtime::Error::Ok, "Failed to compute strides."); if (!strides.empty()) { - ET_CHECK_MSG(computed_strides == strides, "Invalid strides provided."); - } else { - strides = std::move(computed_strides); + for (size_t i = 0; i < dim; i++) { + ET_CHECK_MSG( + strides[i] == computed_strides[i] || sizes[i] == 1, + "invalid strides for dim %zu: %" ET_PRI_SIZES_AND_STRIDES + "!= %" ET_PRI_SIZES_AND_STRIDES + " while its size is %" ET_PRI_SIZES_AND_STRIDES " != 1", + i, + strides[i], + computed_strides[i], + sizes[i]); + } } + + strides = std::move(computed_strides); + #ifndef USE_ATEN_LIB executorch::aten::TensorImpl tensor_impl( type, diff --git a/extension/tensor/test/tensor_ptr_maker_test.cpp b/extension/tensor/test/tensor_ptr_maker_test.cpp index e17d18229df..5988ecd3a04 100644 --- a/extension/tensor/test/tensor_ptr_maker_test.cpp +++ b/extension/tensor/test/tensor_ptr_maker_test.cpp @@ -11,6 +11,7 @@ #include #include +#include using namespace ::executorch::extension; using namespace ::executorch::runtime; @@ -113,6 +114,31 @@ TEST_F(TensorPtrMakerTest, CreateTensorUsingFromBlobWithStrides) { EXPECT_EQ(tensor->const_data_ptr()[0], 3); } +TEST_F(TensorPtrMakerTest, CreateTensorUsingFromBlobWithLegalStrides) { + float data[20] = {3}; + auto tensor = from_blob(data, {1, 2, 2}, {10, 2, 1}); + + EXPECT_EQ(tensor->dim(), 3); + EXPECT_EQ(tensor->size(0), 1); + EXPECT_EQ(tensor->size(1), 2); + EXPECT_EQ(tensor->size(2), 2); + + // recalculated stride[0]t o 2 to meet ET's requirement while maintain the + // same behavior as original tensor since size[0] == 1 + EXPECT_EQ(tensor->strides()[0], 4); + EXPECT_EQ(tensor->strides()[1], 2); + EXPECT_EQ(tensor->strides()[2], 1); + EXPECT_EQ(tensor->const_data_ptr(), data); + EXPECT_EQ(tensor->const_data_ptr()[0], 3); +} + +TEST_F(TensorPtrMakerTest, FailedCreateTensorUsingFromBlobWithIllegalStrides) { + float data[20] = {3}; + ET_EXPECT_DEATH( + from_blob(data, {2, 2, 2}, {10, 2, 1}), + "invalid strides for dim 0: 10!= 4 while its size is 2 != 1"); +} + TEST_F(TensorPtrMakerTest, TensorMakerConversionOperator) { float data[20] = {2}; TensorPtr tensor =