diff --git a/kernels/aten/cpu/op__clone_dim_order.cpp b/kernels/aten/cpu/op__clone_dim_order.cpp new file mode 100644 index 00000000000..5e6c35d64f9 --- /dev/null +++ b/kernels/aten/cpu/op__clone_dim_order.cpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = executorch::aten::Tensor; +using SizesArrayRef = executorch::aten::ArrayRef; +using DimOrderArrayRef = + executorch::aten::ArrayRef; +using MemoryFormat = executorch::aten::MemoryFormat; + +template +using OptionalArrayRef = executorch::aten::OptionalArrayRef; + +template +using Optional = std::optional; + +namespace { +Optional get_memory_format(OptionalArrayRef dim_order) { + if (!dim_order.has_value()) { + return executorch::aten::nullopt; + } + if (is_contiguous_dim_order( + dim_order.value().data(), dim_order.value().size())) { + return MemoryFormat::Contiguous; + } else if (is_channels_last_dim_order( + dim_order.value().data(), dim_order.value().size())) { + return MemoryFormat::ChannelsLast; + } else { + ET_ASSERT_UNREACHABLE(); + } +} + +bool check__clone_dim_order_args( + const Tensor& input, + bool non_blocking, + executorch::aten::OptionalArrayRef dim_order, + Tensor& out) { + // Right now we only support blocking data transfer + ET_LOG_AND_RETURN_IF_FALSE(non_blocking == false); + + // Ensure input and output dtype match + ET_LOG_AND_RETURN_IF_FALSE(input.scalar_type() == out.scalar_type()); + + // dim_order is set, the target dim_order will be either contiguous or + // channels_last memory format + if (dim_order.has_value()) { + executorch::aten::ArrayRef dim_order_ref = dim_order.value(); + + // dim order size shall equal to input dim + ET_LOG_AND_RETURN_IF_FALSE(dim_order_ref.size() == input.dim()); + + ET_LOG_AND_RETURN_IF_FALSE( + is_channels_last_dim_order( + dim_order.value().data(), dim_order.value().size()) || + is_contiguous_dim_order( + dim_order.value().data(), dim_order.value().size())); + + // Out Aten tensor shall have same memory format stride as dim_order + const size_t kMaxNumOfDimensions = 16; + ET_LOG_AND_RETURN_IF_FALSE(kMaxNumOfDimensions >= out.dim()); + executorch::aten::StridesType target_strides[kMaxNumOfDimensions]; + dim_order_to_stride_nocheck( + out.sizes().data(), + dim_order_ref.data(), + dim_order_ref.size(), + target_strides); + ET_LOG_AND_RETURN_IF_FALSE(out.dim() == dim_order_ref.size()); + for (size_t i = 0; i < dim_order_ref.size(); i++) { + ET_LOG_AND_RETURN_IF_FALSE(target_strides[i] == out.strides()[i]); + } + + } else { // dim_order is not set, preserve the dim order of input + + auto out_strides = out.strides(); + auto input_strides = input.strides(); + ET_LOG_AND_RETURN_IF_FALSE(input_strides.size() == out_strides.size()); + for (size_t i = 0; i < input_strides.size(); i++) { + ET_LOG_AND_RETURN_IF_FALSE(input_strides[i] == out_strides[i]); + } + } + return true; +} +} // namespace + +// _clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? +// dim_order=None, Tensor(a!) out) -> Tensor(a!) +Tensor& _clone_dim_order_out( + KernelRuntimeContext& ctx, + const Tensor& self, + bool non_blocking, + OptionalArrayRef dim_order, + Tensor& out) { + // TODO(T181345875): enable sanity check in aten mode + ET_KERNEL_CHECK( + ctx, + check__clone_dim_order_args(self, non_blocking, dim_order, out), + InvalidArgument, + out); + + Optional memory_format = get_memory_format(dim_order); + at::clone_outf(self, memory_format, out); + + return out; +} + +Tensor& _clone_dim_order_out( + const Tensor& self, + bool non_blocking, + OptionalArrayRef dim_order, + Tensor& out) { + KernelRuntimeContext ctx{}; + return _clone_dim_order_out(ctx, self, non_blocking, dim_order, out); +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/aten/cpu/targets.bzl b/kernels/aten/cpu/targets.bzl index bb7083c1f01..e39bbdd144d 100644 --- a/kernels/aten/cpu/targets.bzl +++ b/kernels/aten/cpu/targets.bzl @@ -18,6 +18,12 @@ _EDGE_DIALECT_OPS = ( "//executorch/kernels/aten/cpu/util:copy_ops_util", ], ), + op_target( + name = "op__clone_dim_order", + deps = [ + "//executorch/kernels/aten/cpu/util:copy_ops_util", + ], + ), ) def define_common_targets(): diff --git a/kernels/aten/edge_dialect_aten_op.yaml b/kernels/aten/edge_dialect_aten_op.yaml index d9de3f6dded..1a74b3c71d1 100644 --- a/kernels/aten/edge_dialect_aten_op.yaml +++ b/kernels/aten/edge_dialect_aten_op.yaml @@ -11,3 +11,8 @@ kernels: - arg_meta: null kernel_name: torch::executor::_to_dim_order_copy_out + +- func: dim_order_ops::_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: torch::executor::_clone_dim_order_out diff --git a/kernels/test/op__clone_dim_order_test.cpp b/kernels/test/op__clone_dim_order_test.cpp index d999897cdf3..f009ce1b195 100644 --- a/kernels/test/op__clone_dim_order_test.cpp +++ b/kernels/test/op__clone_dim_order_test.cpp @@ -7,9 +7,6 @@ */ #include -#include -#include -#include #include // Declares the operator. #include diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index a4e681a7be1..7478f190185 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -177,7 +177,7 @@ def define_common_targets(): _common_op_test("op__to_dim_order_copy_test", ["aten", "portable"]) _common_op_test("op__empty_dim_order_test", ["aten", "portable"]) - _common_op_test("op__clone_dim_order_test", ["portable"]) + _common_op_test("op__clone_dim_order_test", ["aten", "portable"]) _common_op_test("op_abs_test", ["aten", "portable"]) _common_op_test("op_acos_test", ["aten", "portable"]) _common_op_test("op_acosh_test", ["aten", "portable"]) diff --git a/shim_et/xplat/executorch/kernels/test/util.bzl b/shim_et/xplat/executorch/kernels/test/util.bzl index cefb4fae6f0..0c702d12a18 100644 --- a/shim_et/xplat/executorch/kernels/test/util.bzl +++ b/shim_et/xplat/executorch/kernels/test/util.bzl @@ -21,11 +21,13 @@ def op_test(name, deps = [], kernel_name = "portable", use_kernel_prefix = False if kernel_name == "aten": generated_lib_and_op_deps = [ "//executorch/kernels/aten:generated_lib", - #TODO(T187390274): consolidate all aten ops into one target - "//executorch/kernels/aten/cpu:op__to_dim_order_copy_aten", "//executorch/kernels/aten:generated_lib_headers", "//executorch/kernels/test:supported_features_aten", ] + + if "dim_order" in op_root: + generated_lib_and_op_deps.append("//executorch/kernels/aten/cpu:" + op_root + "_aten") + else: generated_lib_and_op_deps = [ "//executorch/kernels/{}/cpu:{}".format(kernel_name, op_root),