diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index c93042ca124..62012645b0c 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -4140,10 +4140,9 @@ TEST_F(AtenXlaTensorTest, TestScatter) { torch::Tensor xla_d = torch::scatter(xla_a, dim, xla_c, xla_b); AllClose(d, xla_d); }); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters()); } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestScatterR1) { @@ -4162,7 +4161,7 @@ TEST_F(AtenXlaTensorTest, TestScatterR1) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestScatterR3) { @@ -4186,7 +4185,7 @@ TEST_F(AtenXlaTensorTest, TestScatterR3) { }); ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestScatterBiggerSource) { @@ -4207,10 +4206,10 @@ TEST_F(AtenXlaTensorTest, TestScatterBiggerSource) { torch::Tensor xla_d = torch::scatter(xla_a, dim, xla_c, xla_b); AllClose(d, xla_d); }); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters()); } + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestScatterScalar) { @@ -4230,10 +4229,34 @@ TEST_F(AtenXlaTensorTest, TestScatterScalar) { torch::Tensor xla_d = torch::scatter(xla_a, dim, xla_c, b); AllClose(d, xla_d); }); + } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters()); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters()); +} + +TEST_F(AtenXlaTensorTest, TestScatterReduceAdd) { + torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat)); + torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat)); + torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong)); + for (int dim = 0; dim < 2; ++dim) { + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 5; j++) { + c[i][j] = (i + j) % c.sizes()[dim]; + } + } + torch::Tensor d = torch::scatter(a, dim, c, b, "add"); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = CopyToDevice(b, device); + torch::Tensor xla_c = CopyToDevice(c, device); + torch::Tensor xla_d = torch::scatter(xla_a, dim, xla_c, xla_b, "add"); + AllClose(d, xla_d); + }); } + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestScatterAdd) { @@ -4254,10 +4277,10 @@ TEST_F(AtenXlaTensorTest, TestScatterAdd) { torch::Tensor xla_d = torch::scatter_add(xla_a, dim, xla_c, xla_b); AllClose(d, xla_d); }); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::scatter_add_", cpp_test::GetIgnoredCounters()); } + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::scatter_add_out", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestScatterAddInPlace) { @@ -6182,7 +6205,7 @@ TEST_F(AtenXlaTensorTest, TestOneHot) { // TODO: PT one_hot impl employs item() which could be eliminated. ExpectCounterNotChanged("aten::(?!_local_scalar_dense).*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters()); } TEST_F(AtenXlaTensorTest, TestTranspose) { diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 7d6205c4712..7030c1cdf5a 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2430,21 +2430,85 @@ at::Tensor rsub(const at::Tensor& self, const at::Scalar& other, XLATensor::rsub(bridge::GetXlaTensor(self), other, alpha)); } -at::Tensor& scatter_(at::Tensor& self, int64_t dim, const at::Tensor& index, - const at::Tensor& src) { - XLA_FN_COUNTER("xla::"); +at::Tensor& scatter_reduce_out_helper(const at::Tensor& self, int64_t dim, + const at::Tensor& index, + const at::Tensor& src, + c10::optional reduce, + at::Tensor& out) { XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::scatter_(self_tensor, dim, bridge::GetXlaTensor(index), - bridge::GetXlaTensor(src)); - return self; + XLATensor out_tensor = bridge::GetXlaTensor(out); + if (!reduce.has_value()) { + XLATensor::scatter_out(out_tensor, self_tensor, dim, + bridge::GetXlaTensor(index), + bridge::GetXlaTensor(src)); + return out; + } else if (*reduce == "add") { + XLATensor::scatter_add_out(out_tensor, self_tensor, dim, + bridge::GetXlaTensor(index), + bridge::GetXlaTensor(src)); + } else { + // TODO: implement scatter_mul + return AtenXlaTypeDefault::scatter_out(self, dim, index, src, *reduce, out); + } + return out; } -at::Tensor& scatter_(at::Tensor& self, int64_t dim, const at::Tensor& index, - const at::Scalar& value) { +at::Tensor& scatter_reduce_out_helper(const at::Tensor& self, int64_t dim, + const at::Tensor& index, + const at::Scalar& value, + c10::optional reduce, + at::Tensor& out) { XLA_FN_COUNTER("xla::"); XLATensor self_tensor = bridge::GetXlaTensor(self); - XLATensor::scatter_(self_tensor, dim, bridge::GetXlaTensor(index), value); - return self; + XLATensor out_tensor = bridge::GetXlaTensor(out); + if (!reduce.has_value()) { + XLATensor::scatter_out(out_tensor, self_tensor, dim, + bridge::GetXlaTensor(index), value); + return out; + } else if (*reduce == "add") { + XLATensor::scatter_add_out(out_tensor, self_tensor, dim, + bridge::GetXlaTensor(index), value); + } else { + // TODO: implement scatter_mul + return AtenXlaTypeDefault::scatter_out(self, dim, index, value, *reduce, + out); + } + return out; +} + +at::Tensor& scatter_out(const at::Tensor& self, int64_t dim, + const at::Tensor& index, const at::Tensor& src, + at::Tensor& out) { + XLA_FN_COUNTER("xla::"); + return scatter_reduce_out_helper(self, dim, index, src, c10::nullopt, out); +} + +at::Tensor& scatter_out(const at::Tensor& self, int64_t dim, + const at::Tensor& index, const at::Scalar& value, + at::Tensor& out) { + XLA_FN_COUNTER("xla::"); + return scatter_reduce_out_helper(self, dim, index, value, c10::nullopt, out); +} + +at::Tensor& scatter_out(const at::Tensor& self, int64_t dim, + const at::Tensor& index, const at::Tensor& src, + c10::string_view reduce, at::Tensor& out) { + XLA_FN_COUNTER("xla::"); + return scatter_reduce_out_helper(self, dim, index, src, reduce, out); +} + +at::Tensor& scatter_out(const at::Tensor& self, int64_t dim, + const at::Tensor& index, const at::Scalar& value, + c10::string_view reduce, at::Tensor& out) { + XLA_FN_COUNTER("xla::"); + return scatter_reduce_out_helper(self, dim, index, value, reduce, out); +} + +at::Tensor& scatter_add_out(const at::Tensor& self, int64_t dim, + const at::Tensor& index, const at::Tensor& src, + at::Tensor& out) { + XLA_FN_COUNTER("xla::"); + return scatter_reduce_out_helper(self, dim, index, src, "add", out); } at::Tensor& scatter_add_(at::Tensor& self, int64_t dim, const at::Tensor& index, diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 408e0ab0a77..b2947881e3c 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -918,13 +918,21 @@ class XLATensor { static void copy_(XLATensor& input, XLATensor& src); - static void scatter_(XLATensor& input, xla::int64 dim, const XLATensor& index, - const XLATensor& src); - static void scatter_(XLATensor& input, xla::int64 dim, const XLATensor& index, - const at::Scalar& value); + static void scatter_out(XLATensor& out, const XLATensor& input, + xla::int64 dim, const XLATensor& index, + const XLATensor& src); + static void scatter_out(XLATensor& out, const XLATensor& input, + xla::int64 dim, const XLATensor& index, + const at::Scalar& value); static void scatter_add_(XLATensor& input, xla::int64 dim, const XLATensor& index, const XLATensor& src); + static void scatter_add_out(XLATensor& out, const XLATensor& input, + xla::int64 dim, const XLATensor& index, + const XLATensor& src); + static void scatter_add_out(XLATensor& out, const XLATensor& input, + xla::int64 dim, const XLATensor& index, + const at::Scalar& value); static XLATensor select(const XLATensor& input, xla::int64 dim, xla::int64 index); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 6082ae9467d..c4b512f0ff7 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2147,18 +2147,20 @@ void XLATensor::copy_(XLATensor& input, XLATensor& src) { } } -void XLATensor::scatter_(XLATensor& input, xla::int64 dim, - const XLATensor& index, const XLATensor& src) { - input.SetIrValue(ir::MakeNode( +void XLATensor::scatter_out(XLATensor& out, const XLATensor& input, + xla::int64 dim, const XLATensor& index, + const XLATensor& src) { + out.SetIrValue(ir::MakeNode( input.GetIrValue(), index.GetIrValue(), src.GetIrValue(), XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank()))); } -void XLATensor::scatter_(XLATensor& input, xla::int64 dim, - const XLATensor& index, const at::Scalar& value) { +void XLATensor::scatter_out(XLATensor& out, const XLATensor& input, + xla::int64 dim, const XLATensor& index, + const at::Scalar& value) { ir::Value constant = GetIrValueForScalar(value, input.shape(), input.GetDevice()); - input.SetIrValue(ir::MakeNode( + out.SetIrValue(ir::MakeNode( input.GetIrValue(), index.GetIrValue(), constant, XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank()))); } @@ -2170,6 +2172,24 @@ void XLATensor::scatter_add_(XLATensor& input, xla::int64 dim, XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank()))); } +void XLATensor::scatter_add_out(XLATensor& out, const XLATensor& input, + xla::int64 dim, const XLATensor& index, + const XLATensor& src) { + out.SetIrValue(ir::MakeNode( + input.GetIrValue(), index.GetIrValue(), src.GetIrValue(), + XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank()))); +} + +void XLATensor::scatter_add_out(XLATensor& out, const XLATensor& input, + xla::int64 dim, const XLATensor& index, + const at::Scalar& value) { + ir::Value constant = + GetIrValueForScalar(value, input.shape(), input.GetDevice()); + out.SetIrValue(ir::MakeNode( + input.GetIrValue(), index.GetIrValue(), constant, + XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank()))); +} + XLATensor XLATensor::select(const XLATensor& input, xla::int64 dim, xla::int64 index) { return tensor_ops::Select(input, dim, index); diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index a513572ff6a..1ec405b5924 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -170,9 +170,12 @@ supported: - index_add_ - index_fill_.int_Scalar - index_fill_.int_Tensor - - scatter_.src - - scatter_.value + - scatter.src_out + - scatter.value_out + - scatter.reduce_out + - scatter.value_reduce_out - scatter_add_ + - scatter_add.out - bitwise_and.Tensor - bitwise_and.Scalar - bitwise_or.Tensor_out