From bc128e2f50d8008aff0c1acf3aa8e1756e3d0a55 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 19 Sep 2024 15:46:52 -0700 Subject: [PATCH] [ExecuTorch] Support bfloat16 in op_index_put Differential Revision: [D63057744](https://our.internmc.facebook.com/intern/diff/D63057744/) [ghstack-poisoned] --- kernels/portable/cpu/op_index_put.cpp | 4 ++-- kernels/test/op_index_put_test.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kernels/portable/cpu/op_index_put.cpp b/kernels/portable/cpu/op_index_put.cpp index e44d50f606c..33e67d207a9 100644 --- a/kernels/portable/cpu/op_index_put.cpp +++ b/kernels/portable/cpu/op_index_put.cpp @@ -53,7 +53,7 @@ Tensor& index_put_out( ET_KERNEL_CHECK( ctx, tensor_is_broadcastable_to(values, out), InvalidArgument, out); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() { apply_binary_elementwise_fn( [accumulate](const CTYPE val_in, const CTYPE val) { return accumulate ? val_in + val : val; @@ -120,7 +120,7 @@ Tensor& index_put_out( x_numel *= x_sizes[i]; } - ET_SWITCH_REALHB_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index_put.out", CTYPE, [&]() { const CTYPE* const values_data = values.const_data_ptr(); CTYPE* const out_data = out.mutable_data_ptr(); diff --git a/kernels/test/op_index_put_test.cpp b/kernels/test/op_index_put_test.cpp index b685edc6aaf..868c11600f4 100644 --- a/kernels/test/op_index_put_test.cpp +++ b/kernels/test/op_index_put_test.cpp @@ -707,7 +707,7 @@ TEST_F(OpIndexPutOutTest, AllDtypesSupportedForInput) { #define TEST_ENTRY(ctype, dtype) \ test_dtype(); - ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY }