From 1b2b3dc1970370c3797cb282bf2bebfb2566ba8c Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 19 Sep 2024 14:47:22 -0700 Subject: [PATCH] [ExecuTorch] Support bfloat16 in op_index Seems to block bfloat16 stories110M as exported by torchchat (and we should have op coverage for bfloat16 anyway). Differential Revision: [D63054001](https://our.internmc.facebook.com/intern/diff/D63054001/) [ghstack-poisoned] --- kernels/portable/cpu/op_index.cpp | 2 +- kernels/test/op_index_test.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kernels/portable/cpu/op_index.cpp b/kernels/portable/cpu/op_index.cpp index 780994cb75d..98f76a9e352 100644 --- a/kernels/portable/cpu/op_index.cpp +++ b/kernels/portable/cpu/op_index.cpp @@ -89,7 +89,7 @@ Tensor& index_Tensor_out( compute_dim_map(in, indices, dim_map, block_count == 1); compute_index_map(in, indices, ix_map); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() { const CTYPE* const in_data = in.const_data_ptr(); CTYPE* const out_data = out.mutable_data_ptr(); diff --git a/kernels/test/op_index_test.cpp b/kernels/test/op_index_test.cpp index 03a91005e83..35bd6a28da5 100644 --- a/kernels/test/op_index_test.cpp +++ b/kernels/test/op_index_test.cpp @@ -107,7 +107,7 @@ class OpIndexTensorOutTest : public OperatorTest { #define TEST_ENTRY(ctype, dtype) \ test_dtype(); - ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY }