From 89e01d132dba4607cbbfd8afd5dd5a1bfc04cf13 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 21 Jan 2025 13:45:24 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- kernels/portable/cpu/op_fill.cpp | 16 ++++++++-------- kernels/test/op_fill_test.cpp | 4 ++-- runtime/core/exec_aten/util/tensor_util.h | 9 ++++++--- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/kernels/portable/cpu/op_fill.cpp b/kernels/portable/cpu/op_fill.cpp index 50ef1359612..3992d308818 100644 --- a/kernels/portable/cpu/op_fill.cpp +++ b/kernels/portable/cpu/op_fill.cpp @@ -42,7 +42,7 @@ Tensor& fill_scalar_out( out, "Failed to resize output tensor."); - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Scalar_out", CTYPE_A, [&] { CTYPE_A b_casted; ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "fill.Scalar_out", CTYPE_B, [&] { CTYPE_B b_val; @@ -87,14 +87,14 @@ Tensor& fill_tensor_out( out, "Failed to resize output tensor."); - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "fill.Tensor_out", CTYPE_A, [&] { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "fill.Tensor_out", CTYPE_A, [&] { CTYPE_A b_casted; - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "fill.Tensor_out", CTYPE_B, [&] { - CTYPE_B b_val; - extract_scalar_tensor(b, &b_val); - b_casted = static_cast(b_val); - }); + ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "fill.Tensor_out", CTYPE_B, [&] { + CTYPE_B b_val; + ET_DCHECK_MSG( + extract_scalar_tensor(b, &b_val), "extract_scalar_tensor failed!"); + b_casted = static_cast(b_val); + }); apply_unary_map_fn( [b_casted](const CTYPE_A val_a) { return b_casted; }, diff --git a/kernels/test/op_fill_test.cpp b/kernels/test/op_fill_test.cpp index 0f6be0ecc61..a16cbed66cc 100644 --- a/kernels/test/op_fill_test.cpp +++ b/kernels/test/op_fill_test.cpp @@ -92,7 +92,7 @@ class OpFillTest : public OperatorTest { TEST_FILL_OUT(test_fill_scalar_out, DTYPE); \ } -ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_SCALAR_INPUT_SUPPORT_TEST) +ET_FORALL_REALHBBF16_TYPES(GENERATE_SCALAR_INPUT_SUPPORT_TEST) // Create input support tests for tensor variant. #define GENERATE_TENSOR_INPUT_SUPPORT_TEST(_, DTYPE) \ @@ -100,7 +100,7 @@ ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_SCALAR_INPUT_SUPPORT_TEST) TEST_FILL_OUT(test_fill_tensor_out, DTYPE); \ } -ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_TENSOR_INPUT_SUPPORT_TEST) +ET_FORALL_REALHBBF16_TYPES(GENERATE_TENSOR_INPUT_SUPPORT_TEST) TEST_F(OpFillTest, MismatchedOtherPropertiesDies) { TensorFactory tf; diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index eb57f3e099c..6fdc1bc2936 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -1062,8 +1062,11 @@ bool extract_scalar_tensor(executorch::aten::Tensor tensor, INT_T* out_val) { */ template < typename FLOAT_T, - typename std::enable_if::value, bool>:: - type = true> + typename std::enable_if< + std::is_floating_point_v || + std::is_same_v || + std::is_same_v, + bool>::type = true> bool extract_scalar_tensor(executorch::aten::Tensor tensor, FLOAT_T* out_val) { if (tensor.numel() != 1) { return false; @@ -1083,7 +1086,7 @@ bool extract_scalar_tensor(executorch::aten::Tensor tensor, FLOAT_T* out_val) { } switch (tensor.scalar_type()) { - ET_FORALL_REAL_TYPES(CASE_REAL_DTYPE); + ET_FORALL_REALHBF16_TYPES(CASE_REAL_DTYPE); default: return false; }