diff --git a/kernels/portable/cpu/op_cat.cpp b/kernels/portable/cpu/op_cat.cpp index 04a7a58a99f..5b0a308bda5 100644 --- a/kernels/portable/cpu/op_cat.cpp +++ b/kernels/portable/cpu/op_cat.cpp @@ -56,27 +56,58 @@ Tensor& cat_out( const size_t ninputs = tensors.size(); const auto out_type = out.scalar_type(); - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] { - CTYPE_OUT* out_ptr = out.mutable_data_ptr(); - for (size_t i = 0; i < outer; ++i) { - for (size_t j = 0; j < ninputs; ++j) { - const auto in_type = tensors[j].scalar_type(); - ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] { + const bool out_is_complex = + executorch::runtime::isComplexType(out.scalar_type()); + + if (out_is_complex) { + // TODO: The current support for complex dtype enforces that input and + // output tensors have the same dtype. Support mixed dtypes in the future. + for (size_t i = 0; i < ninputs; ++i) { + const auto in_type = tensors[i].scalar_type(); + ET_KERNEL_CHECK(ctx, out_type == in_type, InvalidArgument, out); + } + ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "cat.out", CTYPE, [&] { + CTYPE* out_ptr = out.mutable_data_ptr(); + for (size_t i = 0; i < outer; ++i) { + for (size_t j = 0; j < ninputs; ++j) { if (tensors[j].numel() == 0) { return; } size_t inner = tensors[j].size(dim) * dim_stride; - const CTYPE_IN* const in_ptr = - tensors[j].const_data_ptr() + i * inner; - - for (size_t k = 0; k < inner; ++k) { - out_ptr[k] = static_cast(in_ptr[k]); - } + const CTYPE* const in_ptr = + tensors[j].const_data_ptr() + i * inner; + memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE)); out_ptr += inner; - }); + } } - } - }); + }); + } else { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] { + CTYPE_OUT* out_ptr = out.mutable_data_ptr(); + for (size_t i = 0; i < outer; ++i) { + for (size_t j = 0; j < ninputs; ++j) { + const auto in_type = tensors[j].scalar_type(); + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] { + if (tensors[j].numel() == 0) { + return; + } + size_t inner = tensors[j].size(dim) * dim_stride; + const CTYPE_IN* const in_ptr = + tensors[j].const_data_ptr() + i * inner; + + if (sizeof(CTYPE_IN) == sizeof(CTYPE_OUT)) { + memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE_IN)); + } else { + for (size_t k = 0; k < inner; ++k) { + out_ptr[k] = static_cast(in_ptr[k]); + } + } + out_ptr += inner; + }); + } + } + }); + } return out; } diff --git a/kernels/test/op_cat_test.cpp b/kernels/test/op_cat_test.cpp index 9bdccb13a3b..4ea131452c7 100644 --- a/kernels/test/op_cat_test.cpp +++ b/kernels/test/op_cat_test.cpp @@ -73,6 +73,58 @@ class OpCatOutTest : public OperatorTest { tf.make({2, 4}, {1.5, -2.0, 3.25, 10.0, 4.0, -5.5, 6.5, 20.0}); EXPECT_TENSOR_EQ(out, expected); } + + template + void test_complex_dtype() { + TensorFactory tf; + Tensor x = tf.make( + {2, 2}, + {CTYPE(0.01, 2.03), + CTYPE(4.05, 6.07), + CTYPE(0.11, 2.13), + CTYPE(4.15, 6.17)}); + Tensor y = tf.make( + {2, 2}, + {CTYPE(0.21, 2.23), + CTYPE(4.25, 6.27), + CTYPE(0.31, 2.33), + CTYPE(4.35, 6.37)}); + + std::vector inputs = {x, y}; + + // Concatenate along dim[0]. + Tensor out_0 = tf.full({4, 2}, CTYPE{0, 0}); + Tensor ret_0 = op_cat_out( + ArrayRef(inputs.data(), inputs.size()), /*dim=*/0, out_0); + Tensor expected_0 = tf.make( + {4, 2}, + {CTYPE(0.01, 2.03), + CTYPE(4.05, 6.07), + CTYPE(0.11, 2.13), + CTYPE(4.15, 6.17), + CTYPE(0.21, 2.23), + CTYPE(4.25, 6.27), + CTYPE(0.31, 2.33), + CTYPE(4.35, 6.37)}); + + EXPECT_TENSOR_EQ(out_0, expected_0); + + // Concatenate along dim[1]. + Tensor out_1 = tf.full({2, 4}, CTYPE{0, 0}); + Tensor ret_1 = op_cat_out( + ArrayRef(inputs.data(), inputs.size()), /*dim=*/1, out_1); + Tensor expected_1 = tf.make( + {2, 4}, + {CTYPE(0.01, 2.03), + CTYPE(4.05, 6.07), + CTYPE(0.21, 2.23), + CTYPE(4.25, 6.27), + CTYPE(0.11, 2.13), + CTYPE(4.15, 6.17), + CTYPE(0.31, 2.33), + CTYPE(4.35, 6.37)}); + EXPECT_TENSOR_EQ(out_1, expected_1); + } }; TEST_F(OpCatOutTest, SmokeDim1) { @@ -133,6 +185,13 @@ TEST_F(OpCatOutTest, SixteenBitFloatSupport) { test_16bit_dtype(); } +TEST_F(OpCatOutTest, ComplexSupport) { +#define RUN_COMPLEX_TEST(ctype, dtype) \ + test_complex_dtype(); + ET_FORALL_COMPLEXH_TYPES(RUN_COMPLEX_TEST); +#undef RUN_COMPLEX_TEST +} + TEST_F(OpCatOutTest, NegativeDims) { TensorFactory tf;