From 847bade53337afaa82759f9247a5a1a185417410 Mon Sep 17 00:00:00 2001 From: Prashant Rawat Date: Tue, 29 Jul 2025 10:05:56 -0700 Subject: [PATCH] Extend cat op for complex dtype (#12894) Summary: Need complex cat op for live translation. The current support for complex dtype enforces that input and output tensors have the same dtype. Support mixed dtypes in the future. Differential Revision: D78934592 --- kernels/portable/cpu/op_cat.cpp | 61 +++++++++++++++++++++++++-------- kernels/test/op_cat_test.cpp | 59 +++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 15 deletions(-) diff --git a/kernels/portable/cpu/op_cat.cpp b/kernels/portable/cpu/op_cat.cpp index 04a7a58a99f..5b0a308bda5 100644 --- a/kernels/portable/cpu/op_cat.cpp +++ b/kernels/portable/cpu/op_cat.cpp @@ -56,27 +56,58 @@ Tensor& cat_out( const size_t ninputs = tensors.size(); const auto out_type = out.scalar_type(); - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] { - CTYPE_OUT* out_ptr = out.mutable_data_ptr(); - for (size_t i = 0; i < outer; ++i) { - for (size_t j = 0; j < ninputs; ++j) { - const auto in_type = tensors[j].scalar_type(); - ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] { + const bool out_is_complex = + executorch::runtime::isComplexType(out.scalar_type()); + + if (out_is_complex) { + // TODO: The current support for complex dtype enforces that input and + // output tensors have the same dtype. Support mixed dtypes in the future. + for (size_t i = 0; i < ninputs; ++i) { + const auto in_type = tensors[i].scalar_type(); + ET_KERNEL_CHECK(ctx, out_type == in_type, InvalidArgument, out); + } + ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "cat.out", CTYPE, [&] { + CTYPE* out_ptr = out.mutable_data_ptr(); + for (size_t i = 0; i < outer; ++i) { + for (size_t j = 0; j < ninputs; ++j) { if (tensors[j].numel() == 0) { return; } size_t inner = tensors[j].size(dim) * dim_stride; - const CTYPE_IN* const in_ptr = - tensors[j].const_data_ptr() + i * inner; - - for (size_t k = 0; k < inner; ++k) { - out_ptr[k] = static_cast(in_ptr[k]); - } + const CTYPE* const in_ptr = + tensors[j].const_data_ptr() + i * inner; + memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE)); out_ptr += inner; - }); + } } - } - }); + }); + } else { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] { + CTYPE_OUT* out_ptr = out.mutable_data_ptr(); + for (size_t i = 0; i < outer; ++i) { + for (size_t j = 0; j < ninputs; ++j) { + const auto in_type = tensors[j].scalar_type(); + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] { + if (tensors[j].numel() == 0) { + return; + } + size_t inner = tensors[j].size(dim) * dim_stride; + const CTYPE_IN* const in_ptr = + tensors[j].const_data_ptr() + i * inner; + + if (sizeof(CTYPE_IN) == sizeof(CTYPE_OUT)) { + memcpy(out_ptr, in_ptr, inner * sizeof(CTYPE_IN)); + } else { + for (size_t k = 0; k < inner; ++k) { + out_ptr[k] = static_cast(in_ptr[k]); + } + } + out_ptr += inner; + }); + } + } + }); + } return out; } diff --git a/kernels/test/op_cat_test.cpp b/kernels/test/op_cat_test.cpp index 9bdccb13a3b..4ea131452c7 100644 --- a/kernels/test/op_cat_test.cpp +++ b/kernels/test/op_cat_test.cpp @@ -73,6 +73,58 @@ class OpCatOutTest : public OperatorTest { tf.make({2, 4}, {1.5, -2.0, 3.25, 10.0, 4.0, -5.5, 6.5, 20.0}); EXPECT_TENSOR_EQ(out, expected); } + + template + void test_complex_dtype() { + TensorFactory tf; + Tensor x = tf.make( + {2, 2}, + {CTYPE(0.01, 2.03), + CTYPE(4.05, 6.07), + CTYPE(0.11, 2.13), + CTYPE(4.15, 6.17)}); + Tensor y = tf.make( + {2, 2}, + {CTYPE(0.21, 2.23), + CTYPE(4.25, 6.27), + CTYPE(0.31, 2.33), + CTYPE(4.35, 6.37)}); + + std::vector inputs = {x, y}; + + // Concatenate along dim[0]. + Tensor out_0 = tf.full({4, 2}, CTYPE{0, 0}); + Tensor ret_0 = op_cat_out( + ArrayRef(inputs.data(), inputs.size()), /*dim=*/0, out_0); + Tensor expected_0 = tf.make( + {4, 2}, + {CTYPE(0.01, 2.03), + CTYPE(4.05, 6.07), + CTYPE(0.11, 2.13), + CTYPE(4.15, 6.17), + CTYPE(0.21, 2.23), + CTYPE(4.25, 6.27), + CTYPE(0.31, 2.33), + CTYPE(4.35, 6.37)}); + + EXPECT_TENSOR_EQ(out_0, expected_0); + + // Concatenate along dim[1]. + Tensor out_1 = tf.full({2, 4}, CTYPE{0, 0}); + Tensor ret_1 = op_cat_out( + ArrayRef(inputs.data(), inputs.size()), /*dim=*/1, out_1); + Tensor expected_1 = tf.make( + {2, 4}, + {CTYPE(0.01, 2.03), + CTYPE(4.05, 6.07), + CTYPE(0.21, 2.23), + CTYPE(4.25, 6.27), + CTYPE(0.11, 2.13), + CTYPE(4.15, 6.17), + CTYPE(0.31, 2.33), + CTYPE(4.35, 6.37)}); + EXPECT_TENSOR_EQ(out_1, expected_1); + } }; TEST_F(OpCatOutTest, SmokeDim1) { @@ -133,6 +185,13 @@ TEST_F(OpCatOutTest, SixteenBitFloatSupport) { test_16bit_dtype(); } +TEST_F(OpCatOutTest, ComplexSupport) { +#define RUN_COMPLEX_TEST(ctype, dtype) \ + test_complex_dtype(); + ET_FORALL_COMPLEXH_TYPES(RUN_COMPLEX_TEST); +#undef RUN_COMPLEX_TEST +} + TEST_F(OpCatOutTest, NegativeDims) { TensorFactory tf;