Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4140,10 +4140,9 @@ TEST_F(AtenXlaTensorTest, TestScatter) {
torch::Tensor xla_d = torch::scatter(xla_a, dim, xla_c, xla_b);
AllClose(d, xla_d);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters());
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestScatterR1) {
Expand All @@ -4162,7 +4161,7 @@ TEST_F(AtenXlaTensorTest, TestScatterR1) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestScatterR3) {
Expand All @@ -4186,7 +4185,7 @@ TEST_F(AtenXlaTensorTest, TestScatterR3) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestScatterBiggerSource) {
Expand All @@ -4207,10 +4206,10 @@ TEST_F(AtenXlaTensorTest, TestScatterBiggerSource) {
torch::Tensor xla_d = torch::scatter(xla_a, dim, xla_c, xla_b);
AllClose(d, xla_d);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters());
}

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestScatterScalar) {
Expand All @@ -4230,10 +4229,34 @@ TEST_F(AtenXlaTensorTest, TestScatterScalar) {
torch::Tensor xla_d = torch::scatter(xla_a, dim, xla_c, b);
AllClose(d, xla_d);
});
}

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters());
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestScatterReduceAdd) {
torch::Tensor a = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({3, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor c = torch::empty({3, 5}, torch::TensorOptions(torch::kLong));
for (int dim = 0; dim < 2; ++dim) {
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 5; j++) {
c[i][j] = (i + j) % c.sizes()[dim];
}
}
torch::Tensor d = torch::scatter(a, dim, c, b, "add");
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = CopyToDevice(b, device);
torch::Tensor xla_c = CopyToDevice(c, device);
torch::Tensor xla_d = torch::scatter(xla_a, dim, xla_c, xla_b, "add");
AllClose(d, xla_d);
});
}

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestScatterAdd) {
Expand All @@ -4254,10 +4277,10 @@ TEST_F(AtenXlaTensorTest, TestScatterAdd) {
torch::Tensor xla_d = torch::scatter_add(xla_a, dim, xla_c, xla_b);
AllClose(d, xla_d);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_add_", cpp_test::GetIgnoredCounters());
}

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_add_out", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestScatterAddInPlace) {
Expand Down Expand Up @@ -6182,7 +6205,7 @@ TEST_F(AtenXlaTensorTest, TestOneHot) {
// TODO: PT one_hot impl employs item() which could be eliminated.
ExpectCounterNotChanged("aten::(?!_local_scalar_dense).*",
cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::scatter_out", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestTranspose) {
Expand Down
84 changes: 74 additions & 10 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2430,21 +2430,85 @@ at::Tensor rsub(const at::Tensor& self, const at::Scalar& other,
XLATensor::rsub(bridge::GetXlaTensor(self), other, alpha));
}

at::Tensor& scatter_(at::Tensor& self, int64_t dim, const at::Tensor& index,
const at::Tensor& src) {
XLA_FN_COUNTER("xla::");
at::Tensor& scatter_reduce_out_helper(const at::Tensor& self, int64_t dim,
const at::Tensor& index,
const at::Tensor& src,
c10::optional<c10::string_view> reduce,
at::Tensor& out) {
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::scatter_(self_tensor, dim, bridge::GetXlaTensor(index),
bridge::GetXlaTensor(src));
return self;
XLATensor out_tensor = bridge::GetXlaTensor(out);
if (!reduce.has_value()) {
XLATensor::scatter_out(out_tensor, self_tensor, dim,
bridge::GetXlaTensor(index),
bridge::GetXlaTensor(src));
return out;
} else if (*reduce == "add") {
XLATensor::scatter_add_out(out_tensor, self_tensor, dim,
bridge::GetXlaTensor(index),
bridge::GetXlaTensor(src));
} else {
// TODO: implement scatter_mul
return AtenXlaTypeDefault::scatter_out(self, dim, index, src, *reduce, out);
}
return out;
}

at::Tensor& scatter_(at::Tensor& self, int64_t dim, const at::Tensor& index,
const at::Scalar& value) {
at::Tensor& scatter_reduce_out_helper(const at::Tensor& self, int64_t dim,
const at::Tensor& index,
const at::Scalar& value,
c10::optional<c10::string_view> reduce,
at::Tensor& out) {
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::scatter_(self_tensor, dim, bridge::GetXlaTensor(index), value);
return self;
XLATensor out_tensor = bridge::GetXlaTensor(out);
if (!reduce.has_value()) {
XLATensor::scatter_out(out_tensor, self_tensor, dim,
bridge::GetXlaTensor(index), value);
return out;
} else if (*reduce == "add") {
XLATensor::scatter_add_out(out_tensor, self_tensor, dim,
bridge::GetXlaTensor(index), value);
} else {
// TODO: implement scatter_mul
return AtenXlaTypeDefault::scatter_out(self, dim, index, value, *reduce,
out);
}
return out;
}

at::Tensor& scatter_out(const at::Tensor& self, int64_t dim,
const at::Tensor& index, const at::Tensor& src,
at::Tensor& out) {
XLA_FN_COUNTER("xla::");
return scatter_reduce_out_helper(self, dim, index, src, c10::nullopt, out);
}

at::Tensor& scatter_out(const at::Tensor& self, int64_t dim,
const at::Tensor& index, const at::Scalar& value,
at::Tensor& out) {
XLA_FN_COUNTER("xla::");
return scatter_reduce_out_helper(self, dim, index, value, c10::nullopt, out);
}

at::Tensor& scatter_out(const at::Tensor& self, int64_t dim,
const at::Tensor& index, const at::Tensor& src,
c10::string_view reduce, at::Tensor& out) {
XLA_FN_COUNTER("xla::");
return scatter_reduce_out_helper(self, dim, index, src, reduce, out);
}

at::Tensor& scatter_out(const at::Tensor& self, int64_t dim,
const at::Tensor& index, const at::Scalar& value,
c10::string_view reduce, at::Tensor& out) {
XLA_FN_COUNTER("xla::");
return scatter_reduce_out_helper(self, dim, index, value, reduce, out);
}

at::Tensor& scatter_add_out(const at::Tensor& self, int64_t dim,
const at::Tensor& index, const at::Tensor& src,
at::Tensor& out) {
XLA_FN_COUNTER("xla::");
return scatter_reduce_out_helper(self, dim, index, src, "add", out);
}

at::Tensor& scatter_add_(at::Tensor& self, int64_t dim, const at::Tensor& index,
Expand Down
16 changes: 12 additions & 4 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -918,13 +918,21 @@ class XLATensor {

static void copy_(XLATensor& input, XLATensor& src);

static void scatter_(XLATensor& input, xla::int64 dim, const XLATensor& index,
const XLATensor& src);
static void scatter_(XLATensor& input, xla::int64 dim, const XLATensor& index,
const at::Scalar& value);
static void scatter_out(XLATensor& out, const XLATensor& input,
xla::int64 dim, const XLATensor& index,
const XLATensor& src);
static void scatter_out(XLATensor& out, const XLATensor& input,
xla::int64 dim, const XLATensor& index,
const at::Scalar& value);

static void scatter_add_(XLATensor& input, xla::int64 dim,
const XLATensor& index, const XLATensor& src);
static void scatter_add_out(XLATensor& out, const XLATensor& input,
xla::int64 dim, const XLATensor& index,
const XLATensor& src);
static void scatter_add_out(XLATensor& out, const XLATensor& input,
xla::int64 dim, const XLATensor& index,
const at::Scalar& value);

static XLATensor select(const XLATensor& input, xla::int64 dim,
xla::int64 index);
Expand Down
32 changes: 26 additions & 6 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2147,18 +2147,20 @@ void XLATensor::copy_(XLATensor& input, XLATensor& src) {
}
}

void XLATensor::scatter_(XLATensor& input, xla::int64 dim,
const XLATensor& index, const XLATensor& src) {
input.SetIrValue(ir::MakeNode<ir::ops::Scatter>(
void XLATensor::scatter_out(XLATensor& out, const XLATensor& input,
xla::int64 dim, const XLATensor& index,
const XLATensor& src) {
out.SetIrValue(ir::MakeNode<ir::ops::Scatter>(
input.GetIrValue(), index.GetIrValue(), src.GetIrValue(),
XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank())));
}

void XLATensor::scatter_(XLATensor& input, xla::int64 dim,
const XLATensor& index, const at::Scalar& value) {
void XLATensor::scatter_out(XLATensor& out, const XLATensor& input,
xla::int64 dim, const XLATensor& index,
const at::Scalar& value) {
ir::Value constant =
GetIrValueForScalar(value, input.shape(), input.GetDevice());
input.SetIrValue(ir::MakeNode<ir::ops::Scatter>(
out.SetIrValue(ir::MakeNode<ir::ops::Scatter>(
input.GetIrValue(), index.GetIrValue(), constant,
XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank())));
}
Expand All @@ -2170,6 +2172,24 @@ void XLATensor::scatter_add_(XLATensor& input, xla::int64 dim,
XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank())));
}

void XLATensor::scatter_add_out(XLATensor& out, const XLATensor& input,
xla::int64 dim, const XLATensor& index,
const XLATensor& src) {
out.SetIrValue(ir::MakeNode<ir::ops::ScatterAdd>(
input.GetIrValue(), index.GetIrValue(), src.GetIrValue(),
XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank())));
}

void XLATensor::scatter_add_out(XLATensor& out, const XLATensor& input,
xla::int64 dim, const XLATensor& index,
const at::Scalar& value) {
ir::Value constant =
GetIrValueForScalar(value, input.shape(), input.GetDevice());
out.SetIrValue(ir::MakeNode<ir::ops::ScatterAdd>(
input.GetIrValue(), index.GetIrValue(), constant,
XlaHelpers::GetCanonicalDimensionIndex(dim, input.shape().get().rank())));
}

XLATensor XLATensor::select(const XLATensor& input, xla::int64 dim,
xla::int64 index) {
return tensor_ops::Select(input, dim, index);
Expand Down
7 changes: 5 additions & 2 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,12 @@ supported:
- index_add_
- index_fill_.int_Scalar
- index_fill_.int_Tensor
- scatter_.src
- scatter_.value
- scatter.src_out
- scatter.value_out
- scatter.reduce_out
- scatter.value_reduce_out
- scatter_add_
- scatter_add.out
- bitwise_and.Tensor
- bitwise_and.Scalar
- bitwise_or.Tensor_out
Expand Down