Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 0 additions & 72 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
namespace torch {
namespace executor {
namespace native {
namespace impl {

Tensor& add_out(
KernelRuntimeContext& ctx,
Expand Down Expand Up @@ -152,77 +151,6 @@ Tensor& add_scalar_out(
return out;
}

} // namespace impl

Tensor& add_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
const Scalar& alpha,
Tensor& out) {
return impl::add_out(ctx, a, b, alpha, out);
}

Tensor& add_scalar_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Scalar& b,
const Scalar& alpha,
Tensor& out) {
return impl::add_scalar_out(ctx, a, b, alpha, out);
}

namespace utils {

Tensor& add_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
const Scalar& alpha,
Tensor& out) {
return impl::add_out(ctx, a, b, alpha, out);
}

Tensor& add_scalar_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Scalar& b,
const Scalar& alpha,
Tensor& out) {
return impl::add_scalar_out(ctx, a, b, alpha, out);
}

std::tuple<
Error,
std::array<executorch::aten::SizesType, kTensorDimensionLimit>,
size_t>
add_out_shape(const Tensor& a, const Tensor& b, ET_UNUSED const Scalar& alpha) {
std::array<executorch::aten::SizesType, kTensorDimensionLimit> out_sizes{};
size_t out_dim = 0;

Error err = get_broadcast_target_size(
a, b, out_sizes.data(), kTensorDimensionLimit, &out_dim);

return std::make_tuple(err, out_sizes, out_dim);
}

std::tuple<
Error,
std::array<executorch::aten::SizesType, kTensorDimensionLimit>,
size_t>
add_scalar_out_shape(
const Tensor& a,
ET_UNUSED const Scalar& b,
ET_UNUSED const Scalar& alpha) {
std::array<executorch::aten::SizesType, kTensorDimensionLimit> out_sizes{};
size_t out_dim = a.dim();

std::copy(a.sizes().begin(), a.sizes().end(), out_sizes.begin());

return std::make_tuple(Error::Ok, out_sizes, out_dim);
}

} // namespace utils
} // namespace native
} // namespace executor
} // namespace torch
65 changes: 0 additions & 65 deletions kernels/portable/cpu/op_add.h

This file was deleted.

125 changes: 2 additions & 123 deletions kernels/portable/cpu/op_stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,142 +6,21 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cstring>

#include <c10/util/irange.h>
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
#include <executorch/kernels/portable/cpu/util/stack_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {
namespace impl {

using Tensor = executorch::aten::Tensor;

Tensor& stack_out(
KernelRuntimeContext& ctx,
executorch::aten::ArrayRef<Tensor> tensors,
int64_t dim,
Tensor& out) {
(void)ctx;

if (dim < 0) {
dim += out.dim();
}

ET_KERNEL_CHECK(
ctx, check_stack_args(tensors, dim, out), InvalidArgument, out);

for (size_t i = 0; i < tensors.size(); ++i) {
ET_KERNEL_CHECK(
ctx,
tensors_have_same_dim_order(tensors[i], out),
InvalidArgument,
out);
}

ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out);

Tensor::SizesType expected_out_size[kTensorDimensionLimit];
size_t expected_out_dim = 0;
get_stack_out_target_size(tensors, dim, expected_out_size, &expected_out_dim);
ET_KERNEL_CHECK(
ctx,
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
InvalidArgument,
out);

const size_t outer = getLeadingDims(out, dim);
const size_t inner = getTrailingDims(out, dim);
const size_t ninputs = tensors.size();

const auto out_type = out.scalar_type();
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "stack.out", CTYPE_OUT, [&] {
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
for (size_t i = 0; i < outer; ++i) {
for (size_t j = 0; j < ninputs; ++j) {
const auto in_type = tensors[j].scalar_type();
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "stack.out", CTYPE_IN, [&] {
const CTYPE_IN* const in_ptr =
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;

for (size_t k = 0; k < inner; ++k) {
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
}
out_ptr += inner;
});
}
}
});

return out;
}

} // namespace impl

Tensor& stack_out(
KernelRuntimeContext& ctx,
executorch::aten::ArrayRef<Tensor> tensors,
int64_t dim,
Tensor& out) {
return impl::stack_out(ctx, tensors, dim, out);
}

namespace utils {

Tensor& stack_out(
KernelRuntimeContext& ctx,
executorch::aten::ArrayRef<Tensor> tensors,
int64_t dim,
Tensor& out) {
return impl::stack_out(ctx, tensors, dim, out);
}

std::tuple<
Error,
std::array<executorch::aten::SizesType, kTensorDimensionLimit>,
size_t>
stack_out_shape(executorch::aten::ArrayRef<Tensor> tensors, int64_t dim) {
std::array<executorch::aten::SizesType, kTensorDimensionLimit> out_sizes{};
size_t out_dim = 0;

// Check if tensors array is empty
if (tensors.size() == 0) {
return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim);
}

// Normalize negative dimension
int64_t normalized_dim = dim;
if (normalized_dim < 0) {
normalized_dim += tensors[0].dim() + 1;
}

// Check if dimension is valid
if (normalized_dim < 0 || normalized_dim > tensors[0].dim()) {
return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim);
}

// Check that all tensors have the same shape
for (size_t i = 1; i < tensors.size(); ++i) {
if (tensors[i].dim() != tensors[0].dim()) {
return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim);
}
for (const auto d : c10::irange(tensors[0].dim())) {
if (tensors[i].size(d) != tensors[0].size(d)) {
return std::make_tuple(Error::InvalidArgument, out_sizes, out_dim);
}
}
}

// Compute output shape using the existing utility
::torch::executor::get_stack_out_target_size(
tensors, normalized_dim, out_sizes.data(), &out_dim);

return std::make_tuple(Error::Ok, out_sizes, out_dim);
return utils::stack_out_impl(ctx, tensors, dim, out);
}

} // namespace utils
} // namespace native
} // namespace executor
} // namespace torch
Loading
Loading