diff --git a/kernels/portable/cpu/op_diagonal_copy.cpp b/kernels/portable/cpu/op_diagonal_copy.cpp index 8bb64b94d9b..6d923a6d904 100644 --- a/kernels/portable/cpu/op_diagonal_copy.cpp +++ b/kernels/portable/cpu/op_diagonal_copy.cpp @@ -98,7 +98,7 @@ Tensor& diagonal_copy_out( constexpr auto name = "diagonal_copy.out"; - ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] { + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] { diagonal_copy_impl(in, offset, dim1, dim2, out); }); diff --git a/kernels/test/op_diagonal_copy_test.cpp b/kernels/test/op_diagonal_copy_test.cpp index 5ad69066532..a878edd2e46 100644 --- a/kernels/test/op_diagonal_copy_test.cpp +++ b/kernels/test/op_diagonal_copy_test.cpp @@ -39,16 +39,23 @@ class OpDiagonalCopyOutTest : public ::testing::Test { // first. torch::executor::runtime_init(); } + + template + void test_2d_dtype() { + TensorFactory tf; + + Tensor input = tf.make({3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor out = tf.zeros({2}); + Tensor out_expected = tf.make({2}, {5, 10}); + op_diagonal_copy_out(input, 1, 1, 0, out); + EXPECT_TENSOR_CLOSE(out, out_expected); + } }; TEST_F(OpDiagonalCopyOutTest, SmokeTest2D) { - TensorFactory tfFloat; - - Tensor input = tfFloat.make({3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - Tensor out = tfFloat.zeros({2}); - Tensor out_expected = tfFloat.make({2}, {5, 10}); - op_diagonal_copy_out(input, 1, 1, 0, out); - EXPECT_TENSOR_CLOSE(out, out_expected); +#define TEST_ENTRY(ctype, dtype) test_2d_dtype(); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY } TEST_F(OpDiagonalCopyOutTest, SmokeTest3D) {