Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions backends/cadence/hifi/operators/op_cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,34 +126,30 @@ Tensor& cat_out(
const size_t outer = getLeadingDims(out, dim);
const size_t dim_stride = getTrailingDims(out, dim);
const size_t ninputs = tensors.size();
const size_t element_size = out.element_size();
char* out_ptr = static_cast<char*>(out.mutable_data_ptr());

const auto out_type = out.scalar_type();
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
for (size_t i = 0; i < outer; ++i) {
for (size_t j = 0; j < ninputs; ++j) {
const auto in_type = tensors[j].scalar_type();
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
if (tensors[j].numel() == 0) {
return;
}
size_t inner = tensors[j].size(dim) * dim_stride;
const CTYPE_IN* const in_ptr =
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;

for (size_t k = 0; k < inner; ++k) {
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
}
out_ptr += inner;
});
for (size_t i = 0; i < outer; ++i) {
for (size_t j = 0; j < ninputs; ++j) {
if (tensors[j].numel() == 0) {
continue;
}
size_t inner_elements = tensors[j].size(dim) * dim_stride;
size_t contiguous_bytes = inner_elements * element_size;

const char* const in_ptr =
static_cast<const char*>(tensors[j].const_data_ptr()) +
i * contiguous_bytes;

std::memcpy(out_ptr, in_ptr, contiguous_bytes);
out_ptr += contiguous_bytes;
}
});
}

return out;
}

} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
} // namespace cadence
6 changes: 6 additions & 0 deletions backends/cadence/hifi/operators/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ void quantized_conv_per_tensor_out(
bool channel_last,
::executorch::aten::Tensor& out);

::executorch::aten::Tensor& cat_out(
::executorch::runtime::KernelRuntimeContext& ctx,
::executorch::aten::ArrayRef<::executorch::aten::Tensor> tensors,
int64_t dim,
::executorch::aten::Tensor& out);

} // namespace native
} // namespace HiFi
} // namespace impl
Expand Down
136 changes: 136 additions & 0 deletions backends/cadence/hifi/operators/tests/test_op_cat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <gtest/gtest.h>
#include <sys/times.h>
#include <xtensa/sim.h>

#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
#include <executorch/runtime/platform/runtime.h>

#include <executorch/backends/cadence/hifi/operators/operators.h>

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {
namespace {

using ::executorch::aten::ArrayRef;
using ::executorch::aten::ScalarType;
using ::executorch::aten::Tensor;
using ::executorch::aten::TensorImpl;
using ::executorch::runtime::Error;
using ::executorch::runtime::KernelRuntimeContext;
using ::executorch::runtime::runtime_init;
using ::executorch::runtime::testing::TensorFactory;

class HiFiCatTest : public OperatorTest {
public:
protected:
Tensor& cat_out(ArrayRef<Tensor> tensors, int64_t dim, Tensor& out) {
return ::cadence::impl::HiFi::native::cat_out(context_, tensors, dim, out);
}
};

TEST_F(HiFiCatTest, FloatCatDim0Test) {
TensorFactory<ScalarType::Float> tf;
Tensor a = tf.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
Tensor b = tf.make({1, 3}, {7.0, 8.0, 9.0});
Tensor c = tf.make({2, 3}, {10.0, 11.0, 12.0, 13.0, 14.0, 15.0});

Tensor expected = tf.make(
{5, 3},
{1.0,
2.0,
3.0,
4.0,
5.0,
6.0,
7.0,
8.0,
9.0,
10.0,
11.0,
12.0,
13.0,
14.0,
15.0});

Tensor out = tf.zeros({5, 3});
std::vector<Tensor> tensors = {a, b, c};

cat_out(ArrayRef<Tensor>(tensors.data(), tensors.size()), 0, out);
EXPECT_TENSOR_EQ(out, expected);
}

TEST_F(HiFiCatTest, FloatCatDim1Test) {
TensorFactory<ScalarType::Float> tf;
Tensor a = tf.make({2, 2}, {1.0, 2.0, 3.0, 4.0});
Tensor b = tf.make({2, 1}, {5.0, 6.0});
Tensor c = tf.make({2, 3}, {7.0, 8.0, 9.0, 10.0, 11.0, 12.0});

Tensor expected = tf.make(
{2, 6}, {1.0, 2.0, 5.0, 7.0, 8.0, 9.0, 3.0, 4.0, 6.0, 10.0, 11.0, 12.0});

Tensor out = tf.zeros({2, 6});
std::vector<Tensor> tensors = {a, b, c};

cat_out(ArrayRef<Tensor>(tensors.data(), tensors.size()), 1, out);
EXPECT_TENSOR_EQ(out, expected);
}

TEST_F(HiFiCatTest, IntCatDim0Test) {
TensorFactory<ScalarType::Int> tf;
Tensor a = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
Tensor b = tf.make({1, 3}, {7, 8, 9});

Tensor expected = tf.make({3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});

Tensor out = tf.zeros({3, 3});
std::vector<Tensor> tensors = {a, b};
cat_out(ArrayRef<Tensor>(tensors.data(), tensors.size()), 0, out);
EXPECT_TENSOR_EQ(out, expected);
}

TEST_F(HiFiCatTest, SingleTensorTest) {
TensorFactory<ScalarType::Float> tf;
Tensor a = tf.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
Tensor expected = tf.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});

Tensor out = tf.zeros({2, 3});
std::vector<Tensor> tensors = {a};
cat_out(ArrayRef<Tensor>(tensors.data(), tensors.size()), 0, out);
EXPECT_TENSOR_EQ(out, expected);
}

TEST_F(HiFiCatTest, ThreeDimensionalCatTest) {
TensorFactory<ScalarType::Float> tf;
Tensor a = tf.make({2, 2, 2}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0});
Tensor b = tf.make({2, 2, 1}, {9.0, 10.0, 11.0, 12.0});

Tensor expected = tf.make(
{2, 2, 3},
{1.0, 2.0, 9.0, 3.0, 4.0, 10.0, 5.0, 6.0, 11.0, 7.0, 8.0, 12.0});

Tensor out = tf.zeros({2, 2, 3});
std::vector<Tensor> tensors = {a, b};

cat_out(ArrayRef<Tensor>(tensors.data(), tensors.size()), 2, out);
EXPECT_TENSOR_EQ(out, expected);
}

} // namespace
} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
Loading