@@ -1259,7 +1259,7 @@ TEST_F(AtenXlaTensorTest, TestBroadcastTensors) {
12591259TEST_F (AtenXlaTensorTest, TestOneIndex) {
12601260 at::Tensor params = at::rand ({4 , 3 , 5 , 6 , 7 }, at::TensorOptions (at::kFloat ));
12611261 at::Tensor indices =
1262- at::randint (0 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
1262+ at::randint (- 3 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
12631263 at::Tensor result = at::index (params, {indices});
12641264 ForEachDevice ([&](const Device& device) {
12651265 at::Tensor xla_params = bridge::CreateXlaTensor (params, device);
@@ -1272,10 +1272,10 @@ TEST_F(AtenXlaTensorTest, TestOneIndex) {
12721272TEST_F (AtenXlaTensorTest, TestMultiIndexMiddleNull) {
12731273 at::Tensor params = at::rand ({4 , 3 , 5 , 6 , 7 }, at::TensorOptions (at::kFloat ));
12741274 at::Tensor indices_0 =
1275- at::randint (0 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
1275+ at::randint (- 3 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
12761276 at::Tensor indices_null;
12771277 at::Tensor indices_1 =
1278- at::randint (0 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
1278+ at::randint (- 3 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
12791279 at::Tensor result = at::index (params, {indices_0, indices_null, indices_1});
12801280 ForEachDevice ([&](const Device& device) {
12811281 at::Tensor xla_params = bridge::CreateXlaTensor (params, device);
@@ -1290,10 +1290,10 @@ TEST_F(AtenXlaTensorTest, TestMultiIndexMiddleNull) {
12901290TEST_F (AtenXlaTensorTest, TestMultiIndexTailNull) {
12911291 at::Tensor params = at::rand ({4 , 3 , 5 , 6 , 7 }, at::TensorOptions (at::kFloat ));
12921292 at::Tensor indices_0 =
1293- at::randint (0 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
1293+ at::randint (- 3 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
12941294 at::Tensor indices_null;
12951295 at::Tensor indices_1 =
1296- at::randint (0 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
1296+ at::randint (- 3 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
12971297 at::Tensor result = at::index (params, {indices_0, indices_1, indices_null});
12981298 ForEachDevice ([&](const Device& device) {
12991299 at::Tensor xla_params = bridge::CreateXlaTensor (params, device);
@@ -1308,9 +1308,9 @@ TEST_F(AtenXlaTensorTest, TestMultiIndexTailNull) {
13081308TEST_F (AtenXlaTensorTest, TestMultiIndexMiddleBroadcast) {
13091309 at::Tensor params = at::rand ({4 , 3 , 5 , 6 , 7 }, at::TensorOptions (at::kFloat ));
13101310 at::Tensor indices_0 =
1311- at::randint (0 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
1311+ at::randint (- 3 , 3 , {2 , 4 , 3 }, at::TensorOptions (at::kLong ));
13121312 at::Tensor indices_1 =
1313- at::randint (0 , 3 , {2 , 1 , 3 }, at::TensorOptions (at::kLong ));
1313+ at::randint (- 3 , 3 , {2 , 1 , 3 }, at::TensorOptions (at::kLong ));
13141314 at::Tensor result = at::index (params, {indices_0, indices_1});
13151315 ForEachDevice ([&](const Device& device) {
13161316 at::Tensor xla_params = bridge::CreateXlaTensor (params, device);
@@ -1325,9 +1325,9 @@ TEST_F(AtenXlaTensorTest, TestMultiIndexMiddleBroadcast) {
13251325TEST_F (AtenXlaTensorTest, TestMultiIndexTailBroadcast) {
13261326 at::Tensor params = at::rand ({4 , 3 , 5 , 6 , 7 }, at::TensorOptions (at::kFloat ));
13271327 at::Tensor indices_0 =
1328- at::randint (0 , 3 , {2 , 1 , 3 }, at::TensorOptions (at::kLong ));
1328+ at::randint (- 3 , 3 , {2 , 1 , 3 }, at::TensorOptions (at::kLong ));
13291329 at::Tensor indices_1 =
1330- at::randint (0 , 3 , {2 , 1 }, at::TensorOptions (at::kLong ));
1330+ at::randint (- 3 , 3 , {2 , 1 }, at::TensorOptions (at::kLong ));
13311331 at::Tensor result = at::index (params, {indices_0, indices_1});
13321332 ForEachDevice ([&](const Device& device) {
13331333 at::Tensor xla_params = bridge::CreateXlaTensor (params, device);
0 commit comments