From b31fe11c85a4a159b0991b0f9db5ed169b635540 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Mon, 26 Aug 2024 22:39:41 -0700 Subject: [PATCH] Handle null data edge case in data_is_close testing util. (#4901) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4901 Reviewed By: kirklandsign Differential Revision: D61783890 --- .../exec_aten/testing_util/tensor_util.cpp | 9 +++++++++ .../testing_util/test/tensor_util_test.cpp | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/runtime/core/exec_aten/testing_util/tensor_util.cpp b/runtime/core/exec_aten/testing_util/tensor_util.cpp index f0340d34ca2..03dffd208f0 100644 --- a/runtime/core/exec_aten/testing_util/tensor_util.cpp +++ b/runtime/core/exec_aten/testing_util/tensor_util.cpp @@ -41,6 +41,15 @@ bool data_is_close( size_t numel, double rtol, double atol) { + ET_CHECK_MSG( + numel == 0 || (a != nullptr && b != nullptr), + "Pointers must not be null when numel > 0: numel %zu, a 0x%p, b 0x%p", + numel, + a, + b); + if (a == b) { + return true; + } for (size_t i = 0; i < numel; i++) { const auto ai = a[i]; const auto bi = b[i]; diff --git a/runtime/core/exec_aten/testing_util/test/tensor_util_test.cpp b/runtime/core/exec_aten/testing_util/test/tensor_util_test.cpp index 6d4ce5a8532..948f6bc78f0 100644 --- a/runtime/core/exec_aten/testing_util/test/tensor_util_test.cpp +++ b/runtime/core/exec_aten/testing_util/test/tensor_util_test.cpp @@ -23,6 +23,7 @@ using namespace ::testing; using exec_aten::ScalarType; using exec_aten::Tensor; +using exec_aten::TensorImpl; using exec_aten::TensorList; using executorch::runtime::testing::IsCloseTo; using executorch::runtime::testing::IsDataCloseTo; @@ -826,4 +827,22 @@ TEST(TensorUtilTest, TensorStreamBool) { "ETensor(sizes={2, 2}, dtype=Bool, data={1, 0, 1, 0})"); } +TEST(TensorTest, TestZeroShapeTensorEquality) { + TensorImpl::SizesType sizes[2] = {2, 2}; + TensorImpl::StridesType strides[2] = {2, 1}; + TensorImpl::DimOrderType dim_order[2] = {0, 1}; + + TensorImpl t1(ScalarType::Float, 2, sizes, nullptr, dim_order, strides); + TensorImpl t2(ScalarType::Float, 2, sizes, nullptr, dim_order, strides); + + ET_EXPECT_DEATH({ EXPECT_TENSOR_EQ(Tensor(&t1), Tensor(&t2)); }, ""); + + float data[] = {1.0, 2.0, 3.0, 4.0}; + + t1.set_data(data); + t2.set_data(data); + + EXPECT_TENSOR_EQ(Tensor(&t1), Tensor(&t2)); +} + #endif // !USE_ATEN_LIB