diff --git a/test/pjrt/test_runtime_tpu.py b/test/pjrt/test_runtime_tpu.py index 174c7313a796..d7a7918bc9cc 100644 --- a/test/pjrt/test_runtime_tpu.py +++ b/test/pjrt/test_runtime_tpu.py @@ -238,6 +238,13 @@ def test_execute_time_metric(self): f"Expected exectue time of {i} to take more than " f"{expected_time_seconds} seconds, got {v / 1e9} seconds") + @mock.patch('torch_xla._internal.tpu.get_worker_ips') + def test_master_ip_discovery(self, patched_get_worker_ips): + # A basic test to verify the non-SPMD codepath returns the correct IP. Two + # IPs are needed to avoid the short-circuit return of localhost. + patched_get_worker_ips.return_value = ['10.0.0.1', '10.0.0.2'] + self.assertTrue(xr.get_master_ip(), '10.0.0.1') + if __name__ == '__main__': absltest.main() diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 37c0224a0f4c..76ea6b71672d 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -477,6 +477,15 @@ def test_manager_async_step_tracking(self, tmpdir): torch.allclose(v, new_state_dict[k]) for k, v in state_dict.items())) + @unittest.skipUnless(xr.device_type() == 'TPU', + 'TPU required for worker IP discovery') + @unittest.mock.patch('torch_xla._internal.tpu.get_worker_ips') + def test_master_ip_discovery(self, patched_get_worker_ips): + # A basic test to verify the SPMD codepath returns the correct IP. Two IPs + # are needed to avoid the short-circuit return of localhost. + patched_get_worker_ips.return_value = ['10.0.0.1', '10.0.0.2'] + self.assertTrue(xr.get_master_ip(), '10.0.0.1') + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/_internal/rendezvous.py b/torch_xla/_internal/rendezvous.py index 8a95ce1024cb..26bbae300a1a 100644 --- a/torch_xla/_internal/rendezvous.py +++ b/torch_xla/_internal/rendezvous.py @@ -28,7 +28,6 @@ def pjrt_rendezvous_handler(url: str, ) == 'TPU' else 'localhost' master_port = xu.getenv_as('MASTER_PORT', int, 12355) - world_size = xr.world_size() with _store_lock: global _store if not _store: @@ -44,4 +43,8 @@ def pjrt_rendezvous_handler(url: str, xr.process_count(), is_master=xr.process_index() == 0) - yield (_store, xr.global_ordinal(), world_size) + # In SPMD, the world size and rank are determined by the process count and + # index, while in multiprocess they are based on the device count and ordinal. + world_size = xr.process_count() if xr.is_spmd() else xr.world_size() + rank = xr.process_index() if xr.is_spmd() else xr.global_ordinal() + yield (_store, rank, world_size) diff --git a/torch_xla/_internal/tpu.py b/torch_xla/_internal/tpu.py index 89fbca4dcbc7..ec4e1626ec36 100644 --- a/torch_xla/_internal/tpu.py +++ b/torch_xla/_internal/tpu.py @@ -1,5 +1,6 @@ import functools import glob +from ipaddress import ip_address import operator import os import pathlib @@ -10,6 +11,7 @@ import yaml import torch +import torch_xla import torch_xla.utils.utils as xu import torch_xla.core.xla_env_vars as xenv import torch_xla.core.xla_model as xm @@ -268,7 +270,10 @@ def configure_topology(local_rank: int, def discover_master_worker_ip(use_localhost: bool = True) -> str: - """Find the IP of the TPU host with TPU:0. + """Find the IP of the master TPU host. + + In multiprocess, this is the host with TPU:0. + In SPMD mode, this is the host running process 0. TPU device IDs are nondeterministic and independent from Cloud TPU worker IDs. @@ -276,15 +281,38 @@ def discover_master_worker_ip(use_localhost: bool = True) -> str: use_localhost: if there is only one TPU host, return 'localhost` instead of that host's internal IP. """ + import torch_xla.runtime as xr worker_ips = get_worker_ips() if len(worker_ips) == 1: return 'localhost' tpu_env = get_tpu_env() current_worker_id = int(tpu_env[xenv.WORKER_ID]) + if xr.is_spmd(): + return _spmd_find_master_ip(worker_ips[current_worker_id]) + t = torch.tensor([current_worker_id], device=xm.xla_device()) xm.collective_broadcast([t]) xm.mark_step() master_worker_id = int(t.cpu()) return worker_ips[master_worker_id] + + +def _spmd_find_master_ip(current_worker_ip: str) -> str: + import torch_xla.runtime as xr + import torch_xla.experimental.xla_sharding as xs + from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards + ip_int = int(ip_address(current_worker_ip)) + n_dev = xr.global_runtime_device_count() + local_ndev = len(torch_xla._XLAC._xla_get_runtime_devices()) + # Create a global (n_dev x 2) tensor containing all process indices and IPs, + # and find the process 0 IP as the master IP. + shard = torch.LongTensor([[xr.process_index(), ip_int]]) + op_sharding = xs.Mesh(range(n_dev), (n_dev, 1)).get_op_sharding((0, 1)) + global_tensor = from_cpu_shards([shard] * local_ndev, op_sharding).cpu() + # Process 0 may not control device 0, so we must do a linear search. + for proc, ip in global_tensor.tolist(): + if proc == 0: + return str(ip_address(ip)) + raise RuntimeError('Could not find IP of host running process 0') diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index 649c538bd334..4f4834c805e8 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -242,3 +242,14 @@ def is_spmd(): """Returns if SPMD is set for execution.""" # TODO(yeounoh) replace this when we fully deprecate the flag. return xu.check_env_flag('XLA_USE_SPMD') + + +@requires_pjrt +def get_master_ip() -> str: + """Retrieve the master worker IP for the runtime. This calls into + backend-specific discovery APIs. + + Returns master worker's IP address as a string.""" + if device_type() == 'TPU': + return tpu.discover_master_worker_ip() + raise RuntimeError(f'IP discovery not supported for device: {device_type()}')