Skip to content

Commit e51d28b

Browse files
authored
Make cpu tensor on XLA dynamo backend a warning instead of error (#5549)
1 parent f38e4a5 commit e51d28b

File tree

2 files changed

+49
-32
lines changed

2 files changed

+49
-32
lines changed

test/dynamo/test_dynamo.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
import torch._dynamo as dynamo
1313
import torchvision
1414
import unittest
15+
import warnings
16+
17+
torch_xla._XLAC._init_computation_client()
1518

1619
# Setup import folders.
1720
xla_test_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
@@ -58,36 +61,6 @@ def test_random_op_different_result_each_run(self):
5861
self.assertFalse(torch.allclose(dynamo_res_2, dynamo_res_3))
5962

6063

61-
class DynamErrorMessageTest(unittest.TestCase):
62-
63-
def test_cpu_tensor(self):
64-
device = xm.xla_device()
65-
input = torch.randn(4, 3, 224, 224)
66-
input_xla = input.clone().to(device)
67-
resnet18 = torchvision.models.resnet18()
68-
resnet18.eval()
69-
xla_resnet18 = torchvision.models.resnet18()
70-
xla_resnet18.to(device)
71-
xla_resnet18.eval()
72-
dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla')
73-
dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla')
74-
# input on cpu and model weight on xla
75-
with self.assertRaises(Exception) as context:
76-
res = dynamo_resnet18(input)
77-
self.assertTrue(
78-
'found two different devices' in context.exception.__str__())
79-
# input on xla and model weight on cpu
80-
with self.assertRaises(Exception) as context:
81-
res = dynamo_resnet18_cpu(input_xla)
82-
self.assertTrue(
83-
'found two different devices' in context.exception.__str__())
84-
# input and model weight on cpu
85-
with self.assertRaises(Exception) as context:
86-
res = dynamo_resnet18_cpu(input)
87-
self.assertTrue(
88-
'please move all tensors to XLA device' in context.exception.__str__())
89-
90-
9164
class DynamoInferenceBasicTest(unittest.TestCase):
9265

9366
@classmethod
@@ -516,6 +489,47 @@ def test_resnet18(self):
516489
met.metric_data('RunCachedGraphOutputData')[0], sample_count * 3)
517490

518491

492+
class DynamErrorMessageTest(unittest.TestCase):
493+
494+
def test_mixed_cpu_tensor(self):
495+
device = xm.xla_device()
496+
input = torch.randn(4, 3, 224, 224)
497+
input_xla = input.clone().to(device)
498+
resnet18 = torchvision.models.resnet18()
499+
resnet18.eval()
500+
xla_resnet18 = torchvision.models.resnet18()
501+
xla_resnet18.to(device)
502+
xla_resnet18.eval()
503+
dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla')
504+
dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla')
505+
# input on cpu and model weight on xla
506+
with self.assertRaises(Exception) as context:
507+
res = dynamo_resnet18(input)
508+
self.assertTrue(
509+
'found two different devices' in context.exception.__str__())
510+
# input on xla and model weight on cpu
511+
with self.assertRaises(Exception) as context:
512+
res = dynamo_resnet18_cpu(input_xla)
513+
self.assertTrue(
514+
'found two different devices' in context.exception.__str__())
515+
516+
def test_all_cpu_tensor(self):
517+
met.clear_all()
518+
input = torch.randn(4, 3, 224, 224)
519+
resnet18 = torchvision.models.resnet18()
520+
resnet18.eval()
521+
dynamo_resnet18_cpu = torch.compile(resnet18, backend='openxla')
522+
# input and model weight on cpu
523+
with warnings.catch_warnings(record=True) as w:
524+
res = dynamo_resnet18_cpu(input)
525+
# there should be 18 paramters + 1 input
526+
self.assertGreater(len(w), 15)
527+
self.assertIn('Found tensor with shape torch.Size', str(w[0].message))
528+
# no XLA operation should happens. Partitioner should offload all CPU
529+
# ops to CPU.
530+
self.assertEqual(len(met.counter_names()), 0)
531+
532+
519533
if __name__ == '__main__':
520534
test = unittest.main()
521535
sys.exit(0 if test.result.wasSuccessful() else 1)

torch_xla/core/dynamo_bridge.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import dataclasses
33
import operator
4+
import warnings
45

56
import functools
67
import itertools
@@ -457,8 +458,10 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
457458

458459
for xla_arg in xla_args:
459460
if xla_arg.device.type != 'xla':
460-
raise RuntimeError(
461-
'For openxla dynamo backend, please move all tensors to XLA device')
461+
warnings.warn(
462+
"Found tensor with shape " + str(xla_arg.size()) + " on " +
463+
str(xla_arg.device) +
464+
". Please move all tensors to xla device to execute on XLA device.")
462465

463466
cloned_args = [
464467
torch.clone(xla_arg) if isinstance(xla_arg, torch.Tensor) else xla_arg

0 commit comments

Comments
 (0)