diff --git a/kernels/portable/cpu/op_cat.cpp b/kernels/portable/cpu/op_cat.cpp index 566937caf1b..26f277a8514 100644 --- a/kernels/portable/cpu/op_cat.cpp +++ b/kernels/portable/cpu/op_cat.cpp @@ -56,12 +56,12 @@ Tensor& cat_out( const size_t ninputs = tensors.size(); const auto out_type = out.scalar_type(); - ET_SWITCH_REALHB_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] { CTYPE_OUT* out_ptr = out.mutable_data_ptr(); for (size_t i = 0; i < outer; ++i) { for (size_t j = 0; j < ninputs; ++j) { const auto in_type = tensors[j].scalar_type(); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] { if (tensors[j].numel() == 0) { return; } diff --git a/kernels/test/op_cat_test.cpp b/kernels/test/op_cat_test.cpp index 04d649b98f1..cf11e32db9d 100644 --- a/kernels/test/op_cat_test.cpp +++ b/kernels/test/op_cat_test.cpp @@ -53,6 +53,26 @@ class OpCatOutTest : public OperatorTest { EXPECT_TENSOR_EQ(out, expected); } + + template + void test_16bit_dtype() { + TensorFactory tf; + + Tensor x = tf.make({2, 3}, {1.5, -2.0, 3.25, 4.0, -5.5, 6.5}); + Tensor y = tf.make({2, 1}, {10.0, 20.0}); + + std::vector inputs = {x, y}; + + Tensor out = tf.zeros({2, 4}); + + // Concatenate along dim[1]. + Tensor ret = op_cat_out( + ArrayRef(inputs.data(), inputs.size()), /*dim=*/1, out); + + Tensor expected = + tf.make({2, 4}, {1.5, -2.0, 3.25, 10.0, 4.0, -5.5, 6.5, 20.0}); + EXPECT_TENSOR_EQ(out, expected); + } }; TEST_F(OpCatOutTest, SmokeDim1) { @@ -105,26 +125,12 @@ TEST_F(OpCatOutTest, SmokeDim1) { EXPECT_TENSOR_EQ(out, expected); } -TEST_F(OpCatOutTest, HalfSupport) { +TEST_F(OpCatOutTest, SixteenBitFloatSupport) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { - GTEST_SKIP() << "Test Half support only for ExecuTorch mode"; + GTEST_SKIP() << "Test Half/BF16 support only for ExecuTorch mode"; } - TensorFactory tf; - - Tensor x = tf.make({2, 3}, {1.5, -2.0, 3.25, 4.0, -5.5, 6.5}); - Tensor y = tf.make({2, 1}, {10.0, 20.0}); - - std::vector inputs = {x, y}; - - Tensor out = tf.zeros({2, 4}); - - // Concatenate along dim[1]. - Tensor ret = op_cat_out( - ArrayRef(inputs.data(), inputs.size()), /*dim=*/1, out); - - Tensor expected = - tf.make({2, 4}, {1.5, -2.0, 3.25, 10.0, 4.0, -5.5, 6.5, 20.0}); - EXPECT_TENSOR_EQ(out, expected); + test_16bit_dtype(); + test_16bit_dtype(); } TEST_F(OpCatOutTest, NegativeDims) {