diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index f28cfb48b36..98208bd17a9 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -321,6 +321,8 @@ - op: scalar_tensor.out +- op: scatter.src_out + - op: scatter.value_out - op: scatter_add.out diff --git a/kernels/portable/cpu/op_scatter.cpp b/kernels/portable/cpu/op_scatter.cpp index 9696ab4f14d..0a2fee9a61e 100644 --- a/kernels/portable/cpu/op_scatter.cpp +++ b/kernels/portable/cpu/op_scatter.cpp @@ -23,6 +23,46 @@ using ScalarType = exec_aten::ScalarType; namespace { +template +void scatter_src_helper( + const Tensor& in, + int64_t dim, + const Tensor& index, + const Tensor& src, + Tensor& out) { + const CTYPE* in_data = in.const_data_ptr(); + const long* index_data = index.const_data_ptr(); + const CTYPE* src_data = src.const_data_ptr(); + CTYPE* out_data = out.mutable_data_ptr(); + + memcpy(out_data, in_data, in.nbytes()); + + if (dim < 0) { + dim += nonzero_dim(in); + } + + for (size_t ix = 0; ix < index.numel(); ++ix) { + // @lint-ignore CLANGTIDY facebook-hte-CArray + size_t ix_coord[kTensorDimensionLimit]; + indexToCoordinate(index, ix, ix_coord); + + size_t src_ix = coordinateToIndex(src, ix_coord); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + size_t out_coord[kTensorDimensionLimit]; + for (size_t i = 0; i < out.dim(); ++i) { + if (i == dim) { + out_coord[i] = index_data[ix]; + } else { + out_coord[i] = ix_coord[i]; + } + } + size_t out_ix = coordinateToIndex(out, out_coord); + + out_data[out_ix] = src_data[src_ix]; + } +} + template void scatter_value_helper( const Tensor& in, @@ -36,15 +76,16 @@ void scatter_value_helper( memcpy(out_data, in_data, in.nbytes()); - if (index.dim() == 0) { - out_data[index_data[0]] = static_cast(val); - return; + if (dim < 0) { + dim += nonzero_dim(in); } for (size_t ix = 0; ix < index.numel(); ++ix) { + // @lint-ignore CLANGTIDY facebook-hte-CArray size_t ix_coord[kTensorDimensionLimit]; indexToCoordinate(index, ix, ix_coord); + // @lint-ignore CLANGTIDY facebook-hte-CArray size_t out_coord[kTensorDimensionLimit]; for (size_t i = 0; i < out.dim(); ++i) { if (i == dim) { @@ -61,6 +102,36 @@ void scatter_value_helper( } // namespace +Tensor& scatter_src_out( + RuntimeContext& context, + const Tensor& in, + int64_t dim, + const Tensor& index, + const Tensor& src, + Tensor& out) { + (void)context; + + ET_KERNEL_CHECK( + context, + check_scatter_src_args(in, dim, index, src, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + context, + resize_tensor(out, in.sizes()) == Error::Ok, + InvalidArgument, + out); + + constexpr auto name = "scatter.src_out"; + + ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { + scatter_src_helper(in, dim, index, src, out); + }); + + return out; +} + Tensor& scatter_value_out( RuntimeContext& ctx, const Tensor& in, @@ -79,10 +150,6 @@ Tensor& scatter_value_out( ET_KERNEL_CHECK( ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); - if (dim < 0) { - dim += nonzero_dim(in); - } - ScalarType val_type = utils::get_scalar_dtype(value); constexpr auto name = "scatter.value_out"; diff --git a/kernels/portable/cpu/util/index_util.cpp b/kernels/portable/cpu/util/index_util.cpp index ca9900773a1..b1c9696fd62 100644 --- a/kernels/portable/cpu/util/index_util.cpp +++ b/kernels/portable/cpu/util/index_util.cpp @@ -191,6 +191,15 @@ bool check_scatter_add_args( return true; } +bool check_scatter_src_args( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& src, + Tensor& out) { + return check_scatter_add_args(self, dim, index, src, out); +} + bool check_scatter_value_args( const Tensor& self, int64_t dim, diff --git a/kernels/portable/cpu/util/index_util.h b/kernels/portable/cpu/util/index_util.h index ae6654be52b..73d264a748c 100644 --- a/kernels/portable/cpu/util/index_util.h +++ b/kernels/portable/cpu/util/index_util.h @@ -43,6 +43,13 @@ bool check_scatter_add_args( const Tensor& src, Tensor& out); +bool check_scatter_src_args( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& src, + Tensor& out); + bool check_scatter_value_args( const Tensor& self, int64_t dim, diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 21258329aa8..d3da47c48d6 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -737,6 +737,11 @@ - arg_meta: null kernel_name: torch::executor::scalar_tensor_out +- op: scatter.src_out + kernels: + - arg_meta: null + kernel_name: torch::executor::scatter_src_out + - op: scatter.value_out kernels: - arg_meta: null diff --git a/kernels/test/op_scatter_test.cpp b/kernels/test/op_scatter_test.cpp index 2335c839d00..83c112a8c34 100644 --- a/kernels/test/op_scatter_test.cpp +++ b/kernels/test/op_scatter_test.cpp @@ -22,6 +22,189 @@ using exec_aten::ScalarType; using exec_aten::Tensor; using torch::executor::testing::TensorFactory; +class OpScatterSrcOutTest : public OperatorTest { + protected: + Tensor& op_scatter_src_out( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& src, + Tensor& out) { + return torch::executor::aten::scatter_outf( + context_, self, dim, index, src, out); + } + + // Common testing for the operator + template + void test_scatter_src_out() { + TensorFactory tf_index; + TensorFactory tf_data; + const std::vector sizes = {3, 5}; + // clang-format off + Tensor src = tf_data.make( + /*sizes=*/{2, 5}, + { + 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10 + }); + // clang-format on + Tensor in = tf_data.zeros(sizes); + Tensor out = tf_data.zeros(sizes); + // clang-format off + Tensor index = tf_index.make( + /*sizes=*/{2, 3}, + { + 0, 1, 2, + 0, 1, 2 + }); + // clang-format on + + // Valid input should give the expected output + op_scatter_src_out(in, 0, index, src, out); + // clang-format off + EXPECT_TENSOR_EQ( + out, tf_data.make( + sizes, + { + 6, 0, 0, 0, 0, + 0, 7, 0, 0, 0, + 0, 0, 8, 0, 0 + })); + // clang-format on + + // Valid input should give the expected output + op_scatter_src_out(in, 1, index, src, out); + // clang-format off + EXPECT_TENSOR_EQ( + out, tf_data.make(sizes, + { + 1, 2, 3, 0, 0, + 6, 7, 8, 0, 0, + 0, 0, 0, 0, 0 + })); + + src = tf_data.make( + /*sizes=*/{2, 3, 3}, + { + // [0, :, :] + 1, 2, 3, + 4, 5, 6, + 7, 8, 9, + + // [1, :, :] + 10, 11, 12, + 13, 14, 15, + 16, 17, 18 + }); + // clang-format on + in = tf_data.ones(/*sizes=*/{2, 3, 3}); + out = tf_data.zeros(/*sizes=*/{2, 3, 3}); + // clang-format off + index = tf_index.make( + /*sizes=*/{1, 3, 2}, + { + 0, 1, + 1, 2, + 0, 2 + }); + // clang-format on + + op_scatter_src_out(in, 1, index, src, out); + // clang-format off + EXPECT_TENSOR_EQ( + out, + tf_data.make( + /*sizes=*/{2, 3, 3}, + { + // [0, :, :] + 7, 1, 1, + 4, 2, 1, + 1, 8, 1, + + // [1, :, :] + 1, 1, 1, + 1, 1, 1, + 1, 1, 1 + })); + // clang-format on + + out = tf_data.zeros(/*sizes=*/{2, 3, 3}); + op_scatter_src_out(in, 2, index, src, out); + // clang-format off + EXPECT_TENSOR_EQ( + out, + tf_data.make( + /*sizes=*/{2, 3, 3}, + { + // [0, :, :] + 1, 2, 1, + 1, 4, 5, + 7, 1, 8, + + // [1, :, :] + 1, 1, 1, + 1, 1, 1, + 1, 1, 1 + })); + // clang-format on + } + + // Invalid dimensions + template + void test_scatter_src_out_invalid_dim() { + TensorFactory tf_index; + TensorFactory tf_data; + const std::vector sizes = {3, 5}; + // clang-format off + Tensor src = tf_data.make(/*sizes=*/{2, 5}, + { + 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10 + }); + Tensor index = tf_index.make(/*sizes=*/{2, 3}, + { + 0, 1, 2, + 0, 1, 2 + }); + // clang-format on + Tensor self = tf_data.zeros(sizes); + Tensor out = tf_data.zeros(sizes); + + // Invalid dim should die + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_src_out(self, -3, index, src, out)); + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_src_out(self, 2, index, src, out)); + + // Self, index and src hsould have same number of dimensions + src = tf_data.zeros(/*sizes=*/{2, 2, 2}); + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_src_out(self, 0, index, src, out)); + + src = tf_data.zeros(/*sizes=*/{5, 5}); + index = tf_index.zeros(/*sizes=*/{2, 2, 2}); + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_src_out(self, 0, index, src, out)); + + // Size of dimension of index should be smaller than the size of that + // dimension of src + index = tf_index.zeros(/*sizes=*/{4, 6}); + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_src_out(self, 0, index, src, out)); + + // Size of dimension of index should be smaller than the size of that + // dimension of self if dimension != dim + index = tf_index.zeros(/*sizes=*/{4, 5}); + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_src_out(self, 1, index, src, out)); + + // Index out of bound for self in dim + index = tf_index.make(/*sizes=*/{2, 3}, {0, 1, 3, 0, 1, 3}); + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_src_out(self, 0, index, src, out)); + } +}; + class OpScatterValueOutTest : public OperatorTest { protected: Tensor& op_scatter_value_out( @@ -183,6 +366,19 @@ class OpScatterValueOutTest : public OperatorTest { } }; +TEST_F(OpScatterSrcOutTest, AllValidInputOutputSupport) { +#define TEST_ENTRY(CTYPE, DTYPE) test_scatter_src_out(); + ET_FORALL_REAL_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + +TEST_F(OpScatterSrcOutTest, InvalidDimensionsDies) { +#define TEST_ENTRY(CTYPE, DTYPE) \ + test_scatter_src_out_invalid_dim(); + ET_FORALL_REAL_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} + TEST_F(OpScatterValueOutTest, AllValidInputOutputSupport) { #define TEST_ENTRY(CTYPE, DTYPE) test_scatter_value_out(); ET_FORALL_REAL_TYPES(TEST_ENTRY); @@ -360,3 +556,99 @@ TEST_F(OpScatterValueOutTest, InvalidOneDimInputAndZeroDimIndex) { ET_EXPECT_KERNEL_FAILURE( context_, op_scatter_value_out(self, 0, index, value, out)); } + +TEST_F(OpScatterSrcOutTest, EmptyIndex) { + TensorFactory tf_index; + TensorFactory tf_data; + + Tensor self = tf_data.ones({2, 5}); + Tensor index = tf_index.zeros({2, 0, 3}); + Tensor src = tf_data.ones({1, 1, 4}); + Tensor out = tf_data.zeros({2, 5}); + op_scatter_src_out(self, 0, index, src, out); + EXPECT_TENSOR_CLOSE(out, tf_data.ones({2, 5})); +} + +TEST_F(OpScatterSrcOutTest, ValidZeroDim) { + TensorFactory tf_index; + TensorFactory tf_data; + + Tensor self = tf_data.make({}, {3.14}); + Tensor index = tf_index.zeros({}); + Tensor src = tf_data.make({}, {5}); + Tensor out = tf_data.zeros({}); + op_scatter_src_out(self, 0, index, src, out); + EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {5})); +} + +TEST_F(OpScatterSrcOutTest, InvalidZeroDimInput) { + TensorFactory tf_index; + TensorFactory tf_data; + + Tensor self = tf_data.ones({}); + Tensor index = tf_index.make({2, 3}, {0, 0, 0, 0, 0, 0}); + Tensor src = tf_data.make({}, {5}); + Tensor out = tf_data.zeros({}); + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_src_out(self, 0, index, src, out)); +} + +TEST_F(OpScatterSrcOutTest, InvalidZeroDimIndex) { + TensorFactory tf_index; + TensorFactory tf_data; + + Tensor self = tf_data.make({2, 3}, {1, 2, 3, 4, 5, 6}); + Tensor index = tf_index.make({}, {2}); + Tensor src = tf_data.make({}, {5}); + Tensor out = tf_data.zeros({2, 3}); + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_src_out(self, 1, index, src, out)); +} + +TEST_F(OpScatterSrcOutTest, ValidZeroDimInputAndOneDimIndex) { + TensorFactory tf_index; + TensorFactory tf_data; + + Tensor self = tf_data.make({}, {3.14}); + Tensor index = tf_index.make({3}, {0, 0, 0}); + Tensor src = tf_data.make({3}, {5, 5, 5}); + Tensor out = tf_data.make({}, {2.71}); + op_scatter_src_out(self, 0, index, src, out); + EXPECT_TENSOR_CLOSE(out, tf_data.make({}, {5})); +} + +TEST_F(OpScatterSrcOutTest, ValidOneDimInputAndZeroDimIndex) { + TensorFactory tf_index; + TensorFactory tf_data; + + Tensor self = tf_data.make({3}, {10, 20, 30}); + Tensor index = tf_index.make({}, {2}); + Tensor src = tf_data.make({}, {5}); + Tensor out = tf_data.make({3}, {1729, 1729, 1729}); + op_scatter_src_out(self, 0, index, src, out); + EXPECT_TENSOR_CLOSE(out, tf_data.make({3}, {10, 20, 5})); +} + +TEST_F(OpScatterSrcOutTest, InvalidZeroDimInputAndOneDimIndex) { + TensorFactory tf_index; + TensorFactory tf_data; + + Tensor self = tf_data.make({}, {3.14}); + Tensor index = tf_index.make({3}, {10, 100, 1000}); + Tensor src = tf_data.make({}, {5}); + Tensor out = tf_data.make({}, {2.71}); + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_src_out(self, 0, index, src, out)); +} + +TEST_F(OpScatterSrcOutTest, InvalidOneDimInputAndZeroDimIndex) { + TensorFactory tf_index; + TensorFactory tf_data; + + Tensor self = tf_data.make({3}, {10, 20, 30}); + Tensor index = tf_index.make({}, {100}); + Tensor src = tf_data.make({}, {5}); + Tensor out = tf_data.make({3}, {1729, 1729, 1729}); + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_src_out(self, 0, index, src, out)); +}