diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index b71585ef9dd..1350fc090b0 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -259,6 +259,8 @@ - op: mul.Scalar_out +- op: narrow_copy.out + - op: native_batch_norm.out - op: native_group_norm.out diff --git a/kernels/portable/cpu/op_narrow_copy.cpp b/kernels/portable/cpu/op_narrow_copy.cpp new file mode 100644 index 00000000000..0c21ec5b901 --- /dev/null +++ b/kernels/portable/cpu/op_narrow_copy.cpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; + +Tensor& narrow_copy_out( + RuntimeContext& ctx, + const Tensor& in, + int64_t dim, + int64_t start, + int64_t length, + Tensor& out) { + (void)ctx; + + ET_KERNEL_CHECK( + ctx, + check_narrow_copy_args(in, dim, start, length, out), + InvalidArgument, + out); + + if (dim < 0) { + dim += in.dim(); + } + + // @lint-ignore CLANGTIDY facebook-hte-CArray + Tensor::SizesType target_sizes[kTensorDimensionLimit]; + size_t target_ndim = 0; + get_narrow_copy_out_target_size(in, dim, length, target_sizes, &target_ndim); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {target_sizes, target_ndim}) == Error::Ok, + InvalidArgument, + out); + + if (length != 0) { + compute_slice(in, dim, start, length, 1, out); + } + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/op_slice_copy.cpp b/kernels/portable/cpu/op_slice_copy.cpp index d56bdcd864f..41a76567906 100644 --- a/kernels/portable/cpu/op_slice_copy.cpp +++ b/kernels/portable/cpu/op_slice_copy.cpp @@ -6,8 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include +#include #include #include @@ -41,40 +40,20 @@ Tensor& slice_copy_Tensor_out( // available) int64_t start = start_val.has_value() ? start_val.value() : 0; - int64_t num_values = adjust_slice_indices(in.size(dim), &start, &end, step); + int64_t length = adjust_slice_indices(in.size(dim), &start, &end, step); + // @lint-ignore CLANGTIDY facebook-hte-CArray Tensor::SizesType target_sizes[kTensorDimensionLimit]; size_t target_ndim = 0; - get_slice_copy_out_target_size( - in, dim, num_values, target_sizes, &target_ndim); + get_slice_copy_out_target_size(in, dim, length, target_sizes, &target_ndim); ET_KERNEL_CHECK( ctx, resize_tensor(out, {target_sizes, target_ndim}) == Error::Ok, InvalidArgument, out); - size_t dim_length = in.size(dim); + compute_slice(in, dim, start, length, step, out); - size_t leading_dims = getLeadingDims(in, dim); - size_t trailing_dims = getTrailingDims(in, dim); - - if (trailing_dims == 0) { - return out; - } - - size_t length_per_step = trailing_dims * in.element_size(); - - const char* input_data = in.const_data_ptr(); - char* dest = out.mutable_data_ptr(); - - for (int i = 0; i < leading_dims; i++) { - const char* src = input_data + (i * dim_length + start) * length_per_step; - for (int j = 0; j < num_values; j++) { - memcpy(dest, src, length_per_step); - src += step * length_per_step; - dest += length_per_step; - } - } return out; } diff --git a/kernels/portable/cpu/op_slice_scatter.cpp b/kernels/portable/cpu/op_slice_scatter.cpp index 367b626696f..a1f9ce4d921 100644 --- a/kernels/portable/cpu/op_slice_scatter.cpp +++ b/kernels/portable/cpu/op_slice_scatter.cpp @@ -9,7 +9,7 @@ #include #include -#include +#include #include namespace torch { diff --git a/kernels/portable/cpu/util/copy_ops_util.cpp b/kernels/portable/cpu/util/copy_ops_util.cpp index bcd72d96a3b..61c07d71a4b 100644 --- a/kernels/portable/cpu/util/copy_ops_util.cpp +++ b/kernels/portable/cpu/util/copy_ops_util.cpp @@ -411,33 +411,6 @@ void get_select_copy_out_target_size( } } -bool check_slice_copy_args( - const Tensor& in, - int64_t dim, - int64_t step, - Tensor& out) { - ET_LOG_AND_RETURN_IF_FALSE(in.dim() > 0); - ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out)); - ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim)); - ET_LOG_MSG_AND_RETURN_IF_FALSE( - step > 0, "slice step must be greater than zero"); - return true; -} - -void get_slice_copy_out_target_size( - const Tensor& in, - int64_t dim, - int64_t num_values, - exec_aten::SizesType* out_sizes, - size_t* out_ndim) { - *out_ndim = in.dim(); - - for (size_t d = 0; d < in.dim(); ++d) { - out_sizes[d] = in.size(d); - } - out_sizes[dim] = num_values; -} - bool check_split_with_sizes_copy_args( const Tensor& in, exec_aten::ArrayRef split_sizes, diff --git a/kernels/portable/cpu/util/copy_ops_util.h b/kernels/portable/cpu/util/copy_ops_util.h index ef0fc9579bd..91c62e707e9 100644 --- a/kernels/portable/cpu/util/copy_ops_util.h +++ b/kernels/portable/cpu/util/copy_ops_util.h @@ -136,19 +136,6 @@ void get_select_copy_out_target_size( exec_aten::SizesType* out_sizes, size_t* out_ndim); -bool check_slice_copy_args( - const Tensor& in, - int64_t dim, - int64_t step, - Tensor& out); - -void get_slice_copy_out_target_size( - const Tensor& in, - int64_t dim, - int64_t num_values, - exec_aten::SizesType* out_sizes, - size_t* out_ndim); - bool check_split_with_sizes_copy_args( const Tensor& in, exec_aten::ArrayRef split_sizes, diff --git a/kernels/portable/cpu/util/index_util.cpp b/kernels/portable/cpu/util/index_util.cpp index b1c9696fd62..39c556fa01c 100644 --- a/kernels/portable/cpu/util/index_util.cpp +++ b/kernels/portable/cpu/util/index_util.cpp @@ -261,82 +261,5 @@ bool check_select_scatter_args( return true; } -bool check_slice_scatter_args( - const Tensor& input, - const Tensor& src, - int64_t dim, - int64_t num_values, - int64_t step, - Tensor output) { - ET_LOG_AND_RETURN_IF_FALSE(input.dim() > 0); - - // Check dim. The dim planed to be selected on shall exist in input - ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim, input.dim())); - - // Input and output tensors should be the same shape and dtype - ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_shape_and_dtype(input, output)); - - // The input.dim() shall equal to src.dim() - ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_rank(input, src)); - - // Check step. Step must be greater than zero - ET_LOG_MSG_AND_RETURN_IF_FALSE( - step > 0, "slice step must be greater than zero"); - - // The size of src tensor should follow these rules: - // - src.size(i) shall equal to input.size(i) if i != dim, - // - src.size(dim) shall equal to num_values - for (size_t d = 0; d < input.dim() - 1; d++) { - if (d != dim) { - ET_LOG_AND_RETURN_IF_FALSE( - tensors_have_same_size_at_dims(input, d, src, d)); - } else { - ET_LOG_MSG_AND_RETURN_IF_FALSE( - src.size(d) == num_values, - "input.size(%zu) %zd != num_values %" PRId64 " | dim = %" PRId64 ")", - d, - input.size(d), - num_values, - dim); - } - } - - return true; -} - -int64_t adjust_slice_indices( - int64_t dim_length, - int64_t* start, - int64_t* end, - int64_t step) { - int64_t num_values = 0; - - // Update start and end index - // First convert it to c++ style from python style if needed. - // The start index is using python style E.g., for the shape {2, 3, 4}, - // dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so on. - *start = *start < 0 ? *start + dim_length : *start; - *end = *end < 0 ? *end + dim_length : *end; - // Second, if start or end still negative, which means user want to start or - // end slicing from very beginning, so set it to zero - *start = *start < 0 ? 0 : *start; - *end = *end < 0 ? 0 : *end; - // Last, if start or end larger than maximum value (dim_length - 1), indicates - // user want to start slicing after end or slicing until the end, so update it - // to dim_length - *start = *start > dim_length ? dim_length : *start; - *end = *end > dim_length ? dim_length : *end; - - if (*start >= dim_length || *end <= 0 || *start >= *end) { - // Set num_values to 0 if interval [start, end) is non-exist or do not - // overlap with [0, dim_length) - num_values = 0; - } else { - // Update num_values to min(max_num_values, num_values) - num_values = (*end - 1 - *start) / step + 1; - } - return num_values; -} - } // namespace executor } // namespace torch diff --git a/kernels/portable/cpu/util/index_util.h b/kernels/portable/cpu/util/index_util.h index 73d264a748c..0ee430c9726 100644 --- a/kernels/portable/cpu/util/index_util.h +++ b/kernels/portable/cpu/util/index_util.h @@ -64,19 +64,5 @@ bool check_select_scatter_args( int64_t index, Tensor& output); -bool check_slice_scatter_args( - const Tensor& input, - const Tensor& src, - int64_t dim, - int64_t num_values, - int64_t step, - Tensor output); - -int64_t adjust_slice_indices( - int64_t dim_length, - int64_t* start, - int64_t* end, - int64_t step); - } // namespace executor } // namespace torch diff --git a/kernels/portable/cpu/util/slice_util.cpp b/kernels/portable/cpu/util/slice_util.cpp new file mode 100644 index 00000000000..b9f5260e626 --- /dev/null +++ b/kernels/portable/cpu/util/slice_util.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace torch { +namespace executor { + +using Tensor = exec_aten::Tensor; + +bool check_narrow_copy_args( + const Tensor& in, + int64_t dim, + int64_t start, + int64_t lenth, + Tensor& out) { + ET_LOG_AND_RETURN_IF_FALSE(in.dim() > 0); + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out)); + ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim)); + ET_LOG_MSG_AND_RETURN_IF_FALSE(lenth >= 0, "lenth must be non-negative"); + ET_LOG_AND_RETURN_IF_FALSE(start >= -in.size(dim)); + ET_LOG_AND_RETURN_IF_FALSE(start <= in.size(dim)); + if (start < 0) { + start += in.size(dim); + } + ET_LOG_AND_RETURN_IF_FALSE(start + lenth <= in.size(dim)); + return true; +} + +void get_narrow_copy_out_target_size( + const Tensor& in, + int64_t dim, + int64_t length, + exec_aten::SizesType* out_sizes, + size_t* out_ndim) { + *out_ndim = in.dim(); + + for (size_t d = 0; d < in.dim(); ++d) { + out_sizes[d] = in.size(d); + } + out_sizes[dim] = length; +} + +bool check_slice_copy_args( + const Tensor& in, + int64_t dim, + int64_t step, + Tensor& out) { + ET_LOG_AND_RETURN_IF_FALSE(in.dim() > 0); + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out)); + ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim)); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + step > 0, "slice step must be greater than zero"); + return true; +} + +void get_slice_copy_out_target_size( + const Tensor& in, + int64_t dim, + int64_t length, + exec_aten::SizesType* out_sizes, + size_t* out_ndim) { + get_narrow_copy_out_target_size(in, dim, length, out_sizes, out_ndim); +} + +bool check_slice_scatter_args( + const Tensor& input, + const Tensor& src, + int64_t dim, + int64_t num_values, + int64_t step, + Tensor output) { + ET_LOG_AND_RETURN_IF_FALSE(input.dim() > 0); + + // Check dim. The dim planed to be selected on shall exist in input + ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim, input.dim())); + + // Input and output tensors should be the same shape and dtype + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_shape_and_dtype(input, output)); + + // The input.dim() shall equal to src.dim() + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_rank(input, src)); + + // Check step. Step must be greater than zero + ET_LOG_MSG_AND_RETURN_IF_FALSE( + step > 0, "slice step must be greater than zero"); + + // The size of src tensor should follow these rules: + // - src.size(i) shall equal to input.size(i) if i != dim, + // - src.size(dim) shall equal to num_values + for (size_t d = 0; d < input.dim() - 1; d++) { + if (d != dim) { + ET_LOG_AND_RETURN_IF_FALSE( + tensors_have_same_size_at_dims(input, d, src, d)); + } else { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + src.size(d) == num_values, + "input.size(%zu) %zd != num_values %" PRId64 " | dim = %" PRId64 ")", + d, + input.size(d), + num_values, + dim); + } + } + + return true; +} + +int64_t adjust_slice_indices( + int64_t dim_length, + int64_t* start, + int64_t* end, + int64_t step) { + int64_t num_values = 0; + + // Update start and end index + // First convert it to c++ style from python style if needed. + // The start index is using python style E.g., for the shape {2, 3, 4}, + // dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so on. + *start = *start < 0 ? *start + dim_length : *start; + *end = *end < 0 ? *end + dim_length : *end; + // Second, if start or end still negative, which means user want to start or + // end slicing from very beginning, so set it to zero + *start = *start < 0 ? 0 : *start; + *end = *end < 0 ? 0 : *end; + // Last, if start or end larger than maximum value (dim_length - 1), indicates + // user want to start slicing after end or slicing until the end, so update it + // to dim_length + *start = *start > dim_length ? dim_length : *start; + *end = *end > dim_length ? dim_length : *end; + + if (*start >= dim_length || *end <= 0 || *start >= *end) { + // Set num_values to 0 if interval [start, end) is non-exist or do not + // overlap with [0, dim_length) + num_values = 0; + } else { + // Update num_values to min(max_num_values, num_values) + num_values = (*end - 1 - *start) / step + 1; + } + return num_values; +} + +void compute_slice( + const Tensor& in, + int64_t dim, + int64_t start, + int64_t length, + int64_t step, + Tensor& out) { + size_t dim_length = in.size(dim); + + size_t leading_dims = getLeadingDims(in, dim); + size_t trailing_dims = getTrailingDims(in, dim); + + if (trailing_dims == 0) { + return; + } + + size_t length_per_step = trailing_dims * in.element_size(); + + const char* input_data = in.const_data_ptr(); + char* dest = out.mutable_data_ptr(); + + for (int i = 0; i < leading_dims; i++) { + const char* src = input_data + (i * dim_length + start) * length_per_step; + for (int j = 0; j < length; j++) { + memcpy(dest, src, length_per_step); + src += step * length_per_step; + dest += length_per_step; + } + } +} + +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/slice_util.h b/kernels/portable/cpu/util/slice_util.h new file mode 100644 index 00000000000..734f0dd3c6d --- /dev/null +++ b/kernels/portable/cpu/util/slice_util.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace torch { +namespace executor { + +bool check_narrow_copy_args( + const Tensor& in, + int64_t dim, + int64_t start, + int64_t length, + Tensor& out); + +void get_narrow_copy_out_target_size( + const Tensor& in, + int64_t dim, + int64_t length, + exec_aten::SizesType* out_sizes, + size_t* out_ndim); + +bool check_slice_copy_args( + const Tensor& in, + int64_t dim, + int64_t step, + Tensor& out); + +void get_slice_copy_out_target_size( + const Tensor& in, + int64_t dim, + int64_t num_values, + exec_aten::SizesType* out_sizes, + size_t* out_ndim); + +bool check_slice_scatter_args( + const Tensor& input, + const Tensor& src, + int64_t dim, + int64_t num_values, + int64_t step, + Tensor output); + +int64_t adjust_slice_indices( + int64_t dim_length, + int64_t* start, + int64_t* end, + int64_t step); + +void compute_slice( + const Tensor& in, + int64_t dim, + int64_t start, + int64_t length, + int64_t step, + Tensor& out); + +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index bd55b4da304..3961add0fd7 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -29,6 +29,7 @@ def define_common_targets(): "//executorch/kernels/portable/cpu/util:distance_util", "//executorch/kernels/portable/cpu/util:select_copy_util", "//executorch/kernels/portable/cpu/util:advanced_index_util", + "//executorch/kernels/portable/cpu/util:slice_util", ], visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], ) @@ -226,6 +227,16 @@ def define_common_targets(): visibility = ["//executorch/kernels/portable/cpu/..."], ) + runtime.cxx_library( + name = "slice_util", + srcs = ["slice_util.cpp"], + exported_headers = ["slice_util.h"], + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = ["//executorch/kernels/portable/cpu/..."], + ) + # Utility functions that can be used by operators that perform reduction for aten_mode in [True, False]: suffix = "_aten" if aten_mode else "" diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 69e0334051c..5136ea0a12f 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -587,6 +587,11 @@ - arg_meta: null kernel_name: torch::executor::mul_scalar_out +- op: narrow_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::narrow_copy_out + - op: native_group_norm.out kernels: - arg_meta: null diff --git a/kernels/test/op_narrow_copy_test.cpp b/kernels/test/op_narrow_copy_test.cpp new file mode 100644 index 00000000000..e453e46500a --- /dev/null +++ b/kernels/test/op_narrow_copy_test.cpp @@ -0,0 +1,197 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include + +#include + +using namespace ::testing; +using exec_aten::ArrayRef; +using exec_aten::optional; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::testing::TensorFactory; + +class OpNarrowCopyOutTest : public OperatorTest { + protected: + Tensor& op_narrow_copy_out( + const Tensor& in, + int64_t dim, + int64_t start, + int64_t length, + Tensor& out) { + return torch::executor::aten::narrow_copy_outf( + context_, in, dim, start, length, out); + } + + template + void test_dtype() { + TensorFactory tf; + + // clang-format off + Tensor input = tf.make( + /*sizes=*/{3, 4}, + /*data=*/{ + 1, 2, 3, 4, // [0, :] + 5, 6, 7, 8, // [1, :] + 9, 10, 11, 12, // [2, :] + }); + + Tensor expected = tf.make( + /*sizes=*/{2, 4}, + /*data=*/{ + 1, 2, 3, 4, // [0, :] + 5, 6, 7, 8, // [1, :] + }); + // clang-format on + + Tensor out = tf.zeros({2, 4}); + Tensor ret = + op_narrow_copy_out(input, /*dim=*/0, /*start=*/0, /*length=*/2, out); + + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_EQ(out, expected); + } +}; + +TEST_F(OpNarrowCopyOutTest, AllDtypesSupported) { +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpNarrowCopyOutTest, EmptyInputSupported) { + TensorFactory tf; + + Tensor input = tf.ones({1, 0, 1}); + Tensor out = tf.zeros({1, 0, 1}); + + Tensor expect = tf.ones({1, 0, 1}); + + Tensor ret = + op_narrow_copy_out(input, /*dim=*/0, /*start=*/0, /*length=*/1, out); + EXPECT_TENSOR_EQ(ret, out); + EXPECT_TENSOR_EQ(ret, expect); + + ret = op_narrow_copy_out(input, /*dim=*/1, /*start=*/0, /*length=*/0, out); + EXPECT_TENSOR_EQ(ret, out); + EXPECT_TENSOR_EQ(ret, expect); + + ret = op_narrow_copy_out(input, /*dim=*/2, /*start=*/0, /*length=*/1, out); + EXPECT_TENSOR_EQ(ret, out); + EXPECT_TENSOR_EQ(ret, expect); +} + +TEST_F(OpNarrowCopyOutTest, ZeroLengthSupported) { + TensorFactory tf; + + Tensor input = tf.ones({2, 3}); + Tensor out = tf.ones({2, 0}); + + Tensor expect = tf.ones({2, 0}); + + Tensor ret = + op_narrow_copy_out(input, /*dim=*/1, /*start=*/1, /*length=*/0, out); + EXPECT_TENSOR_EQ(ret, out); + EXPECT_TENSOR_EQ(ret, expect); + + ret = op_narrow_copy_out(input, /*dim=*/1, /*start=*/-1, /*length=*/0, out); + EXPECT_TENSOR_EQ(ret, out); + EXPECT_TENSOR_EQ(ret, expect); +} + +TEST_F(OpNarrowCopyOutTest, ZeroDimInputDies) { + TensorFactory tf; + + Tensor input = tf.ones({}); + Tensor out = tf.ones({}); + + // The operation shall die whatever the end is. + ET_EXPECT_KERNEL_FAILURE( + context_, + op_narrow_copy_out(input, /*dim=*/0, /*start=*/0, /*length=*/0, out)); + ET_EXPECT_KERNEL_FAILURE( + context_, + op_narrow_copy_out(input, /*dim=*/0, /*start=*/1, /*length=*/1, out)); +} + +TEST_F(OpNarrowCopyOutTest, InvalidStart) { + TensorFactory tf; + + Tensor input = tf.ones({2, 3}); + Tensor out = tf.ones({2, 3}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_narrow_copy_out(input, /*dim=*/0, /*start=*/-3, /*length=*/0, out)); + ET_EXPECT_KERNEL_FAILURE( + context_, + op_narrow_copy_out(input, /*dim=*/1, /*start=*/4, /*length=*/0, out)); +} + +TEST_F(OpNarrowCopyOutTest, InvalidStartLengthCombination) { + TensorFactory tf; + + Tensor input = tf.ones({2, 3}); + Tensor out = tf.ones({2, 3}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_narrow_copy_out(input, /*dim=*/0, /*start=*/0, /*length=*/3, out)); + ET_EXPECT_KERNEL_FAILURE( + context_, + op_narrow_copy_out(input, /*dim=*/1, /*start=*/-1, /*length=*/2, out)); +} + +TEST_F(OpNarrowCopyOutTest, NegativeLengthDies) { + TensorFactory tf; + + Tensor input = tf.ones({1, 1, 1}); + Tensor out = tf.zeros({1, 1, 1}); + + // Some invalid length values. + const std::vector invalid_lengths = {-3, -2, -1}; + for (int64_t length : invalid_lengths) { + ET_EXPECT_KERNEL_FAILURE( + context_, + op_narrow_copy_out( + input, /*dim=*/0, /*start=*/0, /*length=*/length, out)); + } +} + +TEST_F(OpNarrowCopyOutTest, DimOutOfBoundDies) { + TensorFactory tf; + + Tensor input = tf.ones({1, 1, 1}); + Tensor out = tf.zeros({1, 1, 1}); + + // Some invalid dim values. + const std::vector invalid_dims = {3, 4, 5, -4, -5, -6}; + for (int64_t dim : invalid_dims) { + ET_EXPECT_KERNEL_FAILURE( + context_, + op_narrow_copy_out(input, dim, /*start=*/0, /*length=*/1, out)); + } +} + +TEST_F(OpNarrowCopyOutTest, MismatchedDtypesDies) { + TensorFactory tf_int; + TensorFactory tf_float; + Tensor input = tf_int.zeros({1, 2, 2}); + + // Size is compatible to the output, but a mismatched dtype. + Tensor out = tf_float.ones({1, 2, 2}); + + ET_EXPECT_KERNEL_FAILURE( + context_, + op_narrow_copy_out(input, /*dim=*/0, /*start=*/0, /*length=*/1, out)); +} diff --git a/kernels/test/op_slice_copy_test.cpp b/kernels/test/op_slice_copy_test.cpp index 4c04e4bf51c..9aaf6f18dbc 100644 --- a/kernels/test/op_slice_copy_test.cpp +++ b/kernels/test/op_slice_copy_test.cpp @@ -475,6 +475,25 @@ TEST_F(OpSliceCopyTensorOutTest, EmptySizeInputDies) { input, /*dim=*/0, /*start=*/0, /*end=*/1, /*step=*/1, out)); } +TEST_F(OpSliceCopyTensorOutTest, ZeroLengthSupported) { + TensorFactory tf; + + Tensor input = tf.ones({2, 3}); + Tensor out = tf.ones({2, 0}); + + Tensor expect = tf.ones({2, 0}); + + Tensor ret = op_slice_copy_tensor_out( + input, /*dim=*/1, /*start=*/1, /*end=*/1, /*step=*/1, out); + EXPECT_TENSOR_EQ(ret, out); + EXPECT_TENSOR_EQ(ret, expect); + + ret = op_slice_copy_tensor_out( + input, /*dim=*/1, /*start=*/-1, /*end=*/-1, /*step=*/1, out); + EXPECT_TENSOR_EQ(ret, out); + EXPECT_TENSOR_EQ(ret, expect); +} + TEST_F(OpSliceCopyTensorOutTest, NonPostiveStepsDies) { TensorFactory tf; diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 749a221f9c0..7ae17c5237a 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -246,6 +246,7 @@ def define_common_targets(): _common_op_test("op_minimum_test", ["aten", "portable"]) _common_op_test("op_mm_test", ["aten", "portable"]) _common_op_test("op_mul_test", ["aten", "portable", "optimized"]) + _common_op_test("op_narrow_copy_test", ["aten", "portable"]) _common_op_test("op_native_batch_norm_test", ["aten", "portable"]) _common_op_test("op_native_group_norm_test", ["aten", "portable"]) _common_op_test("op_native_layer_norm_test", ["aten", "portable", "optimized"]) diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index b56f40c0215..ef8f936571c 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -808,6 +808,12 @@ ATEN_OPS = ( ":scalar_utils", ], ), + op_target( + name = "op_narrow_copy", + deps = [ + "//executorch/kernels/portable/cpu/util:slice_util", + ], + ), op_target( name = "op_native_batch_norm", deps = [ @@ -1042,14 +1048,13 @@ ATEN_OPS = ( op_target( name = "op_slice_copy", deps = [ - "//executorch/kernels/portable/cpu/util:copy_ops_util", - "//executorch/kernels/portable/cpu/util:index_util", + "//executorch/kernels/portable/cpu/util:slice_util", ], ), op_target( name = "op_slice_scatter", deps = [ - "//executorch/kernels/portable/cpu/util:index_util", + "//executorch/kernels/portable/cpu/util:slice_util", ], ), op_target(