diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index 081340f15f6d..99dbd4f9015c 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -79,6 +79,9 @@ def test_outbound_data_metrics(self): def test_non_tensor_scalar(self): sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) + # TODO(JackCaoG)currently, execution will only happen if there is at least one + # tensor on non-spmd:0 device. + t1 = torch.randn(3, 3, device=xm.xla_device()) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], xm.xla_device(), @@ -86,8 +89,25 @@ def test_non_tensor_scalar(self): # we will transfer 0.5 as a device_data to the 'SPMD:0' device, need to make sure # that virtual device can handle this case. xt2 = xt1 / 0.5 + xm.mark_step(wait=True) torch.allclose(xt2.cpu(), xt1.cpu() / 0.5) + def test_mark_step_on_virtual_device(self): + xm.mark_step() + sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) + # TODO(JackCaoG)currently, execution will only happen if there is at least one + # tensor on non-spmd:0 device. + t1 = torch.randn(3, 3, device=xm.xla_device()) + # tensor will have device as `SPMD:0` in c++ + xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], + xm.xla_device(), + input_sharding=sharding_spec)[0] + xt2 = xt1 / 0.5 + xm.mark_step(wait=True) + # after mark_step, xt2 should be materalized + self.assertNotIn('aten::div', + torch_xla._XLAC._get_xla_tensor_debug_info(xt2)) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index db0f46018f2d..3a492fa30e97 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -81,6 +81,7 @@ auto XLAGraphExecutor::DeviceContextArena::Get() -> DeviceContextArena* { std::vector XLAGraphExecutor::DeviceContextArena::GetLiveTensors( const torch::lazy::BackendDevice* device) { std::vector tensors; + torch::lazy::BackendDevice virtual_device = GetVirtualDevice(); auto fn = [&](DeviceContext* devctx) { std::lock_guard lock(devctx->lock); for (auto& uid_wptr : devctx->tensors_data) { @@ -92,6 +93,8 @@ std::vector XLAGraphExecutor::DeviceContextArena::GetLiveTensors( } }; ForAllDeviceContexts(fn, device); + // TODO(JackCaoG): all tensors should be on spmd:0 in SPMD mode. + ForAllDeviceContexts(fn, &virtual_device); return tensors; } @@ -502,7 +505,10 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( tsl::profiler::TraceMeLevel::kInfo); runtime::util::Unique unique_device; for (size_t i = 0; i < tensors.size(); ++i) { - unique_device.set(tensors[i]->GetDevice()); + // TODO(JackCaoG): all tensors should be on spmd:0 in SPMD mode. + if (tensors[i]->GetDevice().toString() != "SPMD:0") { + unique_device.set(tensors[i]->GetDevice()); + } } SyncTensorCollection coll; if (!unique_device) {