-
Notifications
You must be signed in to change notification settings - Fork 559
Support SPMD through the xla:// init_method #5706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,6 @@ def pjrt_rendezvous_handler(url: str, | |
| ) == 'TPU' else 'localhost' | ||
|
|
||
| master_port = xu.getenv_as('MASTER_PORT', int, 12355) | ||
| world_size = xr.world_size() | ||
| with _store_lock: | ||
| global _store | ||
| if not _store: | ||
|
|
@@ -44,4 +43,8 @@ def pjrt_rendezvous_handler(url: str, | |
| xr.process_count(), | ||
| is_master=xr.process_index() == 0) | ||
|
|
||
| yield (_store, xr.global_ordinal(), world_size) | ||
| # In SPMD, the world size and rank are determined by the process count and | ||
| # index, while in multiprocess they are based on the device count and ordinal. | ||
| world_size = xr.process_count() if xr.is_spmd() else xr.world_size() | ||
| rank = xr.process_index() if xr.is_spmd() else xr.global_ordinal() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the process idx local (per host) or global (across the hosts)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, with SPMD we have a single process per host. process_index gives the global index of the current host's process in the runtime, so it will be different on each host. |
||
| yield (_store, rank, world_size) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import functools | ||
| import glob | ||
| from ipaddress import ip_address | ||
| import operator | ||
| import os | ||
| import pathlib | ||
|
|
@@ -10,6 +11,7 @@ | |
| import yaml | ||
|
|
||
| import torch | ||
| import torch_xla | ||
| import torch_xla.utils.utils as xu | ||
| import torch_xla.core.xla_env_vars as xenv | ||
| import torch_xla.core.xla_model as xm | ||
|
|
@@ -268,23 +270,49 @@ def configure_topology(local_rank: int, | |
|
|
||
|
|
||
| def discover_master_worker_ip(use_localhost: bool = True) -> str: | ||
| """Find the IP of the TPU host with TPU:0. | ||
| """Find the IP of the master TPU host. | ||
|
|
||
| In multiprocess, this is the host with TPU:0. | ||
| In SPMD mode, this is the host running process 0. | ||
|
|
||
| TPU device IDs are nondeterministic and independent from Cloud TPU worker IDs. | ||
|
|
||
| Args: | ||
| use_localhost: if there is only one TPU host, return 'localhost` instead | ||
| of that host's internal IP. | ||
| """ | ||
| import torch_xla.runtime as xr | ||
| worker_ips = get_worker_ips() | ||
| if len(worker_ips) == 1: | ||
| return 'localhost' | ||
|
|
||
| tpu_env = get_tpu_env() | ||
| current_worker_id = int(tpu_env[xenv.WORKER_ID]) | ||
| if xr.is_spmd(): | ||
| return _spmd_find_master_ip(worker_ips[current_worker_id]) | ||
|
|
||
| t = torch.tensor([current_worker_id], device=xm.xla_device()) | ||
| xm.collective_broadcast([t]) | ||
| xm.mark_step() | ||
|
|
||
| master_worker_id = int(t.cpu()) | ||
| return worker_ips[master_worker_id] | ||
|
|
||
|
|
||
| def _spmd_find_master_ip(current_worker_ip: str) -> str: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very convoluted but I guess it works... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it's a hack around our lack of direct control over collectives... If you or @yeounoh have any other ideas, I'm definitely open to revisiting this. |
||
| import torch_xla.runtime as xr | ||
| import torch_xla.experimental.xla_sharding as xs | ||
| from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards | ||
| ip_int = int(ip_address(current_worker_ip)) | ||
| n_dev = xr.global_runtime_device_count() | ||
| local_ndev = len(torch_xla._XLAC._xla_get_runtime_devices()) | ||
| # Create a global (n_dev x 2) tensor containing all process indices and IPs, | ||
| # and find the process 0 IP as the master IP. | ||
| shard = torch.LongTensor([[xr.process_index(), ip_int]]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we use global device ordinal instead of the process idx? It's the same thing, but wonder if it's feasible -- that way your search for the IP associated with TPU0 would be more consistent with the code? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately it's actually different - there's no guaranteed mapping between processes and devices, so process 0 may not control device 0 and we can't directly pull the IP out of the device 0 slot in the global tensor. In SPMD, I think it's better to rely on the process abstraction, since we never actually need the devices for collectives. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @jonb377 , we are not looking for the IP associated with TPU0, but with the rank=0. |
||
| op_sharding = xs.Mesh(range(n_dev), (n_dev, 1)).get_op_sharding((0, 1)) | ||
| global_tensor = from_cpu_shards([shard] * local_ndev, op_sharding).cpu() | ||
| # Process 0 may not control device 0, so we must do a linear search. | ||
| for proc, ip in global_tensor.tolist(): | ||
| if proc == 0: | ||
| return str(ip_address(ip)) | ||
| raise RuntimeError('Could not find IP of host running process 0') | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -242,3 +242,14 @@ def is_spmd(): | |
| """Returns if SPMD is set for execution.""" | ||
| # TODO(yeounoh) replace this when we fully deprecate the flag. | ||
| return xu.check_env_flag('XLA_USE_SPMD') | ||
|
|
||
|
|
||
| @requires_pjrt | ||
| def get_master_ip() -> str: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @will-cromar |
||
| """Retrieve the master worker IP for the runtime. This calls into | ||
| backend-specific discovery APIs. | ||
|
|
||
| Returns master worker's IP address as a string.""" | ||
| if device_type() == 'TPU': | ||
| return tpu.discover_master_worker_ip() | ||
| raise RuntimeError(f'IP discovery not supported for device: {device_type()}') | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In both cases the world size and rank should be determined by the process count and index -- this is needed due to the fact that
xr.world_size()returns 1 for a single-replica SPMD case. I think we could instead explain this, saying that "In SPMD we use the process count and index directly, since the xr.world_size() returns 1 without any replication group set." Or you can also add some more in the context of the CPU process group need for SPMD?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment to below - the device ordinal doesn't correspond to the process index. In the MP case, we want the process group to operate on device ordinals, while in SPMD we want it to be over the process indices.