From 36310efee0704af46a02f3d25b9c621030b72f06 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 5 Jul 2023 18:23:30 +0000 Subject: [PATCH 1/3] Fix the error where mark_step does not materalize tensors on SPMD:0 --- test/spmd/test_xla_virtual_device.py | 16 ++++++++++++++++ torch_xla/csrc/xla_graph_executor.cpp | 8 +++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index 081340f15f6d..37160e64dd06 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -88,6 +88,22 @@ def test_non_tensor_scalar(self): xt2 = xt1 / 0.5 torch.allclose(xt2.cpu(), xt1.cpu() / 0.5) + def test_mark_step_on_virtual_device(self): + xm.mark_step() + sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) + # TODO(JackCaoG)currently, execution will only happen if there is at least on + # tensor on non-spmd:0 device. + t1 = torch.randn(3, 3, device=xm.xla_device()) + # tensor will have device as `SPMD:0` in c++ + xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], + xm.xla_device(), + input_sharding=sharding_spec)[0] + xt2 = xt1 / 0.5 + xm.mark_step() + # after mark_step, xt2 should be materalized + self.assertNotIn('aten::div', + torch_xla._XLAC._get_xla_tensor_debug_info(xt2)) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index db0f46018f2d..3a492fa30e97 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -81,6 +81,7 @@ auto XLAGraphExecutor::DeviceContextArena::Get() -> DeviceContextArena* { std::vector XLAGraphExecutor::DeviceContextArena::GetLiveTensors( const torch::lazy::BackendDevice* device) { std::vector tensors; + torch::lazy::BackendDevice virtual_device = GetVirtualDevice(); auto fn = [&](DeviceContext* devctx) { std::lock_guard lock(devctx->lock); for (auto& uid_wptr : devctx->tensors_data) { @@ -92,6 +93,8 @@ std::vector XLAGraphExecutor::DeviceContextArena::GetLiveTensors( } }; ForAllDeviceContexts(fn, device); + // TODO(JackCaoG): all tensors should be on spmd:0 in SPMD mode. + ForAllDeviceContexts(fn, &virtual_device); return tensors; } @@ -502,7 +505,10 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors( tsl::profiler::TraceMeLevel::kInfo); runtime::util::Unique unique_device; for (size_t i = 0; i < tensors.size(); ++i) { - unique_device.set(tensors[i]->GetDevice()); + // TODO(JackCaoG): all tensors should be on spmd:0 in SPMD mode. + if (tensors[i]->GetDevice().toString() != "SPMD:0") { + unique_device.set(tensors[i]->GetDevice()); + } } SyncTensorCollection coll; if (!unique_device) { From 54c8b46c2e800a8313116255a7776b7c65e4bffa Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 5 Jul 2023 18:39:28 +0000 Subject: [PATCH 2/3] typo --- test/spmd/test_xla_virtual_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index 37160e64dd06..81f03121ee53 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -91,7 +91,7 @@ def test_non_tensor_scalar(self): def test_mark_step_on_virtual_device(self): xm.mark_step() sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) - # TODO(JackCaoG)currently, execution will only happen if there is at least on + # TODO(JackCaoG)currently, execution will only happen if there is at least one # tensor on non-spmd:0 device. t1 = torch.randn(3, 3, device=xm.xla_device()) # tensor will have device as `SPMD:0` in c++ From f87555030a400d13f79838aff41a30540f1385dc Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 5 Jul 2023 21:50:37 +0000 Subject: [PATCH 3/3] fix test_non_tensor_scalar --- test/spmd/test_xla_virtual_device.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/spmd/test_xla_virtual_device.py b/test/spmd/test_xla_virtual_device.py index 81f03121ee53..99dbd4f9015c 100644 --- a/test/spmd/test_xla_virtual_device.py +++ b/test/spmd/test_xla_virtual_device.py @@ -79,6 +79,9 @@ def test_outbound_data_metrics(self): def test_non_tensor_scalar(self): sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1)) + # TODO(JackCaoG)currently, execution will only happen if there is at least one + # tensor on non-spmd:0 device. + t1 = torch.randn(3, 3, device=xm.xla_device()) # tensor will have device as `SPMD:0` in c++ xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)], xm.xla_device(), @@ -86,6 +89,7 @@ def test_non_tensor_scalar(self): # we will transfer 0.5 as a device_data to the 'SPMD:0' device, need to make sure # that virtual device can handle this case. xt2 = xt1 / 0.5 + xm.mark_step(wait=True) torch.allclose(xt2.cpu(), xt1.cpu() / 0.5) def test_mark_step_on_virtual_device(self): @@ -99,7 +103,7 @@ def test_mark_step_on_virtual_device(self): xm.xla_device(), input_sharding=sharding_spec)[0] xt2 = xt1 / 0.5 - xm.mark_step() + xm.mark_step(wait=True) # after mark_step, xt2 should be materalized self.assertNotIn('aten::div', torch_xla._XLAC._get_xla_tensor_debug_info(xt2))