diff --git a/kernels/portable/cpu/op_mul.cpp b/kernels/portable/cpu/op_mul.cpp index 8fc4f9d4593..34e7e085687 100644 --- a/kernels/portable/cpu/op_mul.cpp +++ b/kernels/portable/cpu/op_mul.cpp @@ -123,7 +123,11 @@ Tensor& mul_scalar_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_realhbbf16_type(out), + InvalidArgument, + out); ScalarType a_type = a.scalar_type(); ScalarType b_type = utils::get_scalar_dtype(b); diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index 84a7e8dedc4..f8205ea601e 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -586,3 +586,29 @@ TEST_F(OpMulScalarOutTest, OptimizedSanityCheck) { // Check that it matches the expected output. EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {2.6, 4.2, 9.2, 16.4})); } + +TEST_F(OpMulScalarOutTest, HalfSanityCheck) { + TensorFactory tf; + + const std::vector sizes = {2, 2}; + + Tensor out = tf.zeros(sizes); + + op_mul_scalar_out(tf.make(sizes, {1.3, 2.1, 4.6, 8.2}), 2.0, out); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {2.6, 4.2, 9.2, 16.4})); +} + +TEST_F(OpMulScalarOutTest, BFloat16SanityCheck) { + TensorFactory tf; + + const std::vector sizes = {2, 2}; + + Tensor out = tf.zeros(sizes); + + op_mul_scalar_out(tf.make(sizes, {1.3, 2.1, 4.6, 8.2}), 2.0, out); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {2.6, 4.2, 9.2, 16.4})); +}