From 71e0873408dba9118acfe9d75136b0614b9a3d75 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 3 Sep 2025 11:00:45 -0700 Subject: [PATCH] Fix reduction over dim list for empty input (#13833) Summary: Allow empty input in `MapReduceOverDimListPlan` Avoid integer division by zero in `parallel_for_each_reduce_over_dim_list_output_index` Add empty input tests for ops: any, mean, sum & var Differential Revision: D81383049 --- kernels/portable/cpu/op_any.cpp | 2 +- kernels/portable/cpu/op_mean.cpp | 9 +++++--- kernels/portable/cpu/op_var.cpp | 2 +- kernels/portable/cpu/util/reduce_util.h | 13 ++++++++---- kernels/test/op_any_test.cpp | 28 +++++++++++++++++++++++++ kernels/test/op_mean_test.cpp | 27 ++++++++++++++++++++++++ kernels/test/op_sum_test.cpp | 27 ++++++++++++++++++++++++ kernels/test/op_var_test.cpp | 27 ++++++++++++++++++++++++ 8 files changed, 126 insertions(+), 9 deletions(-) diff --git a/kernels/portable/cpu/op_any.cpp b/kernels/portable/cpu/op_any.cpp index 8be0993767d..0f3a36b6ba7 100644 --- a/kernels/portable/cpu/op_any.cpp +++ b/kernels/portable/cpu/op_any.cpp @@ -105,7 +105,7 @@ Tensor& any_dims_out( in, dim_list, out, [&](const auto begin, const auto end) { for (const auto out_ix : c10::irange(begin, end)) { bool any = false; - if (in_not_empty) { + if (plan.has_value()) { any = plan->execute( [](CTYPE_IN v) { return static_cast(v); }, [](bool outv, bool acc) { return acc || outv; }, diff --git a/kernels/portable/cpu/op_mean.cpp b/kernels/portable/cpu/op_mean.cpp index 738fa98c9eb..63c78968751 100644 --- a/kernels/portable/cpu/op_mean.cpp +++ b/kernels/portable/cpu/op_mean.cpp @@ -45,7 +45,10 @@ Tensor& mean_dim_out( InvalidArgument, out); - MapReduceOverDimListPlan plan(in, dim_list); + std::optional plan; + if (in.numel() > 0) { + plan.emplace(in, dim_list); + } // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "mean.out"; ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] { @@ -56,8 +59,8 @@ Tensor& mean_dim_out( in, dim_list, out, [&](const auto begin, const auto end) { for (const auto out_ix : c10::irange(begin, end)) { CTYPE_OUT sum = 0; - if (in.numel() > 0) { - sum = plan.execute( + if (plan.has_value()) { + sum = plan->execute( [](CTYPE_IN v) { return static_cast(v); }, [](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; }, out_ix); diff --git a/kernels/portable/cpu/op_var.cpp b/kernels/portable/cpu/op_var.cpp index a95b3a9f167..fcaa79a54fe 100644 --- a/kernels/portable/cpu/op_var.cpp +++ b/kernels/portable/cpu/op_var.cpp @@ -32,7 +32,7 @@ void compute_variance( for (const auto out_ix : c10::irange(out.numel())) { out_data[out_ix] = NAN; } - } else { + } else if (in.numel() > 0) { MapReduceOverDimListPlan plan(in, dim_list); const bool success = parallel_for_each_reduce_over_dim_list_output_index( in, dim_list, out, [&](const auto begin, const auto end) { diff --git a/kernels/portable/cpu/util/reduce_util.h b/kernels/portable/cpu/util/reduce_util.h index 7d24ae7bda2..51981328c4f 100644 --- a/kernels/portable/cpu/util/reduce_util.h +++ b/kernels/portable/cpu/util/reduce_util.h @@ -543,6 +543,9 @@ class MapReduceOverDimListPlan { const MapOp& map_fun, const ReduceOp& reduce_fun, const size_t out_ix) const { + ET_CHECK_MSG( + plan_.get_input_tensor().numel() > 0, "Input tensor must be nonempty"); + const size_t init_index = get_init_index(plan_.get_input_tensor(), plan_.get_dim_list(), out_ix); @@ -834,10 +837,12 @@ template const Func& func) { #ifdef ET_USE_THREADPOOL const ssize_t reduction_size = get_reduced_dim_product(in, dim_list); - const auto grain_size = std::max( - static_cast(1), - static_cast(executorch::extension::internal::GRAIN_SIZE) / - reduction_size); + const auto grain_size = reduction_size == 0 + ? 1 + : std::max( + static_cast(1), + static_cast(executorch::extension::internal::GRAIN_SIZE) / + reduction_size); #else // ET_USE_THREADPOOL const auto grain_size = 1; #endif // ET_USE_THREADPOOL diff --git a/kernels/test/op_any_test.cpp b/kernels/test/op_any_test.cpp index fc815ea8508..7261be1b822 100644 --- a/kernels/test/op_any_test.cpp +++ b/kernels/test/op_any_test.cpp @@ -148,3 +148,31 @@ TEST_F(OpAnyOutTest, SmokeTest) { op_any_out(self, dim, keepdim, out); EXPECT_TENSOR_CLOSE(out, out_expected); } + +TEST_F(OpAnyOutTest, EmptyInput) { + TensorFactory tf; + TensorFactory tfBool; + + Tensor x = tf.make({2, 0, 3}, {}); + optional> dim_list = ArrayRef{}; + Tensor out = tfBool.make({2, 0, 3}, {}); + + op_any_dims_out(x, dim_list, /*keepdim=*/true, out); + EXPECT_TENSOR_CLOSE(out, tfBool.zeros({2, 0, 3})); + + out = tfBool.ones({2, 0, 3}); + op_any_dims_out(x, dim_list, /*keepdim=*/false, out); + EXPECT_TENSOR_CLOSE(out, tfBool.zeros({2, 0, 3})); + + int64_t dims1[1] = {1}; + dim_list = ArrayRef{dims1, 1}; + out = tfBool.ones({2, 3}); + op_any_dims_out(x, dim_list, /*keepdim=*/false, out); + EXPECT_TENSOR_CLOSE(out, tfBool.zeros({2, 3})); + + int64_t dims2[1] = {2}; + dim_list = ArrayRef{dims2, 1}; + out = tfBool.make({2, 0, 1}, {}); + op_any_dims_out(x, dim_list, /*keepdim=*/true, out); + EXPECT_TENSOR_CLOSE(out, tfBool.make({2, 0, 1}, {})); +} diff --git a/kernels/test/op_mean_test.cpp b/kernels/test/op_mean_test.cpp index 47702be82cb..65d21b45518 100644 --- a/kernels/test/op_mean_test.cpp +++ b/kernels/test/op_mean_test.cpp @@ -551,3 +551,30 @@ TEST_F(OpMeanOutTest, DTypeOutFloatNAN) { Tensor ret = op_mean_dtype_out(x, ScalarType::Float, out); EXPECT_TENSOR_CLOSE(out, expected_result); } + +TEST_F(OpMeanOutTest, EmptyInput) { + TensorFactory tf; + + Tensor x = tf.make({2, 0, 3}, {}); + optional dtype = ScalarType::Float; + optional> dim_list = ArrayRef{}; + Tensor out = tf.zeros({1, 1, 1}); + op_mean_out(x, dim_list, /*keepdim=*/true, dtype, out); + EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 1}, {NAN})); + + out = tf.zeros({}); + op_mean_out(x, dim_list, /*keepdim=*/false, dtype, out); + EXPECT_TENSOR_CLOSE(out, tf.make({}, {NAN})); + + int64_t dims1[1] = {1}; + dim_list = ArrayRef{dims1, 1}; + out = tf.zeros({2, 3}); + op_mean_out(x, dim_list, /*keepdim=*/false, dtype, out); + EXPECT_TENSOR_CLOSE(out, tf.make({2, 3}, {NAN, NAN, NAN, NAN, NAN, NAN})); + + int64_t dims2[1] = {2}; + dim_list = ArrayRef{dims2, 1}; + out = tf.make({2, 0, 1}, {}); + op_mean_out(x, dim_list, /*keepdim=*/true, dtype, out); + EXPECT_TENSOR_CLOSE(out, tf.make({2, 0, 1}, {})); +} diff --git a/kernels/test/op_sum_test.cpp b/kernels/test/op_sum_test.cpp index 748e5427b1d..58624c2a110 100644 --- a/kernels/test/op_sum_test.cpp +++ b/kernels/test/op_sum_test.cpp @@ -490,3 +490,30 @@ TEST_F(OpSumOutTest, InfinityAndNANTest) { })); // clang-format on } + +TEST_F(OpSumOutTest, EmptyInput) { + TensorFactory tf; + + Tensor x = tf.make({2, 0, 3}, {}); + optional dtype = ScalarType::Float; + optional> dim_list = ArrayRef{}; + Tensor out = tf.ones({1, 1, 1}); + op_sum_intlist_out(x, dim_list, /*keepdim=*/true, dtype, out); + EXPECT_TENSOR_CLOSE(out, tf.zeros({1, 1, 1})); + + out = tf.ones({}); + op_sum_intlist_out(x, dim_list, /*keepdim=*/false, dtype, out); + EXPECT_TENSOR_CLOSE(out, tf.zeros({})); + + int64_t dims1[1] = {1}; + dim_list = ArrayRef{dims1, 1}; + out = tf.ones({2, 3}); + op_sum_intlist_out(x, dim_list, /*keepdim=*/false, dtype, out); + EXPECT_TENSOR_CLOSE(out, tf.zeros({2, 3})); + + int64_t dims2[1] = {2}; + dim_list = ArrayRef{dims2, 1}; + out = tf.make({2, 0, 1}, {}); + op_sum_intlist_out(x, dim_list, /*keepdim=*/true, dtype, out); + EXPECT_TENSOR_CLOSE(out, tf.make({2, 0, 1}, {})); +} diff --git a/kernels/test/op_var_test.cpp b/kernels/test/op_var_test.cpp index f2bd3acccf3..bfa73bfe15c 100644 --- a/kernels/test/op_var_test.cpp +++ b/kernels/test/op_var_test.cpp @@ -468,3 +468,30 @@ TEST_F(OpVarCorrectionOutTest, SmokeTest) { ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } + +TEST_F(OpVarOutTest, EmptyInput) { + TensorFactory tf; + + Tensor x = tf.make({2, 0, 3}, {}); + bool unbiased = true; + optional> dim_list = ArrayRef{}; + Tensor out = tf.zeros({1, 1, 1}); + op_var_out(x, dim_list, unbiased, /*keepdim=*/true, out); + EXPECT_TENSOR_CLOSE(out, tf.make({1, 1, 1}, {NAN})); + + out = tf.zeros({}); + op_var_out(x, dim_list, unbiased, /*keepdim=*/false, out); + EXPECT_TENSOR_CLOSE(out, tf.make({}, {NAN})); + + int64_t dims1[1] = {1}; + dim_list = ArrayRef{dims1, 1}; + out = tf.zeros({2, 3}); + op_var_out(x, dim_list, unbiased, /*keepdim=*/false, out); + EXPECT_TENSOR_CLOSE(out, tf.make({2, 3}, {NAN, NAN, NAN, NAN, NAN, NAN})); + + int64_t dims2[1] = {2}; + dim_list = ArrayRef{dims2, 1}; + out = tf.make({2, 0, 1}, {}); + op_var_out(x, dim_list, unbiased, /*keepdim=*/true, out); + EXPECT_TENSOR_CLOSE(out, tf.make({2, 0, 1}, {})); +}