@@ -68,6 +68,25 @@ def test_simple_model(self):
6868 res_cpu_3 = self .fn_simple (x + y , y * 3 )
6969 self .assertTrue (torch .allclose (res_cpu_3 , res_xla_dynamo_3 .cpu ()))
7070
71+ def test_simple_model_with_different_input_shape (self ):
72+ met .clear_counters ()
73+ device = xm .xla_device ()
74+ xla_x = torch .randn (5 , 5 ).to (device )
75+ xla_y = torch .randn (5 , 5 ).to (device )
76+ xla_z = torch .randn (10 , 10 ).to (device )
77+ self .fn_simple_dynamo (xla_x , xla_x )
78+ compile_count = met .metric_data ('CompileTime' )[0 ]
79+ # Execute with input with same shape should not trigger additional compilation
80+ self .fn_simple_dynamo (xla_y , xla_y )
81+ self .assertEqual (met .metric_data ('CompileTime' )[0 ], compile_count )
82+ # Give `fn_simple_dynamo` an input with different shappe, we expect
83+ # dynamo to recognize this is a different graph and let XLA to retrace/recompile
84+ res_xla_dynamo_3 = self .fn_simple_dynamo (xla_z , xla_z )
85+ self .assertEqual (met .metric_data ('CompileTime' )[0 ], compile_count + 1 )
86+ self .assertTrue (
87+ torch .allclose (res_xla_dynamo_3 .cpu (),
88+ self .fn_simple (xla_z .cpu (), xla_z .cpu ())))
89+
7190 def test_resnet18 (self ):
7291 device = xm .xla_device ()
7392 batch_size = xu .getenv_as ('BATCH_SIZE' , int , defval = 4 )
0 commit comments