@@ -79,9 +79,6 @@ def test_outbound_data_metrics(self):
7979
8080 def test_non_tensor_scalar (self ):
8181 sharding_spec = xs .ShardingSpec (self ._get_mesh ((1 , self .n_devices )), (0 , 1 ))
82- # TODO(JackCaoG)currently, execution will only happen if there is at least one
83- # tensor on non-spmd:0 device.
84- t1 = torch .randn (3 , 3 , device = xm .xla_device ())
8582 # tensor will have device as `SPMD:0` in c++
8683 xt1 = xm .send_cpu_data_to_device ([torch .randn (3 , 3 )],
8784 xm .xla_device (),
@@ -95,9 +92,6 @@ def test_non_tensor_scalar(self):
9592 def test_mark_step_on_virtual_device (self ):
9693 xm .mark_step ()
9794 sharding_spec = xs .ShardingSpec (self ._get_mesh ((1 , self .n_devices )), (0 , 1 ))
98- # TODO(JackCaoG)currently, execution will only happen if there is at least one
99- # tensor on non-spmd:0 device.
100- t1 = torch .randn (3 , 3 , device = xm .xla_device ())
10195 # tensor will have device as `SPMD:0` in c++
10296 xt1 = xm .send_cpu_data_to_device ([torch .randn (3 , 3 )],
10397 xm .xla_device (),
@@ -108,6 +102,63 @@ def test_mark_step_on_virtual_device(self):
108102 self .assertNotIn ('aten::div' ,
109103 torch_xla ._XLAC ._get_xla_tensor_debug_info (xt2 ))
110104
105+ def test_virtual_device_no_upload (self ):
106+ met .clear_all ()
107+ device = xm .xla_device ()
108+ t1 = torch .randn (5 , 5 ).to (device )
109+ t1_debug_info = torch_xla ._XLAC ._get_xla_tensor_debug_info (t1 )
110+ # t1's upload to device should be deferred
111+ self .assertIn ("Tensor on host: with size [5, 5]" , t1_debug_info )
112+ self .assertNotIn ("TransferToServerTime" , met .metric_names ())
113+ # t1 should be on SPMD device under spmd context
114+ self .assertIn ("Device: SPMD:0" , t1_debug_info )
115+ self .assertIn ("IR: None" , t1_debug_info )
116+ self .assertIn ("XLAData: None" , t1_debug_info )
117+
118+ def test_virtual_device_upload_after_mark_sharding (self ):
119+ met .clear_all ()
120+ partition_spec = (0 , 1 )
121+ device = xm .xla_device ()
122+ t1 = torch .randn (8 , 8 ).to (device )
123+ t1_debug_info = torch_xla ._XLAC ._get_xla_tensor_debug_info (t1 )
124+ self .assertIn ("Tensor on host: with size [8, 8]" , t1_debug_info )
125+ xs .mark_sharding (t1 , self ._get_mesh ((1 , self .n_devices )), partition_spec )
126+ t1_debug_info_new = torch_xla ._XLAC ._get_xla_tensor_debug_info (t1 )
127+ # tensor should be uploaded to device after mark_sharding
128+ self .assertIn ("Tensor on host: None" , t1_debug_info_new )
129+ self .assertIn ("xla::device_data" , t1_debug_info_new )
130+ self .assertIn ("XLAShardedData" , t1_debug_info_new )
131+ self .assertIn ("TransferToServerTime" , met .metric_names ())
132+
133+ def test_virtual_device_upload_after_tracing (self ):
134+ met .clear_all ()
135+ device = xm .xla_device ()
136+ t1 = torch .randn (8 , 8 ).to (device )
137+ t1_debug_info = torch_xla ._XLAC ._get_xla_tensor_debug_info (t1 )
138+ self .assertIn ("Tensor on host: with size [8, 8]" , t1_debug_info )
139+ t2 = t1 + t1
140+ t1_debug_info_new = torch_xla ._XLAC ._get_xla_tensor_debug_info (t1 )
141+ # tensor should be uploaded to device after being used as input to other op.
142+ self .assertIn ("Tensor on host: None" , t1_debug_info_new )
143+ self .assertIn ("xla::device_data" , t1_debug_info_new )
144+ self .assertIn ("TransferToServerTime" , met .metric_names ())
145+
146+ def test_virtual_device_upload_for_sharded_dataloader (self ):
147+ met .clear_counters ()
148+ device = xm .xla_device ()
149+ sharding_spec = xs .ShardingSpec (self ._get_mesh ((1 , self .n_devices )), (0 , 1 ))
150+ # tensor will have device as `SPMD:0` in c++
151+ t1 = xm .send_cpu_data_to_device ([torch .randn (8 , 8 )],
152+ device ,
153+ input_sharding = sharding_spec )[0 ]
154+ t1_debug_info = torch_xla ._XLAC ._get_xla_tensor_debug_info (t1 )
155+ self .assertIn ("Device: SPMD:0" , t1_debug_info )
156+ # tensor should be uploaded to device after send_cpu_data_to_device + sharding_spec
157+ self .assertIn ("Tensor on host: None" , t1_debug_info )
158+ self .assertIn ("xla::device_data" , t1_debug_info )
159+ self .assertIn ("XLAShardedData" , t1_debug_info )
160+ self .assertIn ("TransferToServerTime" , met .metric_names ())
161+
111162
112163if __name__ == '__main__' :
113164 test = unittest .main ()
0 commit comments