|
12 | 12 | import torch._dynamo as dynamo |
13 | 13 | import torchvision |
14 | 14 | import unittest |
| 15 | +import warnings |
| 16 | + |
| 17 | +torch_xla._XLAC._init_computation_client() |
15 | 18 |
|
16 | 19 | # Setup import folders. |
17 | 20 | xla_test_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) |
@@ -58,36 +61,6 @@ def test_random_op_different_result_each_run(self): |
58 | 61 | self.assertFalse(torch.allclose(dynamo_res_2, dynamo_res_3)) |
59 | 62 |
|
60 | 63 |
|
61 | | -class DynamErrorMessageTest(unittest.TestCase): |
62 | | - |
63 | | - def test_cpu_tensor(self): |
64 | | - device = xm.xla_device() |
65 | | - input = torch.randn(4, 3, 224, 224) |
66 | | - input_xla = input.clone().to(device) |
67 | | - resnet18 = torchvision.models.resnet18() |
68 | | - resnet18.eval() |
69 | | - xla_resnet18 = torchvision.models.resnet18() |
70 | | - xla_resnet18.to(device) |
71 | | - xla_resnet18.eval() |
72 | | - dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla') |
73 | | - dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla') |
74 | | - # input on cpu and model weight on xla |
75 | | - with self.assertRaises(Exception) as context: |
76 | | - res = dynamo_resnet18(input) |
77 | | - self.assertTrue( |
78 | | - 'found two different devices' in context.exception.__str__()) |
79 | | - # input on xla and model weight on cpu |
80 | | - with self.assertRaises(Exception) as context: |
81 | | - res = dynamo_resnet18_cpu(input_xla) |
82 | | - self.assertTrue( |
83 | | - 'found two different devices' in context.exception.__str__()) |
84 | | - # input and model weight on cpu |
85 | | - with self.assertRaises(Exception) as context: |
86 | | - res = dynamo_resnet18_cpu(input) |
87 | | - self.assertTrue( |
88 | | - 'please move all tensors to XLA device' in context.exception.__str__()) |
89 | | - |
90 | | - |
91 | 64 | class DynamoInferenceBasicTest(unittest.TestCase): |
92 | 65 |
|
93 | 66 | @classmethod |
@@ -516,6 +489,47 @@ def test_resnet18(self): |
516 | 489 | met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3) |
517 | 490 |
|
518 | 491 |
|
| 492 | +class DynamErrorMessageTest(unittest.TestCase): |
| 493 | + |
| 494 | + def test_mixed_cpu_tensor(self): |
| 495 | + device = xm.xla_device() |
| 496 | + input = torch.randn(4, 3, 224, 224) |
| 497 | + input_xla = input.clone().to(device) |
| 498 | + resnet18 = torchvision.models.resnet18() |
| 499 | + resnet18.eval() |
| 500 | + xla_resnet18 = torchvision.models.resnet18() |
| 501 | + xla_resnet18.to(device) |
| 502 | + xla_resnet18.eval() |
| 503 | + dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla') |
| 504 | + dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla') |
| 505 | + # input on cpu and model weight on xla |
| 506 | + with self.assertRaises(Exception) as context: |
| 507 | + res = dynamo_resnet18(input) |
| 508 | + self.assertTrue( |
| 509 | + 'found two different devices' in context.exception.__str__()) |
| 510 | + # input on xla and model weight on cpu |
| 511 | + with self.assertRaises(Exception) as context: |
| 512 | + res = dynamo_resnet18_cpu(input_xla) |
| 513 | + self.assertTrue( |
| 514 | + 'found two different devices' in context.exception.__str__()) |
| 515 | + |
| 516 | + def test_all_cpu_tensor(self): |
| 517 | + met.clear_all() |
| 518 | + input = torch.randn(4, 3, 224, 224) |
| 519 | + resnet18 = torchvision.models.resnet18() |
| 520 | + resnet18.eval() |
| 521 | + dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla') |
| 522 | + # input and model weight on cpu |
| 523 | + with warnings.catch_warnings(record=True) as w: |
| 524 | + res = dynamo_resnet18_cpu(input) |
| 525 | + # there should be 18 paramters + 1 input |
| 526 | + self.assertGreater(len(w), 15) |
| 527 | + self.assertIn('Found tensor with shape torch.Size', str(w[0].message)) |
| 528 | + # no XLA operation should happens. Partitioner should offload all CPU |
| 529 | + # ops to CPU. |
| 530 | + self.assertEqual(len(met.counter_names()), 0) |
| 531 | + |
| 532 | + |
519 | 533 | if __name__ == '__main__': |
520 | 534 | test = unittest.main() |
521 | 535 | sys.exit(0 if test.result.wasSuccessful() else 1) |
0 commit comments