Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/pjrt/test_runtime_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ def test_execute_time_metric(self):
f"Expected exectue time of {i} to take more than "
f"{expected_time_seconds} seconds, got {v / 1e9} seconds")

@mock.patch('torch_xla._internal.tpu.get_worker_ips')
def test_master_ip_discovery(self, patched_get_worker_ips):
# A basic test to verify the non-SPMD codepath returns the correct IP. Two
# IPs are needed to avoid the short-circuit return of localhost.
patched_get_worker_ips.return_value = ['10.0.0.1', '10.0.0.2']
self.assertTrue(xr.get_master_ip(), '10.0.0.1')


if __name__ == '__main__':
absltest.main()
9 changes: 9 additions & 0 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,15 @@ def test_manager_async_step_tracking(self, tmpdir):
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))

@unittest.skipUnless(xr.device_type() == 'TPU',
'TPU required for worker IP discovery')
@unittest.mock.patch('torch_xla._internal.tpu.get_worker_ips')
def test_master_ip_discovery(self, patched_get_worker_ips):
# A basic test to verify the SPMD codepath returns the correct IP. Two IPs
# are needed to avoid the short-circuit return of localhost.
patched_get_worker_ips.return_value = ['10.0.0.1', '10.0.0.2']
self.assertTrue(xr.get_master_ip(), '10.0.0.1')


if __name__ == '__main__':
test = unittest.main()
Expand Down
7 changes: 5 additions & 2 deletions torch_xla/_internal/rendezvous.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def pjrt_rendezvous_handler(url: str,
) == 'TPU' else 'localhost'

master_port = xu.getenv_as('MASTER_PORT', int, 12355)
world_size = xr.world_size()
with _store_lock:
global _store
if not _store:
Expand All @@ -44,4 +43,8 @@ def pjrt_rendezvous_handler(url: str,
xr.process_count(),
is_master=xr.process_index() == 0)

yield (_store, xr.global_ordinal(), world_size)
# In SPMD, the world size and rank are determined by the process count and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In both cases the world size and rank should be determined by the process count and index -- this is needed due to the fact that xr.world_size() returns 1 for a single-replica SPMD case. I think we could instead explain this, saying that "In SPMD we use the process count and index directly, since the xr.world_size() returns 1 without any replication group set." Or you can also add some more in the context of the CPU process group need for SPMD?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment to below - the device ordinal doesn't correspond to the process index. In the MP case, we want the process group to operate on device ordinals, while in SPMD we want it to be over the process indices.

# index, while in multiprocess they are based on the device count and ordinal.
world_size = xr.process_count() if xr.is_spmd() else xr.world_size()
rank = xr.process_index() if xr.is_spmd() else xr.global_ordinal()
Copy link
Contributor

@yeounoh yeounoh Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the process idx local (per host) or global (across the hosts)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, with SPMD we have a single process per host. process_index gives the global index of the current host's process in the runtime, so it will be different on each host.

yield (_store, rank, world_size)
30 changes: 29 additions & 1 deletion torch_xla/_internal/tpu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import glob
from ipaddress import ip_address
import operator
import os
import pathlib
Expand All @@ -10,6 +11,7 @@
import yaml

import torch
import torch_xla
import torch_xla.utils.utils as xu
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -268,23 +270,49 @@ def configure_topology(local_rank: int,


def discover_master_worker_ip(use_localhost: bool = True) -> str:
"""Find the IP of the TPU host with TPU:0.
"""Find the IP of the master TPU host.

In multiprocess, this is the host with TPU:0.
In SPMD mode, this is the host running process 0.

TPU device IDs are nondeterministic and independent from Cloud TPU worker IDs.

Args:
use_localhost: if there is only one TPU host, return 'localhost` instead
of that host's internal IP.
"""
import torch_xla.runtime as xr
worker_ips = get_worker_ips()
if len(worker_ips) == 1:
return 'localhost'

tpu_env = get_tpu_env()
current_worker_id = int(tpu_env[xenv.WORKER_ID])
if xr.is_spmd():
return _spmd_find_master_ip(worker_ips[current_worker_id])

t = torch.tensor([current_worker_id], device=xm.xla_device())
xm.collective_broadcast([t])
xm.mark_step()

master_worker_id = int(t.cpu())
return worker_ips[master_worker_id]


def _spmd_find_master_ip(current_worker_ip: str) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very convoluted but I guess it works...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a hack around our lack of direct control over collectives... If you or @yeounoh have any other ideas, I'm definitely open to revisiting this.

import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
ip_int = int(ip_address(current_worker_ip))
n_dev = xr.global_runtime_device_count()
local_ndev = len(torch_xla._XLAC._xla_get_runtime_devices())
# Create a global (n_dev x 2) tensor containing all process indices and IPs,
# and find the process 0 IP as the master IP.
shard = torch.LongTensor([[xr.process_index(), ip_int]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use global device ordinal instead of the process idx? It's the same thing, but wonder if it's feasible -- that way your search for the IP associated with TPU0 would be more consistent with the code?

Copy link
Collaborator Author

@jonb377 jonb377 Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it's actually different - there's no guaranteed mapping between processes and devices, so process 0 may not control device 0 and we can't directly pull the IP out of the device 0 slot in the global tensor. In SPMD, I think it's better to rely on the process abstraction, since we never actually need the devices for collectives.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jonb377 , we are not looking for the IP associated with TPU0, but with the rank=0.

op_sharding = xs.Mesh(range(n_dev), (n_dev, 1)).get_op_sharding((0, 1))
global_tensor = from_cpu_shards([shard] * local_ndev, op_sharding).cpu()
# Process 0 may not control device 0, so we must do a linear search.
for proc, ip in global_tensor.tolist():
if proc == 0:
return str(ip_address(ip))
raise RuntimeError('Could not find IP of host running process 0')
11 changes: 11 additions & 0 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,14 @@ def is_spmd():
"""Returns if SPMD is set for execution."""
# TODO(yeounoh) replace this when we fully deprecate the flag.
return xu.check_env_flag('XLA_USE_SPMD')


@requires_pjrt
def get_master_ip() -> str:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Retrieve the master worker IP for the runtime. This calls into
backend-specific discovery APIs.

Returns master worker's IP address as a string."""
if device_type() == 'TPU':
return tpu.discover_master_worker_ip()
raise RuntimeError(f'IP discovery not supported for device: {device_type()}')