Skip to content

Commit ee72332

Browse files
authored
Cherry pick pjrt:// init method rename and doc updates (#5562)
* Change `pjrt://` init method to `xla://` (#5560) * Update PJRT documentation for the 2.1 release (#5557) * Update PJRT documentation for the 2.1 release * clarify plugins * clarify PJRT doc * Update `pjrt://` to `xla://`
1 parent 2c07df9 commit ee72332

File tree

11 files changed

+101
-85
lines changed

11 files changed

+101
-85
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
<b>Current CI status:</b> ![GitHub Actions
44
status](https://github.com/pytorch/xla/actions/workflows/build_and_test.yml/badge.svg)
55

6+
Note: PyTorch/XLA r2.1 will be the last release with XRT available as a legacy
7+
runtime. Our main release build will not include XRT, but it will be available
8+
in a separate package.
9+
610
PyTorch/XLA is a Python package that uses the [XLA deep learning
711
compiler](https://www.tensorflow.org/xla) to connect the [PyTorch deep learning
812
framework](https://pytorch.org/) and [Cloud
@@ -70,6 +74,7 @@ If you're using `DistributedDataParallel`, make the following changes:
7074
+import torch_xla.core.xla_model as xm
7175
+import torch_xla.distributed.parallel_loader as pl
7276
+import torch_xla.distributed.xla_multiprocessing as xmp
77+
+import torch_xla.distributed.xla_backend
7378

7479
def _mp_fn(rank, world_size):
7580
...
@@ -78,7 +83,7 @@ If you're using `DistributedDataParallel`, make the following changes:
7883
- os.environ['MASTER_PORT'] = '12355'
7984
- dist.init_process_group("gloo", rank=rank, world_size=world_size)
8085
+ # Rank and world size are inferred from the XLA device runtime
81-
+ dist.init_process_group("xla", init_method='pjrt://')
86+
+ dist.init_process_group("xla", init_method='xla://')
8287
+
8388
+ model.to(xm.xla_device())
8489
+ # `gradient_as_bucket_view=tpu` required for XLA
@@ -101,7 +106,6 @@ If you're using `DistributedDataParallel`, make the following changes:
101106
+ xmp.spawn(_mp_fn, args=())
102107
```
103108

104-
105109
Additional information on PyTorch/XLA, including a description of its semantics
106110
and functions, is available at [PyTorch.org](http://pytorch.org/xla/). See the
107111
[API Guide](API_GUIDE.md) for best practices when writing networks that run on

docs/pjrt.md

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1-
# PJRT Runtime (Beta)
1+
# PJRT Runtime
22

3-
_This document reflects the current state of PJRT support in current nightly
4-
builds_. See the [same document on the r2.0 branch](https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md)
5-
for the status in the latest stable release.
6-
7-
The PyTorch/XLA team is currently migrating from the currently-supported XRT
8-
runtime to the [PJRT
3+
PyTorch/XLA has migrated from the TensorFlow-based XRT runtime to the [PJRT
94
runtime](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/pjrt)
105
used by [JAX](https://github.com/google/jax).
116

12-
PJRT is available for preview in PyTorch/XLA 2.0. **We are planning to make
13-
PJRT our officially supported runtime**, so we encourage all users to experiment
14-
with it. We aim to make PJRT stable in release 2.1, so if you encounter a bug
15-
with PJRT, please file an issue on GitHub with the `runtime` tag.
7+
If you encounter a bug with PJRT, please file an issue on GitHub with the
8+
`runtime` tag.
9+
10+
_New features in PyTorch/XLA r2.1_:
11+
12+
* PJRT is stable in PyTorch/XLA r2.1!
13+
* Public runtime APIs have moved from `torch_xla.experimental.pjrt` to
14+
`torch_xla.runtime`.
15+
* The `pjrt://` init method has been renamed to `xla://`, and it is registered
16+
by `torch_xla.distributed.xla_backend`.
17+
* The previous `torch_xla.experimental.*` names are still available in this
18+
release for compatibility.
19+
* `torchrun` is now supported when using `init_method='xla://'`.
20+
* New plugins for XPU and Neuron via the PJRT C API.
1621

1722
_New features in PyTorch/XLA r2.0_:
1823

@@ -29,7 +34,7 @@ _New features in PyTorch/XLA r2.0_:
2934
## TL;DR
3035

3136
* To use the PJRT preview runtime, set the `PJRT_DEVICE` environment variable to
32-
`CPU`, `TPU, or `GPU`
37+
`CPU`, `TPU`, or `GPU`
3338
* In XRT, all distributed workloads are multiprocess, with one process per
3439
device. On TPU v2 and v3 in PJRT, workloads are multiprocess and multithreaded
3540
(4 processes with 2 threads each), so your workload should be thread-safe. See
@@ -45,7 +50,7 @@ _New features in PyTorch/XLA r2.0_:
4550
The global `torch` RNG is _not_ thread-safe, even if you set the same
4651
`torch.manual_seed` across replicas.
4752
* To use `torch.distributed`, import `torch_xla.experimental.pjrt_backend` and
48-
use the `pjrt://` `init_method`.
53+
use the `xla://` `init_method`.
4954
* These steps are optional for GPU and TPU v4.
5055

5156
Sample diff from XRT to PJRT:
@@ -62,20 +67,19 @@ Sample diff from XRT to PJRT:
6267
import torch_xla.distributed.parallel_loader as pl
6368
import torch_xla.distributed.xla_backend
6469
import torch_xla.distributed.xla_multiprocessing as xmp
65-
+import torch_xla.experimental.pjrt_backend
66-
+import torch_xla.experimental.pjrt as pjrt
70+
+import torch_xla.runtime as xr
6771

6872

6973
def _mp_fn(index):
7074
device = xm.xla_device()
7175
- dist.init_process_group('xla', rank=xm.get_ordinal(), world_size=xm.xrt_world_size())
72-
+ dist.init_process_group('xla', init_method='pjrt://')
76+
+ dist.init_process_group('xla', init_method='xla://')
7377

7478
torch.manual_seed(42)
7579
model = nn.Linear(128, 10).to(device)
7680

7781
+ # Optional for TPU v4 and GPU
78-
+ pjrt.broadcast_master_param(model)
82+
+ xm.broadcast_master_param(model)
7983
model = DDP(model, gradient_as_bucket_view=True)
8084

8185
loss_fn = nn.MSELoss()
@@ -286,15 +290,15 @@ without altering the XLA graph and/or synchronizing a subset of workers),
286290
consider using
287291
[`torch.distributed.barrier`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier)
288292
or
289-
`[torch.distributed.all_gather_object](https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object)`
293+
[`torch.distributed.all_gather_object`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object)
290294
with a `gloo` process group. If you are also using the `xla` `torch.distributed`
291295
backend, you can use `torch.new_group` to create a `gloo` subgroup. See [this
292296
example](https://pytorch.org/docs/stable/distributed.html#monitored-barrier)
293297
from the PyTorch documentation. Keep in mind these constraints:
294298

295299
* `torch.distributed` is not fully supported on TPU v2/v3. Only a subset of
296300
operations with the `xla` backend are implemented, and `gloo` will likely not
297-
work as expected in a multiprocessing context.
301+
work as expected in a multithreaded context.
298302
* In our experiments, `gloo` does not scale well to thousands of TPU chips, so
299303
expect this alternative to be less reliable than using `xm.rendezvous` with
300304
PJRT at large scales.
@@ -305,7 +309,7 @@ _New in PyTorch/XLA r2.0_
305309

306310
When using PJRT with `torch.distributed` and
307311
`[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md)`
308-
we strongly recommend using the new `pjrt://` `init_method`, which automatically
312+
we strongly recommend using the new `xla://` `init_method`, which automatically
309313
finds the replica IDs, world size, and master IP by querying the runtime. For
310314
example:
311315

@@ -316,12 +320,12 @@ import torch_xla.core.xla_model as xm
316320
import torch_xla.distributed.xla_multiprocessing as xmp
317321
from torch_xla.experimental import pjrt
318322

319-
# Required for `pjrt://` init_method
320-
import torch_xla.experimental.pjrt_backend
323+
# Required for `xla://` init_method and `xla` backend
324+
import torch_xla.distributed.xla_backend
321325

322326
def _all_gather(index: int):
323327
# No need to pass in `rank` or `world_size`
324-
dist.init_process_group('xla', init_method='pjrt://')
328+
dist.init_process_group('xla', init_method='xla://')
325329

326330
t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
327331
output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
@@ -334,10 +338,14 @@ if __name__ == '__main__':
334338
xmp.spawn(_all_gather)
335339
```
336340

337-
Note: Although the `pjrt://` init_method is not required on TPU v4, it is still
341+
Note: Although the `xla://` init_method is not required on TPU v4, it is still
338342
recommended. If you use `env://`, `MASTER_ADDR` must be set to IP host that has
339-
device 0, which is _not_ always worker 0. The `pjrt://` init_method finds this
340-
IP automatically and supports TPU v2/v3.
343+
device 0, which is _not_ always worker 0. The `xla://` init_method finds this
344+
IP automatically.
345+
346+
Note: For TPU v2/v3, you still need to import
347+
`torch_xla.experimental.pjrt_backend`, as TPU v2/v3 support in
348+
`torch.distributed` is still experimental.
341349

342350
For more information about using `DistributedDataParallel` on PyTorch/XLA, see
343351
[`ddp.md`](./ddp.md) on TPU V4. For an example that uses DDP and PJRT together,

test/pjrt/test_ddp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn as nn
77
from torch.nn.parallel import DistributedDataParallel as DDP
88
import torch_xla.core.xla_model as xm
9-
import torch_xla.experimental.pjrt_backend
9+
import torch_xla.distributed.xla_backend
1010
from torch_xla import runtime as xr
1111
from torch_xla._internal import pjrt, tpu
1212

@@ -24,7 +24,7 @@ class TestPjRtDistributedDataParallel(parameterized.TestCase):
2424

2525
@staticmethod
2626
def _ddp_init(index: int = ...):
27-
dist.init_process_group('xla', init_method='pjrt://')
27+
dist.init_process_group('xla', init_method='xla://')
2828
device = xm.xla_device()
2929
model = nn.Linear(10, 10).to(device)
3030
ddp_model = DDP(model)
@@ -41,7 +41,7 @@ def test_ddp_init_threaded(self):
4141
def test_ddp_correctness(self, use_large_net: bool):
4242
pjrt.run_multiprocess(
4343
util.ddp_correctness,
44-
init_method='pjrt://',
44+
init_method='xla://',
4545
use_large_net=use_large_net,
4646
debug=FLAGS.debug)
4747

test/pjrt/test_torchrun.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import torch
44
import torch.distributed as dist
55
import torch_xla.core.xla_model as xm
6-
import torch_xla.experimental.pjrt_backend
6+
import torch_xla.distributed.xla_backend
77
import torch_xla.runtime as xr
88
import torch_xla.utils.utils as xu
99

1010

1111
class TestTorchrun(absltest.TestCase):
1212

1313
def test_all_gather(self):
14-
dist.init_process_group('xla', init_method='pjrt://')
14+
dist.init_process_group('xla', init_method='xla://')
1515

1616
dist_world_size = xu.getenv_as('WORLD_SIZE', int)
1717
devices_per_thread = xr.addressable_device_count()

test/test_torch_distributed_xla_backend.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch_xla
1111
import torch_xla.core.xla_model as xm
1212
import torch_xla.distributed.xla_backend
13-
import torch_xla.experimental.pjrt_backend
1413
from torch_xla import runtime as xr
1514

1615

@@ -42,7 +41,7 @@ class XlaBackendTest(parameterized.TestCase):
4241
def setUpClass(cls):
4342
# Add no-op all-reduce ops to HLO
4443
os.environ['XLA_ALWAYS_ALLREDUCE'] = '1'
45-
dist.init_process_group('xla', init_method='pjrt://')
44+
dist.init_process_group('xla', init_method='xla://')
4645

4746
def tearDown(self) -> None:
4847
# Purge all computations attached the device.

test/test_train_mp_imagenet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
'--ddp': {
3232
'action': 'store_true',
3333
},
34-
# Use pjrt:// init_method instead of env:// for `torch.distributed`.
34+
# Use xla:// init_method instead of env:// for `torch.distributed`.
3535
# Required for DDP on TPU v2/v3 when using PJRT.
3636
'--pjrt_distributed': {
3737
'action': 'store_true',
@@ -180,8 +180,7 @@ def _train_update(device, step, loss, tracker, epoch, writer):
180180

181181
def train_imagenet():
182182
if FLAGS.pjrt_distributed:
183-
import torch_xla.experimental.pjrt_backend
184-
dist.init_process_group('xla', init_method='pjrt://')
183+
dist.init_process_group('xla', init_method='xla://')
185184
elif FLAGS.ddp:
186185
dist.init_process_group(
187186
'xla', world_size=xm.xrt_world_size(), rank=xm.get_ordinal())

test/test_train_mp_mnist.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,7 @@ def _train_update(device, step, loss, tracker, epoch, writer):
7777

7878
def train_mnist(flags, **kwargs):
7979
if flags.pjrt_distributed:
80-
import torch_xla.experimental.pjrt_backend
81-
dist.init_process_group('xla', init_method='pjrt://')
80+
dist.init_process_group('xla', init_method='xla://')
8281
elif flags.ddp:
8382
dist.init_process_group(
8483
'xla', world_size=xm.xrt_world_size(), rank=xm.get_ordinal())

torch_xla/_internal/rendezvous.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import datetime
2+
import logging
3+
import threading
4+
5+
import torch.distributed as dist
6+
from torch_xla.distributed import xla_backend
7+
from torch_xla import runtime as xr
8+
from torch_xla._internal import pjrt
9+
from torch_xla._internal import tpu
10+
import torch_xla.utils.utils as xu
11+
12+
_store = None
13+
_store_lock = threading.Lock()
14+
15+
16+
def pjrt_rendezvous_handler(url: str,
17+
timeout: datetime.timedelta = ...,
18+
**kwargs):
19+
# Assume `xmp.spawn` has not been called when using torchrun
20+
if dist.is_torchelastic_launched():
21+
local_world_size = xu.getenv_as('LOCAL_WORLD_SIZE', int)
22+
local_rank = xu.getenv_as('LOCAL_RANK', int)
23+
pjrt.initialize_multiprocess(local_rank, local_world_size)
24+
25+
master_ip = xu.getenv_as('MASTER_ADDR', str)
26+
if not master_ip:
27+
master_ip = tpu.discover_master_worker_ip() if xr.device_type(
28+
) == 'TPU' else 'localhost'
29+
30+
master_port = xu.getenv_as('MASTER_PORT', int, 12355)
31+
world_size = xr.world_size()
32+
with _store_lock:
33+
global _store
34+
if not _store:
35+
if xu.getenv_as('TORCHELASTIC_USE_AGENT_STORE', str) == 'True':
36+
attempt = xu.getenv_as('TORCHELASTIC_RESTART_COUNT', int, defval=0)
37+
tcp_store = dist.TCPStore(
38+
master_ip, master_port, xr.process_count(), is_master=False)
39+
_store = dist.PrefixStore(f"/worker/attempt_{attempt}", tcp_store)
40+
else:
41+
_store = dist.TCPStore(
42+
master_ip,
43+
master_port,
44+
xr.process_count(),
45+
is_master=xr.process_index() == 0)
46+
47+
yield (_store, xr.global_ordinal(), world_size)

torch_xla/distributed/xla_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.distributed as dist
33
import torch_xla.core.xla_model as xm
4+
from torch_xla._internal import rendezvous
45
import logging
56
import os
67
from torch._C._distributed_c10d import ProcessGroup
@@ -16,6 +17,8 @@ def _register_xla_backend():
1617

1718
_register_xla_backend()
1819

20+
dist.register_rendezvous_handler('xla', rendezvous.pjrt_rendezvous_handler)
21+
1922

2023
def _ret_work(ret):
2124
fut = torch.futures.Future()
Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,9 @@
1-
import datetime
21
import logging
3-
import threading
42

53
import torch.distributed as dist
64
from torch_xla.distributed import xla_backend
7-
from torch_xla import runtime as xr
8-
from torch_xla._internal import pjrt
5+
from torch_xla._internal import rendezvous
96
from torch_xla._internal import tpu
10-
import torch_xla.utils.utils as xu
11-
12-
_store = None
13-
_store_lock = threading.Lock()
14-
15-
16-
def _pjrt_rendezvous_handler(url: str,
17-
timeout: datetime.timedelta = ...,
18-
**kwargs):
19-
# Assume `xmp.spawn` has not been called when using torchrun
20-
if dist.is_torchelastic_launched():
21-
local_world_size = xu.getenv_as('LOCAL_WORLD_SIZE', int)
22-
local_rank = xu.getenv_as('LOCAL_RANK', int)
23-
pjrt.initialize_multiprocess(local_rank, local_world_size)
24-
25-
master_ip = xu.getenv_as('MASTER_ADDR', str)
26-
if not master_ip:
27-
master_ip = tpu.discover_master_worker_ip() if xr.device_type(
28-
) == 'TPU' else 'localhost'
29-
30-
master_port = xu.getenv_as('MASTER_PORT', int, 12355)
31-
world_size = xr.world_size()
32-
with _store_lock:
33-
global _store
34-
if not _store:
35-
if xu.getenv_as('TORCHELASTIC_USE_AGENT_STORE', str) == 'True':
36-
attempt = xu.getenv_as('TORCHELASTIC_RESTART_COUNT', int, defval=0)
37-
tcp_store = dist.TCPStore(
38-
master_ip, master_port, xr.process_count(), is_master=False)
39-
_store = dist.PrefixStore(f"/worker/attempt_{attempt}", tcp_store)
40-
else:
41-
_store = dist.TCPStore(
42-
master_ip,
43-
master_port,
44-
xr.process_count(),
45-
is_master=xr.process_index() == 0)
46-
47-
yield (_store, xr.global_ordinal(), world_size)
48-
497

508
if tpu.num_available_chips() > 0 and tpu.version() <= 3:
519
from torch.testing._internal.distributed import multi_threaded_pg
@@ -54,4 +12,4 @@ def _pjrt_rendezvous_handler(url: str,
5412
'and does not support torchrun.')
5513
multi_threaded_pg._install_threaded_pg()
5614

57-
dist.register_rendezvous_handler('pjrt', _pjrt_rendezvous_handler)
15+
dist.register_rendezvous_handler('pjrt', rendezvous.pjrt_rendezvous_handler)

0 commit comments

Comments
 (0)