diff --git a/test/test_operations.py b/test/test_operations.py index 0ed1829a941f..da316aae0b66 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -49,6 +49,13 @@ DeviceSupport = collections.namedtuple('DeviceSupport', ['num_devices']) +def _is_on_tpu(): + return 'XRT_TPU_CONFIG' in os.environ or pjrt.device_type() == 'TPU' + + +skipOnTpu = unittest.skipIf(_is_on_tpu(), 'Not supported on TPU') + + def _gen_tensor(*args, **kwargs): return torch.randn(*args, **kwargs) @@ -702,6 +709,7 @@ def test_index_put(self): vset = b.sum().item() self.assertEqual(a.sum().item(), 10.0 * vset + (4.0 - vset)) + @skipOnTpu def test_pow_integer_types(self): self.runAtenTest(torch.randint(10, (2, 2)), lambda x: torch.pow(x, 2)) self.runAtenTest(torch.randint(10, (2, 2)), lambda x: torch.pow(2, x)) @@ -709,6 +717,7 @@ def test_pow_integer_types(self): self.runAtenTest(torch.randint(10, (2, 2)), lambda x: x.pow_(2)) self.runAtenTest(torch.randint(10, (2, 2)), lambda x: x.pow_(x)) + @skipOnTpu def test_matmul_integer_types(self): # all variance of matmul: dot/mv/mm/bmm self.runAtenTest((torch.randint(10, (2,)), torch.randint(10, (2,))), @@ -723,11 +732,13 @@ def test_matmul_integer_types(self): self.runAtenTest((torch.randint(10, (10, 3, 4)), torch.randint(10, (4, 5))), lambda x, y: torch.matmul(x, y)) + @skipOnTpu def test_addmm_integer_types(self): self.runAtenTest((torch.randint(10, (2, 3)), torch.randint( 10, (2, 3)), torch.randint(10, (3, 3))), lambda x, y, z: torch.addmm(x, y, z)) + @skipOnTpu def test_baddmm_integer_types(self): self.runAtenTest( (torch.randint(10, (10, 3, 5)), torch.randint(10, (10, 3, 4)), diff --git a/test/tpu/pjrt_test_job.yaml b/test/tpu/pjrt_test_job.yaml index 0ab2e953f906..dda0e9533f90 100644 --- a/test/tpu/pjrt_test_job.yaml +++ b/test/tpu/pjrt_test_job.yaml @@ -27,6 +27,7 @@ spec: - bash - -c - | + python3 pytorch/xla/test/pjrt/test_operations.py -v python3 pytorch/xla/test/pjrt/test_experimental_pjrt_tpu.py volumeMounts: - mountPath: /dev/shm