Skip to content

Commit d7bd06b

Browse files
committed
Wrap around negative indices in the index method
1 parent b074e44 commit d7bd06b

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,7 +1259,7 @@ TEST_F(AtenXlaTensorTest, TestBroadcastTensors) {
12591259
TEST_F(AtenXlaTensorTest, TestOneIndex) {
12601260
at::Tensor params = at::rand({4, 3, 5, 6, 7}, at::TensorOptions(at::kFloat));
12611261
at::Tensor indices =
1262-
at::randint(0, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
1262+
at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
12631263
at::Tensor result = at::index(params, {indices});
12641264
ForEachDevice([&](const Device& device) {
12651265
at::Tensor xla_params = bridge::CreateXlaTensor(params, device);
@@ -1272,10 +1272,10 @@ TEST_F(AtenXlaTensorTest, TestOneIndex) {
12721272
TEST_F(AtenXlaTensorTest, TestMultiIndexMiddleNull) {
12731273
at::Tensor params = at::rand({4, 3, 5, 6, 7}, at::TensorOptions(at::kFloat));
12741274
at::Tensor indices_0 =
1275-
at::randint(0, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
1275+
at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
12761276
at::Tensor indices_null;
12771277
at::Tensor indices_1 =
1278-
at::randint(0, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
1278+
at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
12791279
at::Tensor result = at::index(params, {indices_0, indices_null, indices_1});
12801280
ForEachDevice([&](const Device& device) {
12811281
at::Tensor xla_params = bridge::CreateXlaTensor(params, device);
@@ -1290,10 +1290,10 @@ TEST_F(AtenXlaTensorTest, TestMultiIndexMiddleNull) {
12901290
TEST_F(AtenXlaTensorTest, TestMultiIndexTailNull) {
12911291
at::Tensor params = at::rand({4, 3, 5, 6, 7}, at::TensorOptions(at::kFloat));
12921292
at::Tensor indices_0 =
1293-
at::randint(0, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
1293+
at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
12941294
at::Tensor indices_null;
12951295
at::Tensor indices_1 =
1296-
at::randint(0, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
1296+
at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
12971297
at::Tensor result = at::index(params, {indices_0, indices_1, indices_null});
12981298
ForEachDevice([&](const Device& device) {
12991299
at::Tensor xla_params = bridge::CreateXlaTensor(params, device);
@@ -1308,9 +1308,9 @@ TEST_F(AtenXlaTensorTest, TestMultiIndexTailNull) {
13081308
TEST_F(AtenXlaTensorTest, TestMultiIndexMiddleBroadcast) {
13091309
at::Tensor params = at::rand({4, 3, 5, 6, 7}, at::TensorOptions(at::kFloat));
13101310
at::Tensor indices_0 =
1311-
at::randint(0, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
1311+
at::randint(-3, 3, {2, 4, 3}, at::TensorOptions(at::kLong));
13121312
at::Tensor indices_1 =
1313-
at::randint(0, 3, {2, 1, 3}, at::TensorOptions(at::kLong));
1313+
at::randint(-3, 3, {2, 1, 3}, at::TensorOptions(at::kLong));
13141314
at::Tensor result = at::index(params, {indices_0, indices_1});
13151315
ForEachDevice([&](const Device& device) {
13161316
at::Tensor xla_params = bridge::CreateXlaTensor(params, device);
@@ -1325,9 +1325,9 @@ TEST_F(AtenXlaTensorTest, TestMultiIndexMiddleBroadcast) {
13251325
TEST_F(AtenXlaTensorTest, TestMultiIndexTailBroadcast) {
13261326
at::Tensor params = at::rand({4, 3, 5, 6, 7}, at::TensorOptions(at::kFloat));
13271327
at::Tensor indices_0 =
1328-
at::randint(0, 3, {2, 1, 3}, at::TensorOptions(at::kLong));
1328+
at::randint(-3, 3, {2, 1, 3}, at::TensorOptions(at::kLong));
13291329
at::Tensor indices_1 =
1330-
at::randint(0, 3, {2, 1}, at::TensorOptions(at::kLong));
1330+
at::randint(-3, 3, {2, 1}, at::TensorOptions(at::kLong));
13311331
at::Tensor result = at::index(params, {indices_0, indices_1});
13321332
ForEachDevice([&](const Device& device) {
13331333
at::Tensor xla_params = bridge::CreateXlaTensor(params, device);

torch_xla/csrc/tensor.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,10 +1045,11 @@ XLATensor XLATensor::expand(const XLATensor& input,
10451045
XLATensor XLATensor::index(
10461046
const XLATensor& input,
10471047
tensorflow::gtl::ArraySlice<const XLATensor> indices) {
1048-
xla::int64 indices_rank = indices.front().shape().get().rank();
1048+
auto canonical_indices = WrapIndicesOnce(input, indices);
1049+
xla::int64 indices_rank = canonical_indices.front().shape().get().rank();
10491050
// Stack the indices to allow the whole multi-indexing to be dispatched with a
10501051
// single gather.
1051-
XLATensor indices_nd = XLATensor::stack(indices, indices_rank);
1052+
XLATensor indices_nd = XLATensor::stack(canonical_indices, indices_rank);
10521053
return Create(ir::ops::IndexOp(input.GetIrValue(), indices_nd.GetIrValue()),
10531054
input.GetDevice());
10541055
}
@@ -1869,6 +1870,24 @@ Device XLATensor::CommonDeviceForTensors(
18691870
return device;
18701871
}
18711872

1873+
std::vector<XLATensor> XLATensor::WrapIndicesOnce(
1874+
const XLATensor& input,
1875+
tensorflow::gtl::ArraySlice<const XLATensor> indices) {
1876+
std::vector<XLATensor> canonical_indices;
1877+
XLA_CHECK_LE(indices.size(), input.shape().get().rank());
1878+
for (size_t dim_idx = 0; dim_idx < indices.size(); ++dim_idx) {
1879+
const XLATensor& dim_index = indices[dim_idx];
1880+
int64_t dim_size = input.shape().get().dimensions(dim_idx);
1881+
XLATensor wrapped_dim_index =
1882+
Create(dim_index.GetIrValue() +
1883+
ir::ops::ScalarOp(at::Scalar(dim_size), dim_index.shape()),
1884+
input.GetDevice());
1885+
XLATensor wrap_cond = lt(indices[dim_idx], at::Scalar(int64_t(0)));
1886+
canonical_indices.push_back(where(wrap_cond, wrapped_dim_index, dim_index));
1887+
}
1888+
return canonical_indices;
1889+
}
1890+
18721891
xla::int64 XLATensor::GetCanonicalDimension(const XLATensor& input,
18731892
xla::int64 dim) {
18741893
return XlaHelpers::GetCanonicalDimensionIndex(dim,

torch_xla/csrc/tensor.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,12 @@ class XLATensor {
697697

698698
static xla::int64 GetNextTensorId();
699699

700+
// Wraps index tensors once into the [0, dim_size) interval, where dim_size is
701+
// the size of the current indexed dimension.
702+
static std::vector<XLATensor> WrapIndicesOnce(
703+
const XLATensor& input,
704+
tensorflow::gtl::ArraySlice<const XLATensor> indices);
705+
700706
static xla::int64 GetCanonicalDimension(const XLATensor& input,
701707
xla::int64 dim);
702708

0 commit comments

Comments
 (0)