diff --git a/backends/cadence/hifi/operators/op_cat.cpp b/backends/cadence/hifi/operators/op_cat.cpp index 8ad52753de3..d4fd51871ce 100644 --- a/backends/cadence/hifi/operators/op_cat.cpp +++ b/backends/cadence/hifi/operators/op_cat.cpp @@ -126,29 +126,25 @@ Tensor& cat_out( const size_t outer = getLeadingDims(out, dim); const size_t dim_stride = getTrailingDims(out, dim); const size_t ninputs = tensors.size(); + const size_t element_size = out.element_size(); + char* out_ptr = static_cast(out.mutable_data_ptr()); - const auto out_type = out.scalar_type(); - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { - CTYPE_OUT* out_ptr = out.mutable_data_ptr(); - for (size_t i = 0; i < outer; ++i) { - for (size_t j = 0; j < ninputs; ++j) { - const auto in_type = tensors[j].scalar_type(); - ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] { - if (tensors[j].numel() == 0) { - return; - } - size_t inner = tensors[j].size(dim) * dim_stride; - const CTYPE_IN* const in_ptr = - tensors[j].const_data_ptr() + i * inner; - - for (size_t k = 0; k < inner; ++k) { - out_ptr[k] = static_cast(in_ptr[k]); - } - out_ptr += inner; - }); + for (size_t i = 0; i < outer; ++i) { + for (size_t j = 0; j < ninputs; ++j) { + if (tensors[j].numel() == 0) { + continue; } + size_t inner_elements = tensors[j].size(dim) * dim_stride; + size_t contiguous_bytes = inner_elements * element_size; + + const char* const in_ptr = + static_cast(tensors[j].const_data_ptr()) + + i * contiguous_bytes; + + std::memcpy(out_ptr, in_ptr, contiguous_bytes); + out_ptr += contiguous_bytes; } - }); + } return out; } @@ -156,4 +152,4 @@ Tensor& cat_out( } // namespace native } // namespace HiFi } // namespace impl -} // namespace cadence \ No newline at end of file +} // namespace cadence diff --git a/backends/cadence/hifi/operators/operators.h b/backends/cadence/hifi/operators/operators.h index 85a71dd5092..1321945c5e1 100644 --- a/backends/cadence/hifi/operators/operators.h +++ b/backends/cadence/hifi/operators/operators.h @@ -122,6 +122,12 @@ void quantized_conv_per_tensor_out( bool channel_last, ::executorch::aten::Tensor& out); +::executorch::aten::Tensor& cat_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + ::executorch::aten::ArrayRef<::executorch::aten::Tensor> tensors, + int64_t dim, + ::executorch::aten::Tensor& out); + } // namespace native } // namespace HiFi } // namespace impl diff --git a/backends/cadence/hifi/operators/tests/test_op_cat.cpp b/backends/cadence/hifi/operators/tests/test_op_cat.cpp new file mode 100644 index 00000000000..2f012ed6c81 --- /dev/null +++ b/backends/cadence/hifi/operators/tests/test_op_cat.cpp @@ -0,0 +1,136 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { +namespace { + +using ::executorch::aten::ArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::aten::TensorImpl; +using ::executorch::runtime::Error; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::runtime_init; +using ::executorch::runtime::testing::TensorFactory; + +class HiFiCatTest : public OperatorTest { + public: + protected: + Tensor& cat_out(ArrayRef tensors, int64_t dim, Tensor& out) { + return ::cadence::impl::HiFi::native::cat_out(context_, tensors, dim, out); + } +}; + +TEST_F(HiFiCatTest, FloatCatDim0Test) { + TensorFactory tf; + Tensor a = tf.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}); + Tensor b = tf.make({1, 3}, {7.0, 8.0, 9.0}); + Tensor c = tf.make({2, 3}, {10.0, 11.0, 12.0, 13.0, 14.0, 15.0}); + + Tensor expected = tf.make( + {5, 3}, + {1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 10.0, + 11.0, + 12.0, + 13.0, + 14.0, + 15.0}); + + Tensor out = tf.zeros({5, 3}); + std::vector tensors = {a, b, c}; + + cat_out(ArrayRef(tensors.data(), tensors.size()), 0, out); + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(HiFiCatTest, FloatCatDim1Test) { + TensorFactory tf; + Tensor a = tf.make({2, 2}, {1.0, 2.0, 3.0, 4.0}); + Tensor b = tf.make({2, 1}, {5.0, 6.0}); + Tensor c = tf.make({2, 3}, {7.0, 8.0, 9.0, 10.0, 11.0, 12.0}); + + Tensor expected = tf.make( + {2, 6}, {1.0, 2.0, 5.0, 7.0, 8.0, 9.0, 3.0, 4.0, 6.0, 10.0, 11.0, 12.0}); + + Tensor out = tf.zeros({2, 6}); + std::vector tensors = {a, b, c}; + + cat_out(ArrayRef(tensors.data(), tensors.size()), 1, out); + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(HiFiCatTest, IntCatDim0Test) { + TensorFactory tf; + Tensor a = tf.make({2, 3}, {1, 2, 3, 4, 5, 6}); + Tensor b = tf.make({1, 3}, {7, 8, 9}); + + Tensor expected = tf.make({3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + + Tensor out = tf.zeros({3, 3}); + std::vector tensors = {a, b}; + cat_out(ArrayRef(tensors.data(), tensors.size()), 0, out); + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(HiFiCatTest, SingleTensorTest) { + TensorFactory tf; + Tensor a = tf.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}); + Tensor expected = tf.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}); + + Tensor out = tf.zeros({2, 3}); + std::vector tensors = {a}; + cat_out(ArrayRef(tensors.data(), tensors.size()), 0, out); + EXPECT_TENSOR_EQ(out, expected); +} + +TEST_F(HiFiCatTest, ThreeDimensionalCatTest) { + TensorFactory tf; + Tensor a = tf.make({2, 2, 2}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}); + Tensor b = tf.make({2, 2, 1}, {9.0, 10.0, 11.0, 12.0}); + + Tensor expected = tf.make( + {2, 2, 3}, + {1.0, 2.0, 9.0, 3.0, 4.0, 10.0, 5.0, 6.0, 11.0, 7.0, 8.0, 12.0}); + + Tensor out = tf.zeros({2, 2, 3}); + std::vector tensors = {a, b}; + + cat_out(ArrayRef(tensors.data(), tensors.size()), 2, out); + EXPECT_TENSOR_EQ(out, expected); +} + +} // namespace +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence