Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions kernels/aten/cpu/op__clone_dim_order.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {

using Tensor = executorch::aten::Tensor;
using SizesArrayRef = executorch::aten::ArrayRef<executorch::aten::SizesType>;
using DimOrderArrayRef =
executorch::aten::ArrayRef<executorch::aten::DimOrderType>;
using MemoryFormat = executorch::aten::MemoryFormat;

template <typename T>
using OptionalArrayRef = executorch::aten::OptionalArrayRef<T>;

template <typename T>
using Optional = std::optional<T>;

namespace {
Optional<MemoryFormat> get_memory_format(OptionalArrayRef<int64_t> dim_order) {
if (!dim_order.has_value()) {
return executorch::aten::nullopt;
}
if (is_contiguous_dim_order(
dim_order.value().data(), dim_order.value().size())) {
return MemoryFormat::Contiguous;
} else if (is_channels_last_dim_order(
dim_order.value().data(), dim_order.value().size())) {
return MemoryFormat::ChannelsLast;
} else {
ET_ASSERT_UNREACHABLE();
}
}

bool check__clone_dim_order_args(
const Tensor& input,
bool non_blocking,
executorch::aten::OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
// Right now we only support blocking data transfer
ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false);

// Ensure input and output dtype match
ET_LOG_AND_RETURN_IF_FALSE(input.scalar_type() == out.scalar_type());

// dim_order is set, the target dim_order will be either contiguous or
// channels_last memory format
if (dim_order.has_value()) {
executorch::aten::ArrayRef<int64_t> dim_order_ref = dim_order.value();

// dim order size shall equal to input dim
ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim());

ET_LOG_AND_RETURN_IF_FALSE(
is_channels_last_dim_order(
dim_order.value().data(), dim_order.value().size()) ||
is_contiguous_dim_order(
dim_order.value().data(), dim_order.value().size()));

// Out Aten tensor shall have same memory format stride as dim_order
const size_t kMaxNumOfDimensions = 16;
ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim());
executorch::aten::StridesType target_strides[kMaxNumOfDimensions];
dim_order_to_stride_nocheck(
out.sizes().data(),
dim_order_ref.data(),
dim_order_ref.size(),
target_strides);
ET_LOG_AND_RETURN_IF_FALSE(out.dim() == dim_order_ref.size());
for (size_t i = 0; i < dim_order_ref.size(); i++) {
ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]);
}

} else { // dim_order is not set, preserve the dim order of input

auto out_strides = out.strides();
auto input_strides = input.strides();
ET_LOG_AND_RETURN_IF_FALSE(input_strides.size() == out_strides.size());
for (size_t i = 0; i < input_strides.size(); i++) {
ET_LOG_AND_RETURN_IF_FALSE(input_strides[i] == out_strides[i]);
}
}
return true;
}
} // namespace

// _clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]?
// dim_order=None, Tensor(a!) out) -> Tensor(a!)
Tensor& _clone_dim_order_out(
KernelRuntimeContext& ctx,
const Tensor& self,
bool non_blocking,
OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
// TODO(T181345875): enable sanity check in aten mode
ET_KERNEL_CHECK(
ctx,
check__clone_dim_order_args(self, non_blocking, dim_order, out),
InvalidArgument,
out);

Optional<MemoryFormat> memory_format = get_memory_format(dim_order);
at::clone_outf(self, memory_format, out);

return out;
}

Tensor& _clone_dim_order_out(
const Tensor& self,
bool non_blocking,
OptionalArrayRef<int64_t> dim_order,
Tensor& out) {
KernelRuntimeContext ctx{};
return _clone_dim_order_out(ctx, self, non_blocking, dim_order, out);
}

} // namespace native
} // namespace executor
} // namespace torch
6 changes: 6 additions & 0 deletions kernels/aten/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ _EDGE_DIALECT_OPS = (
"//executorch/kernels/aten/cpu/util:copy_ops_util",
],
),
op_target(
name = "op__clone_dim_order",
deps = [
"//executorch/kernels/aten/cpu/util:copy_ops_util",
],
),
)

def define_common_targets():
Expand Down
5 changes: 5 additions & 0 deletions kernels/aten/edge_dialect_aten_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
kernels:
- arg_meta: null
kernel_name: torch::executor::_to_dim_order_copy_out

- func: dim_order_ops::_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: torch::executor::_clone_dim_order_out
3 changes: 0 additions & 3 deletions kernels/test/op__clone_dim_order_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
*/

#include <cstdint>
#include <map>
#include <typeindex>
#include <variant>

#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator.
#include <executorch/kernels/test/TestUtil.h>
Expand Down
2 changes: 1 addition & 1 deletion kernels/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def define_common_targets():

_common_op_test("op__to_dim_order_copy_test", ["aten", "portable"])
_common_op_test("op__empty_dim_order_test", ["aten", "portable"])
_common_op_test("op__clone_dim_order_test", ["portable"])
_common_op_test("op__clone_dim_order_test", ["aten", "portable"])
_common_op_test("op_abs_test", ["aten", "portable"])
_common_op_test("op_acos_test", ["aten", "portable"])
_common_op_test("op_acosh_test", ["aten", "portable"])
Expand Down
6 changes: 4 additions & 2 deletions shim_et/xplat/executorch/kernels/test/util.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ def op_test(name, deps = [], kernel_name = "portable", use_kernel_prefix = False
if kernel_name == "aten":
generated_lib_and_op_deps = [
"//executorch/kernels/aten:generated_lib",
#TODO(T187390274): consolidate all aten ops into one target
"//executorch/kernels/aten/cpu:op__to_dim_order_copy_aten",
"//executorch/kernels/aten:generated_lib_headers",
"//executorch/kernels/test:supported_features_aten",
]

if "dim_order" in op_root:
generated_lib_and_op_deps.append("//executorch/kernels/aten/cpu:" + op_root + "_aten")

else:
generated_lib_and_op_deps = [
"//executorch/kernels/{}/cpu:{}".format(kernel_name, op_root),
Expand Down
Loading