Skip to content

Commit dfb8036

Browse files
authored
Make sure dynamo give us different graph when input shape changes (#5177)
* Make sure dynamo give us different graph when input shape changes * fix typo
1 parent 84e7756 commit dfb8036

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

test/dynamo/test_dynamo.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,25 @@ def test_simple_model(self):
6868
res_cpu_3 = self.fn_simple(x + y, y * 3)
6969
self.assertTrue(torch.allclose(res_cpu_3, res_xla_dynamo_3.cpu()))
7070

71+
def test_simple_model_with_different_input_shape(self):
72+
met.clear_counters()
73+
device = xm.xla_device()
74+
xla_x = torch.randn(5, 5).to(device)
75+
xla_y = torch.randn(5, 5).to(device)
76+
xla_z = torch.randn(10, 10).to(device)
77+
self.fn_simple_dynamo(xla_x, xla_x)
78+
compile_count = met.metric_data('CompileTime')[0]
79+
# Execute with input with same shape should not trigger additional compilation
80+
self.fn_simple_dynamo(xla_y, xla_y)
81+
self.assertEqual(met.metric_data('CompileTime')[0], compile_count)
82+
# Give `fn_simple_dynamo` an input with different shappe, we expect
83+
# dynamo to recognize this is a different graph and let XLA to retrace/recompile
84+
res_xla_dynamo_3 = self.fn_simple_dynamo(xla_z, xla_z)
85+
self.assertEqual(met.metric_data('CompileTime')[0], compile_count + 1)
86+
self.assertTrue(
87+
torch.allclose(res_xla_dynamo_3.cpu(),
88+
self.fn_simple(xla_z.cpu(), xla_z.cpu())))
89+
7190
def test_resnet18(self):
7291
device = xm.xla_device()
7392
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)

0 commit comments

Comments
 (0)