From 4516f9dc091cbeb6b1a4f7c8e905b52551673b33 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Tue, 11 Jun 2024 22:20:39 +0000 Subject: [PATCH 1/6] Fix tpu v3 --- test/spmd/test_xla_sharding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 2d710a7c7c1e..a8c06ca6bef5 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1230,7 +1230,7 @@ def test_spmd_reduce_scatter(self): f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{0}}, to_apply=%AddComputation.3", hlo) - expected_x = torch.ones(2, 8) * 4 + expected_x = torch.ones(8 // self.n_devices, 8) * 4 self.assertTrue(torch.allclose(x.cpu(), expected_x)) @unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device") @@ -1250,7 +1250,7 @@ def test_spmd_reduce_scatter_canonical_index(self): f"reduce-scatter(f32[8,8]{{1,0}} %custom-call.2), channel_id=1, replica_groups={{{{{','.join([str(x) for x in self.device_ids])}}}}}, use_global_device_ids=true, dimensions={{1}}, to_apply=%AddComputation.3", hlo) - expected_x = torch.ones(8, 2) * 4 + expected_x = torch.ones(8, 8 // self.n_devices) * 4 self.assertTrue(torch.allclose(x.cpu(), expected_x)) From bb4f179640f9ff0764a9e3150218215cf061b30b Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 12 Jun 2024 00:47:34 +0000 Subject: [PATCH 2/6] skip tpu v2 --- test/spmd/test_xla_sharding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index a8c06ca6bef5..f972b3d4a3f8 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1213,7 +1213,7 @@ def test_manual_sharding_api_e2e(self): self.assertEqual(xxx.shape, (8, 8)) self.assertTrue(torch.allclose(x.cpu() + 1, xxx.cpu())) - @unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device") + @unittest.skipIf(xr.device_type() != 'TPU' and tpu.version() < 4, "Only runs on TPUv4") def test_spmd_reduce_scatter(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) x = torch.ones(8, 8).to(xm.xla_device()) @@ -1233,7 +1233,7 @@ def test_spmd_reduce_scatter(self): expected_x = torch.ones(8 // self.n_devices, 8) * 4 self.assertTrue(torch.allclose(x.cpu(), expected_x)) - @unittest.skipIf(xr.device_type() != 'TPU', "Skip non-TPU device") + @unittest.skipIf(xr.device_type() != 'TPU' and tpu.version() < 4, "Only runs on TPUv4") def test_spmd_reduce_scatter_canonical_index(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) x = torch.ones(8, 8).to(xm.xla_device()) From 3415c6d08a47e02e9c079d16d9cf238d3bcc1fc2 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 12 Jun 2024 00:49:07 +0000 Subject: [PATCH 3/6] Fix linters --- test/spmd/test_xla_sharding.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index f972b3d4a3f8..8e3ca043133d 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1213,7 +1213,8 @@ def test_manual_sharding_api_e2e(self): self.assertEqual(xxx.shape, (8, 8)) self.assertTrue(torch.allclose(x.cpu() + 1, xxx.cpu())) - @unittest.skipIf(xr.device_type() != 'TPU' and tpu.version() < 4, "Only runs on TPUv4") + @unittest.skipIf(xr.device_type() != 'TPU' and tpu.version() < 4, + "Only runs on TPUv4") def test_spmd_reduce_scatter(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) x = torch.ones(8, 8).to(xm.xla_device()) @@ -1233,7 +1234,8 @@ def test_spmd_reduce_scatter(self): expected_x = torch.ones(8 // self.n_devices, 8) * 4 self.assertTrue(torch.allclose(x.cpu(), expected_x)) - @unittest.skipIf(xr.device_type() != 'TPU' and tpu.version() < 4, "Only runs on TPUv4") + @unittest.skipIf(xr.device_type() != 'TPU' and tpu.version() < 4, + "Only runs on TPUv4") def test_spmd_reduce_scatter_canonical_index(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) x = torch.ones(8, 8).to(xm.xla_device()) From f92a4dd8308909886fc8f81fc6ff16b4e58747ef Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 12 Jun 2024 20:47:50 +0000 Subject: [PATCH 4/6] Fix --- test/spmd/test_xla_sharding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 8e3ca043133d..40d3304e6f08 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1213,7 +1213,7 @@ def test_manual_sharding_api_e2e(self): self.assertEqual(xxx.shape, (8, 8)) self.assertTrue(torch.allclose(x.cpu() + 1, xxx.cpu())) - @unittest.skipIf(xr.device_type() != 'TPU' and tpu.version() < 4, + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "Only runs on TPUv4") def test_spmd_reduce_scatter(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) @@ -1234,7 +1234,7 @@ def test_spmd_reduce_scatter(self): expected_x = torch.ones(8 // self.n_devices, 8) * 4 self.assertTrue(torch.allclose(x.cpu(), expected_x)) - @unittest.skipIf(xr.device_type() != 'TPU' and tpu.version() < 4, + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "Only runs on TPUv4") def test_spmd_reduce_scatter_canonical_index(self): xs.set_global_mesh(self._get_mesh((1, self.n_devices))) From 658f68f049c3622142f3d05ca6b535ef85c571b8 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 12 Jun 2024 20:49:15 +0000 Subject: [PATCH 5/6] Add a debug print --- test/spmd/test_xla_sharding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 40d3304e6f08..7e3987a48cbd 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1216,6 +1216,7 @@ def test_manual_sharding_api_e2e(self): @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "Only runs on TPUv4") def test_spmd_reduce_scatter(self): + print(tpu.version()) xs.set_global_mesh(self._get_mesh((1, self.n_devices))) x = torch.ones(8, 8).to(xm.xla_device()) From 6a15573dbff95ab40e95d07c17bf6705d4e72c98 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 14 Jun 2024 04:43:04 +0000 Subject: [PATCH 6/6] remove print --- test/spmd/test_xla_sharding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 7e3987a48cbd..40d3304e6f08 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1216,7 +1216,6 @@ def test_manual_sharding_api_e2e(self): @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "Only runs on TPUv4") def test_spmd_reduce_scatter(self): - print(tpu.version()) xs.set_global_mesh(self._get_mesh((1, self.n_devices))) x = torch.ones(8, 8).to(xm.xla_device())