diff --git a/kernels/portable/cpu/op_copy.cpp b/kernels/portable/cpu/op_copy.cpp index 764a50a5d20..86f2d5c62be 100644 --- a/kernels/portable/cpu/op_copy.cpp +++ b/kernels/portable/cpu/op_copy.cpp @@ -45,8 +45,8 @@ Tensor& copy_out( ScalarType in_type = in.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() { - ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() { apply_binary_elementwise_fn( [](const CTYPE val_in, const CTYPE_SRC val_src) { return convert(val_src); @@ -75,8 +75,8 @@ copy_(RuntimeContext& ctx, Tensor& in, const Tensor& src, bool non_blocking) { ScalarType in_type = in.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy_", CTYPE, [&]() { - ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy_", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() { apply_binary_elementwise_fn( [](const CTYPE val_in, const CTYPE_SRC val_src) { return convert(val_src); diff --git a/kernels/portable/cpu/op_mm.cpp b/kernels/portable/cpu/op_mm.cpp index 4a6a8f3cfdc..1241182e4a9 100644 --- a/kernels/portable/cpu/op_mm.cpp +++ b/kernels/portable/cpu/op_mm.cpp @@ -34,19 +34,20 @@ mm_out(RuntimeContext& ctx, const Tensor& in, const Tensor& mat2, Tensor& out) { ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND(Half, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { - size_t m = in.size(0); - size_t n = in.size(1); - size_t p = mat2.size(1); - - vec_matmul( - out.mutable_data_ptr(), - in.const_data_ptr(), - mat2.const_data_ptr(), - m, - n, - p); - }); + ET_SWITCH_REAL_TYPES_AND2( + Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { + size_t m = in.size(0); + size_t n = in.size(1); + size_t p = mat2.size(1); + + vec_matmul( + out.mutable_data_ptr(), + in.const_data_ptr(), + mat2.const_data_ptr(), + m, + n, + p); + }); return out; } diff --git a/kernels/test/op_copy_test.cpp b/kernels/test/op_copy_test.cpp index 82332f85eb2..007b10a7636 100644 --- a/kernels/test/op_copy_test.cpp +++ b/kernels/test/op_copy_test.cpp @@ -125,13 +125,13 @@ class OpCopyInplaceTest : public OperatorTest { // regular test for copy.out TEST_F(OpCopyTest, AllRealDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } TEST_F(OpCopyTest, EmptyInputSupported) { #define TEST_ENTRY(ctype, dtype) test_empty_input(); - ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } diff --git a/kernels/test/op_mm_test.cpp b/kernels/test/op_mm_test.cpp index 70d4b5ff0f5..c05792523f2 100644 --- a/kernels/test/op_mm_test.cpp +++ b/kernels/test/op_mm_test.cpp @@ -81,7 +81,7 @@ TEST_F(OpMmOutTest, OutputDim) { /// zeros(). TEST_F(OpMmOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES_AND(Half, TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY // TODO: Also add tests for half, complex, quantized, and other types. Easiest // way to do that would be to make TensorFactory support zeros() and ones()