From 07d5d12974f582020480fce7e0f5604765a27bf1 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 11 Sep 2025 06:28:01 -0700 Subject: [PATCH] Fix BFloat16 support for op_index (#14167) Summary: Previously, the portable op index didn't support BFloat16 when all the indices were null. This PR fixes this, and adds comprehensive testing for this edge case. Reviewed By: SS-JIA Differential Revision: D82134506 --- kernels/portable/cpu/op_index.cpp | 2 +- kernels/test/op_index_test.cpp | 48 ++++++++++++++++++++++++------- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/kernels/portable/cpu/op_index.cpp b/kernels/portable/cpu/op_index.cpp index e0ca951de85..8fbf903400a 100644 --- a/kernels/portable/cpu/op_index.cpp +++ b/kernels/portable/cpu/op_index.cpp @@ -213,7 +213,7 @@ Tensor& index_Tensor_out( if (block_count == 0) { ET_KERNEL_CHECK( ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() { const CTYPE* const in_data = in.const_data_ptr(); CTYPE* const out_data = out.mutable_data_ptr(); memcpy(out_data, in_data, in.nbytes()); diff --git a/kernels/test/op_index_test.cpp b/kernels/test/op_index_test.cpp index 9f1f8e3e9f7..f3e1d9081c0 100644 --- a/kernels/test/op_index_test.cpp +++ b/kernels/test/op_index_test.cpp @@ -109,6 +109,36 @@ class OpIndexTensorOutTest : public OperatorTest { ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY + } + + template + void test_indices_with_only_null_tensors_supported() { + TensorFactory tf; + + Tensor x = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + Tensor out = tf.zeros({2, 3}); + + std::array, 1> indices1 = {optional()}; + op_index_tensor_out(x, indices1, out); + EXPECT_TENSOR_EQ(out, x); + + out = tf.zeros({2, 3}); + std::array, 2> indices2 = { + optional(), std::optional()}; + op_index_tensor_out(x, indices2, out); + EXPECT_TENSOR_EQ(out, x); + } + + /** + * Test indices with only null tensors for all input data types + */ + void test_indices_with_only_null_tensors_enumerate_in_types() { +#define TEST_ENTRY(ctype, dtype) \ + test_indices_with_only_null_tensors_supported(); + + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); + #undef TEST_ENTRY } @@ -405,21 +435,19 @@ TEST_F(OpIndexTensorOutTest, IndicesWithOnlyNullTensorsSupported) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel test fails"; } - TensorFactory tf; + test_indices_with_only_null_tensors_enumerate_in_types(); +} +TEST_F(OpIndexTensorOutTest, TooManyNullIndices) { + TensorFactory tf; Tensor x = tf.make({2, 3}, {1., 2., 3., 4., 5., 6.}); - std::array, 1> indices0 = {optional()}; - run_test_cases(x, indices0, x); - - std::array, 2> indices1 = { - optional(), std::optional()}; - run_test_cases(x, indices1, x); - - std::array, 3> indices2 = { + std::array, 3> indices = { optional(), std::optional(), std::optional()}; Tensor out = tf.ones({2, 3}); ET_EXPECT_KERNEL_FAILURE_WITH_MSG( - context_, op_index_tensor_out(x, indices2, out), ""); + context_, + op_index_tensor_out(x, indices, out), + "Indexing too many dimensions"); } TEST_F(OpIndexTensorOutTest, EmptyIndicesSupported) {