From d740b2a1908676d4504eccaa91c00c50e7c5a96b Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 14 Feb 2023 19:08:16 +0000 Subject: [PATCH 1/2] Disable `test_operations.py` tests failing on TPU --- test/test_operations.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/test_operations.py b/test/test_operations.py index 0ed1829a941f..da316aae0b66 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -49,6 +49,13 @@ DeviceSupport = collections.namedtuple('DeviceSupport', ['num_devices']) +def _is_on_tpu(): + return 'XRT_TPU_CONFIG' in os.environ or pjrt.device_type() == 'TPU' + + +skipOnTpu = unittest.skipIf(_is_on_tpu(), 'Not supported on TPU') + + def _gen_tensor(*args, **kwargs): return torch.randn(*args, **kwargs) @@ -702,6 +709,7 @@ def test_index_put(self): vset = b.sum().item() self.assertEqual(a.sum().item(), 10.0 * vset + (4.0 - vset)) + @skipOnTpu def test_pow_integer_types(self): self.runAtenTest(torch.randint(10, (2, 2)), lambda x: torch.pow(x, 2)) self.runAtenTest(torch.randint(10, (2, 2)), lambda x: torch.pow(2, x)) @@ -709,6 +717,7 @@ def test_pow_integer_types(self): self.runAtenTest(torch.randint(10, (2, 2)), lambda x: x.pow_(2)) self.runAtenTest(torch.randint(10, (2, 2)), lambda x: x.pow_(x)) + @skipOnTpu def test_matmul_integer_types(self): # all variance of matmul: dot/mv/mm/bmm self.runAtenTest((torch.randint(10, (2,)), torch.randint(10, (2,))), @@ -723,11 +732,13 @@ def test_matmul_integer_types(self): self.runAtenTest((torch.randint(10, (10, 3, 4)), torch.randint(10, (4, 5))), lambda x, y: torch.matmul(x, y)) + @skipOnTpu def test_addmm_integer_types(self): self.runAtenTest((torch.randint(10, (2, 3)), torch.randint( 10, (2, 3)), torch.randint(10, (3, 3))), lambda x, y, z: torch.addmm(x, y, z)) + @skipOnTpu def test_baddmm_integer_types(self): self.runAtenTest( (torch.randint(10, (10, 3, 5)), torch.randint(10, (10, 3, 4)), From 9d1b351e5b4de86066c650cd93f3e7c0698c7314 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 14 Feb 2023 19:13:04 +0000 Subject: [PATCH 2/2] Add to TPU CI --- test/tpu/pjrt_test_job.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/tpu/pjrt_test_job.yaml b/test/tpu/pjrt_test_job.yaml index 0ab2e953f906..dda0e9533f90 100644 --- a/test/tpu/pjrt_test_job.yaml +++ b/test/tpu/pjrt_test_job.yaml @@ -27,6 +27,7 @@ spec: - bash - -c - | + python3 pytorch/xla/test/pjrt/test_operations.py -v python3 pytorch/xla/test/pjrt/test_experimental_pjrt_tpu.py volumeMounts: - mountPath: /dev/shm