1- # PJRT Runtime (Beta)
1+ # PJRT Runtime
22
3- _ This document reflects the current state of PJRT support in current nightly
4- builds_ . See the [ same document on the r2.0 branch] ( https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md )
5- for the status in the latest stable release.
6-
7- The PyTorch/XLA team is currently migrating from the currently-supported XRT
8- runtime to the [ PJRT
3+ PyTorch/XLA has migrated from the TensorFlow-based XRT runtime to the [ PJRT
94runtime] ( https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/pjrt )
105used by [ JAX] ( https://github.com/google/jax ) .
116
12- PJRT is available for preview in PyTorch/XLA 2.0. ** We are planning to make
13- PJRT our officially supported runtime** , so we encourage all users to experiment
14- with it. We aim to make PJRT stable in release 2.1, so if you encounter a bug
15- with PJRT, please file an issue on GitHub with the ` runtime ` tag.
7+ If you encounter a bug with PJRT, please file an issue on GitHub with the
8+ ` runtime ` tag.
9+
10+ _ New features in PyTorch/XLA r2.1_ :
11+
12+ * PJRT is stable in PyTorch/XLA r2.1!
13+ * Public runtime APIs have moved from ` torch_xla.experimental.pjrt ` to
14+ ` torch_xla.runtime ` .
15+ * The ` pjrt:// ` init method has been renamed to ` xla:// ` , and it is registered
16+ by ` torch_xla.distributed.xla_backend ` .
17+ * The previous ` torch_xla.experimental.* ` names are still available in this
18+ release for compatibility.
19+ * ` torchrun ` is now supported when using ` init_method='xla://' ` .
20+ * New plugins for XPU and Neuron via the PJRT C API.
1621
1722_ New features in PyTorch/XLA r2.0_ :
1823
@@ -29,7 +34,7 @@ _New features in PyTorch/XLA r2.0_:
2934## TL;DR
3035
3136* To use the PJRT preview runtime, set the ` PJRT_DEVICE ` environment variable to
32- ` CPU ` , ` TPU, or ` GPU`
37+ ` CPU ` , ` TPU ` , or ` GPU `
3338* In XRT, all distributed workloads are multiprocess, with one process per
3439 device. On TPU v2 and v3 in PJRT, workloads are multiprocess and multithreaded
3540 (4 processes with 2 threads each), so your workload should be thread-safe. See
@@ -45,7 +50,7 @@ _New features in PyTorch/XLA r2.0_:
4550 The global ` torch ` RNG is _ not_ thread-safe, even if you set the same
4651 ` torch.manual_seed ` across replicas.
4752 * To use ` torch.distributed ` , import ` torch_xla.experimental.pjrt_backend ` and
48- use the ` pjrt ://` ` init_method ` .
53+ use the ` xla ://` ` init_method ` .
4954 * These steps are optional for GPU and TPU v4.
5055
5156Sample diff from XRT to PJRT:
@@ -62,20 +67,19 @@ Sample diff from XRT to PJRT:
6267 import torch_xla.distributed.parallel_loader as pl
6368 import torch_xla.distributed.xla_backend
6469 import torch_xla.distributed.xla_multiprocessing as xmp
65- + import torch_xla.experimental.pjrt_backend
66- + import torch_xla.experimental.pjrt as pjrt
70+ + import torch_xla.runtime as xr
6771
6872
6973 def _mp_fn(index):
7074 device = xm.xla_device()
7175- dist.init_process_group('xla', rank=xm.get_ordinal(), world_size=xm.xrt_world_size())
72- + dist.init_process_group('xla', init_method='pjrt ://')
76+ + dist.init_process_group('xla', init_method='xla ://')
7377
7478 torch.manual_seed(42)
7579 model = nn.Linear(128, 10).to(device)
7680
7781+ # Optional for TPU v4 and GPU
78- + pjrt .broadcast_master_param(model)
82+ + xm .broadcast_master_param(model)
7983 model = DDP(model, gradient_as_bucket_view=True)
8084
8185 loss_fn = nn.MSELoss()
@@ -286,15 +290,15 @@ without altering the XLA graph and/or synchronizing a subset of workers),
286290consider using
287291[ ` torch.distributed.barrier ` ] ( https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier )
288292or
289- ` [ torch.distributed.all_gather_object](https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object)`
293+ [ ` torch.distributed.all_gather_object ` ] ( https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object )
290294with a ` gloo ` process group. If you are also using the ` xla ` ` torch.distributed `
291295backend, you can use ` torch.new_group ` to create a ` gloo ` subgroup. See [ this
292296example] ( https://pytorch.org/docs/stable/distributed.html#monitored-barrier )
293297from the PyTorch documentation. Keep in mind these constraints:
294298
295299* ` torch.distributed ` is not fully supported on TPU v2/v3. Only a subset of
296300 operations with the ` xla ` backend are implemented, and ` gloo ` will likely not
297- work as expected in a multiprocessing context.
301+ work as expected in a multithreaded context.
298302* In our experiments, ` gloo ` does not scale well to thousands of TPU chips, so
299303 expect this alternative to be less reliable than using ` xm.rendezvous ` with
300304 PJRT at large scales.
@@ -305,7 +309,7 @@ _New in PyTorch/XLA r2.0_
305309
306310When using PJRT with ` torch.distributed ` and
307311` [torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md) `
308- we strongly recommend using the new ` pjrt ://` ` init_method ` , which automatically
312+ we strongly recommend using the new ` xla ://` ` init_method ` , which automatically
309313finds the replica IDs, world size, and master IP by querying the runtime. For
310314example:
311315
@@ -316,12 +320,12 @@ import torch_xla.core.xla_model as xm
316320import torch_xla.distributed.xla_multiprocessing as xmp
317321from torch_xla.experimental import pjrt
318322
319- # Required for `pjrt ://` init_method
320- import torch_xla.experimental.pjrt_backend
323+ # Required for `xla ://` init_method and `xla` backend
324+ import torch_xla.distributed.xla_backend
321325
322326def _all_gather (index : int ):
323327 # No need to pass in `rank` or `world_size`
324- dist.init_process_group(' xla' , init_method = ' pjrt ://' )
328+ dist.init_process_group(' xla' , init_method = ' xla ://' )
325329
326330 t = torch.tensor([index], dtype = torch.int32, device = xm.xla_device())
327331 output = [torch.zeros_like(t) for _ in range (dist.get_world_size())]
@@ -334,10 +338,14 @@ if __name__ == '__main__':
334338 xmp.spawn(_all_gather)
335339```
336340
337- Note: Although the ` pjrt ://` init_method is not required on TPU v4, it is still
341+ Note: Although the ` xla ://` init_method is not required on TPU v4, it is still
338342recommended. If you use ` env:// ` , ` MASTER_ADDR ` must be set to IP host that has
339- device 0, which is _ not_ always worker 0. The ` pjrt:// ` init_method finds this
340- IP automatically and supports TPU v2/v3.
343+ device 0, which is _ not_ always worker 0. The ` xla:// ` init_method finds this
344+ IP automatically.
345+
346+ Note: For TPU v2/v3, you still need to import
347+ ` torch_xla.experimental.pjrt_backend ` , as TPU v2/v3 support in
348+ ` torch.distributed ` is still experimental.
341349
342350For more information about using ` DistributedDataParallel ` on PyTorch/XLA, see
343351[ ` ddp.md ` ] ( ./ddp.md ) on TPU V4. For an example that uses DDP and PJRT together,
0 commit comments