diff --git a/kernels/portable/cpu/op_select_copy.cpp b/kernels/portable/cpu/op_select_copy.cpp index 274f149096b..140eea062de 100644 --- a/kernels/portable/cpu/op_select_copy.cpp +++ b/kernels/portable/cpu/op_select_copy.cpp @@ -8,7 +8,7 @@ #include -#include +#include #include namespace torch { @@ -23,64 +23,9 @@ Tensor& select_copy_int_out( int64_t dim, int64_t index, Tensor& out) { - (void)ctx; - - ET_KERNEL_CHECK( - ctx, - check_select_copy_out_args(in, dim, index, out), - InvalidArgument, - out); - - if (dim < 0) { - dim += nonzero_dim(in); - } - - Tensor::SizesType target_sizes[kTensorDimensionLimit]; - size_t target_ndim = 0; - get_select_copy_out_target_size(in, dim, target_sizes, &target_ndim); - - ET_KERNEL_CHECK( - ctx, - resize_tensor(out, {target_sizes, target_ndim}) == Error::Ok, - InvalidArgument, - out); - - // If the input is a empty tensor, no other operation could be done. We just - // return the output. - if (in.numel() == 0) { - return out; - } - // The code past this point assumes that the tensors are non-empty. - - // Support python-style negative indexing - if (index < 0) { - index += in.size(dim); - } - - size_t leading_dims = getLeadingDims(in, dim); - size_t trailing_dims = getTrailingDims(in, dim); - size_t dim_length = in.size(dim); - - // Number of bytes to copy in the each memcpy operation - size_t copy_size_per_op = trailing_dims * out.element_size(); - - // Step between the src locations of two adjcant memcpy operations - size_t src_step_per_op = dim_length * trailing_dims * in.element_size(); - - // the start point of data need to be copied is the start point of overall - // data chunk plus the offset between the overall start point and the first - // data to be copied. - char* input_data = in.mutable_data_ptr(); - - size_t start_offset = index * trailing_dims * in.element_size(); - char* src = input_data + start_offset; - - char* dest = out.mutable_data_ptr(); - - for (size_t j = 0; j < leading_dims; ++j) { - memcpy(dest, src, copy_size_per_op); - src += src_step_per_op; - dest += copy_size_per_op; + Error err = torch::executor::select_copy_util(in, dim, index, out); + if (err != Error::Ok) { + ctx.fail(err); } return out; } diff --git a/kernels/portable/cpu/targets.bzl b/kernels/portable/cpu/targets.bzl index b80e343f174..c86ac2b9f6d 100644 --- a/kernels/portable/cpu/targets.bzl +++ b/kernels/portable/cpu/targets.bzl @@ -794,6 +794,7 @@ _ATEN_OPS = ( name = "op_select_copy", deps = [ "//executorch/kernels/portable/cpu/util:copy_ops_util", + "//executorch/kernels/portable/cpu/util:select_copy_util", ], ), op_target( diff --git a/kernels/portable/cpu/util/select_copy_util.cpp b/kernels/portable/cpu/util/select_copy_util.cpp new file mode 100644 index 00000000000..cf56b3e4ca2 --- /dev/null +++ b/kernels/portable/cpu/util/select_copy_util.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include "executorch/kernels/portable/cpu/util/select_copy_util.h" + +namespace torch { +namespace executor { + +using Tensor = exec_aten::Tensor; + +Error select_copy_util( + const Tensor& in, + int64_t dim, + int64_t index, + Tensor& out) { + if (!check_select_copy_out_args(in, dim, index, out)) { + return Error::InvalidArgument; + } + + if (dim < 0) { + dim += nonzero_dim(in); + } + + Tensor::SizesType target_sizes[kTensorDimensionLimit]; + size_t target_ndim = 0; + get_select_copy_out_target_size(in, dim, target_sizes, &target_ndim); + + if (!(resize_tensor(out, {target_sizes, target_ndim}) == Error::Ok)) { + return Error::InvalidArgument; + } + + // If the input is a empty tensor, no other operation could be done. We just + // return the output. + if (in.numel() == 0) { + return Error::Ok; + } + // The code past this point assumes that the tensors are non-empty. + + // Support python-style negative indexing + if (index < 0) { + index += in.size(dim); + } + + size_t leading_dims = getLeadingDims(in, dim); + size_t trailing_dims = getTrailingDims(in, dim); + size_t dim_length = in.size(dim); + + // Number of bytes to copy in the each memcpy operation + size_t copy_size_per_op = trailing_dims * out.element_size(); + + // Step between the src locations of two adjcant memcpy operations + size_t src_step_per_op = dim_length * trailing_dims * in.element_size(); + + // the start point of data need to be copied is the start point of overall + // data chunk plus the offset between the overall start point and the first + // data to be copied. + char* input_data = in.mutable_data_ptr(); + + size_t start_offset = index * trailing_dims * in.element_size(); + char* src = input_data + start_offset; + + char* dest = out.mutable_data_ptr(); + + for (size_t j = 0; j < leading_dims; ++j) { + memcpy(dest, src, copy_size_per_op); + src += src_step_per_op; + dest += copy_size_per_op; + } + + return Error::Ok; +} + +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/select_copy_util.h b/kernels/portable/cpu/util/select_copy_util.h new file mode 100644 index 00000000000..4129f9d523a --- /dev/null +++ b/kernels/portable/cpu/util/select_copy_util.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace torch { +namespace executor { + +Error select_copy_util( + const Tensor& in, + int64_t dim, + int64_t index, + Tensor& out); + +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index f7ca5bce920..a846d07e59c 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -191,6 +191,18 @@ def define_common_targets(): visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/quantized/..."], ) + runtime.cxx_library( + name = "select_copy_util", + srcs = ["select_copy_util.cpp"], + exported_headers = ["select_copy_util.h"], + deps = [ + ":copy_ops_util", + "//executorch/runtime/kernel:kernel_includes", + "//executorch/runtime/core/exec_aten/util:tensor_util", + ], + visibility = ["//executorch/kernels/portable/cpu/..."], + ) + # Utility functions that can be used by operators that perform reduction for aten_mode in [True, False]: suffix = "_aten" if aten_mode else ""