diff --git a/kernels/portable/cpu/op_gather.cpp b/kernels/portable/cpu/op_gather.cpp index b7d257ae3d9..b221b450752 100644 --- a/kernels/portable/cpu/op_gather.cpp +++ b/kernels/portable/cpu/op_gather.cpp @@ -86,7 +86,7 @@ Tensor& gather_out( constexpr auto name = "gather.out"; - ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { gather_helper(in, index, out, dim); }); diff --git a/kernels/test/CMakeLists.txt b/kernels/test/CMakeLists.txt index f1f51e73fbb..4b4607e4bff 100644 --- a/kernels/test/CMakeLists.txt +++ b/kernels/test/CMakeLists.txt @@ -139,6 +139,7 @@ set(all_test_sources "op_fmod_test.cpp" "op_full_like_test.cpp" "op_full_test.cpp" + "op_gather_test.cpp" "op_ge_test.cpp" "op_gelu_test.cpp" "op_glu_test.cpp" diff --git a/kernels/test/op_gather_test.cpp b/kernels/test/op_gather_test.cpp index 9d637560eda..24d3b740d20 100644 --- a/kernels/test/op_gather_test.cpp +++ b/kernels/test/op_gather_test.cpp @@ -194,7 +194,7 @@ class OpGatherOutTest : public OperatorTest { TEST_F(OpGatherOutTest, AllValidInputOutputSupport) { #define TEST_ENTRY(CTYPE, DTYPE) test_gather_out(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY }