@@ -23,40 +23,48 @@ def fn_simple_dynamo(self, x, y):
2323 return self .fn_simple (x , y )
2424
2525 @dynamo .optimize ('torchxla_trace_once' )
26- def resetnet_18_dynamo (self , model , data ):
26+ def run_model_with_dynamo (self , model , data ):
2727 return model (data )
2828
2929 def test_simple_model (self ):
30+ device = xm .xla_device ()
3031 x = torch .tensor (100.0 )
3132 y = torch .tensor (200.0 )
33+ xla_x = x .to (device )
34+ xla_y = y .to (device )
3235 res_cpu = self .fn_simple (x , y )
33- res_xla_dynamo = self .fn_simple_dynamo (x , y )
36+ res_xla_dynamo = self .fn_simple_dynamo (xla_x , xla_y )
3437 self .assertIn ('xla::add' , met .counter_names ())
3538 torch .allclose (res_cpu , res_xla_dynamo .cpu ())
3639 # verifiy that tracing is skipped in following runs
3740 met .clear_counters ()
38- res_xla_dynamo_2 = self .fn_simple_dynamo (x , y )
41+ res_xla_dynamo_2 = self .fn_simple_dynamo (xla_x , xla_y )
3942 self .assertNotIn ('xla::add' , met .counter_names ())
4043 torch .allclose (res_cpu , res_xla_dynamo_2 .cpu ())
4144 # verify that dynamo can handle different inputs
42- res_xla_dynamo_3 = self .fn_simple_dynamo (x + y , y * 3 )
45+ res_xla_dynamo_3 = self .fn_simple_dynamo (xla_x + xla_y , xla_y * 3 )
4346 res_cpu_3 = self .fn_simple (x + y , y * 3 )
4447 torch .allclose (res_cpu , res_xla_dynamo_3 .cpu ())
4548
4649 def test_resnet18 (self ):
50+ device = xm .xla_device ()
4751 batch_size = xu .getenv_as ('BATCH_SIZE' , int , defval = 4 )
4852 sample_count = xu .getenv_as ('SAMPLE_COUNT' , int , defval = 10 )
4953 loader = xu .SampleGenerator (
50- data = (torch .randn (batch_size , 3 , 224 ,
51- 224 ), torch .zeros (batch_size , dtype = torch .int64 )),
54+ data = (torch .randn (batch_size , 3 , 224 , 224 , device = device ),
55+ torch .zeros (batch_size , dtype = torch .int64 , device = device )),
5256 sample_count = sample_count )
53- model = torchvision .models .resnet18 ()
54- model .eval ()
57+ resnet18 = torchvision .models .resnet18 ()
58+ resnet18 .eval ()
59+ xla_resnet18 = torchvision .models .resnet18 ().to (device )
60+ xla_resnet18 .eval ()
5561 for data , _ in loader :
56- output = self .resetnet_18_dynamo (model , data )
57- torch .allclose (model (data ), output .cpu ())
58- self .assertEqual (met .metric_data ('CompileTime' )[0 ], 1 )
59- self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], sample_count + 1 )
62+ output = self .run_model_with_dynamo (xla_resnet18 , data )
63+ torch .allclose (resnet18 (data .cpu ()), output .cpu ())
64+ # One graph for initial input data materialization. Another grpah for the
65+ # real model code.
66+ self .assertEqual (met .metric_data ('CompileTime' )[0 ], 2 )
67+ self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], sample_count + 2 )
6068 self .assertEqual (
6169 met .metric_data ('RunCachedGraphInputData' )[0 ], sample_count )
6270 self .assertEqual (
0 commit comments