diff --git a/kernels/portable/cpu/op_mm.cpp b/kernels/portable/cpu/op_mm.cpp index 4a6a8f3cfdc..1241182e4a9 100644 --- a/kernels/portable/cpu/op_mm.cpp +++ b/kernels/portable/cpu/op_mm.cpp @@ -34,19 +34,20 @@ mm_out(RuntimeContext& ctx, const Tensor& in, const Tensor& mat2, Tensor& out) { ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND(Half, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { - size_t m = in.size(0); - size_t n = in.size(1); - size_t p = mat2.size(1); - - vec_matmul( - out.mutable_data_ptr(), - in.const_data_ptr(), - mat2.const_data_ptr(), - m, - n, - p); - }); + ET_SWITCH_REAL_TYPES_AND2( + Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { + size_t m = in.size(0); + size_t n = in.size(1); + size_t p = mat2.size(1); + + vec_matmul( + out.mutable_data_ptr(), + in.const_data_ptr(), + mat2.const_data_ptr(), + m, + n, + p); + }); return out; } diff --git a/kernels/test/op_mm_test.cpp b/kernels/test/op_mm_test.cpp index 70d4b5ff0f5..c05792523f2 100644 --- a/kernels/test/op_mm_test.cpp +++ b/kernels/test/op_mm_test.cpp @@ -81,7 +81,7 @@ TEST_F(OpMmOutTest, OutputDim) { /// zeros(). TEST_F(OpMmOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES_AND(Half, TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY // TODO: Also add tests for half, complex, quantized, and other types. Easiest // way to do that would be to make TensorFactory support zeros() and ones()