diff --git a/kernels/portable/cpu/op_masked_fill.cpp b/kernels/portable/cpu/op_masked_fill.cpp index 643d5293ed4..b3192b95c2f 100644 --- a/kernels/portable/cpu/op_masked_fill.cpp +++ b/kernels/portable/cpu/op_masked_fill.cpp @@ -42,8 +42,8 @@ Tensor& masked_fill_scalar_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, mask, out), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND( - Bool, in_type, ctx, "masked_fill.Scalar_out", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES( + in_type, ctx, "masked_fill.Scalar_out", CTYPE, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, val_type, ctx, "masked_fill.Scalar_out", CTYPE_VAL, [&]() { CTYPE_VAL value_v; diff --git a/kernels/test/op_masked_fill_test.cpp b/kernels/test/op_masked_fill_test.cpp index d7ed8256400..0c08c2b7815 100644 --- a/kernels/test/op_masked_fill_test.cpp +++ b/kernels/test/op_masked_fill_test.cpp @@ -114,8 +114,11 @@ TEST_F(OpMaskedFillTest, IntTensorFloatAlphaDies) { tf.ones(sizes), tf.ones(sizes), /*alpha=*/.7, out)); } -TEST_F(OpMaskedFillTest, FloatTensors) { - test_floating_point_masked_fill_scalar_out(); +TEST_F(OpMaskedFillTest, FloatingPointTensors) { +#define TEST_ENTRY(ctype, dtype) \ + test_floating_point_masked_fill_scalar_out(); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY } TEST_F(OpMaskedFillTest, DoubleTensors) {