Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 91 additions & 29 deletions docs/pjrt.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Experimental PjRt Runtime Support

_This document reflects the current state of PjRt support in current nightly
builds_. See the [same document on the r1.13 branch](https://github.com/pytorch/xla/blob/r1.13/docs/pjrt.md)
for the status in the latest stable release.

The PyTorch/XLA team is currently migrating from the currently-supported XRT
runtime to the [PjRt
runtime](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/pjrt)
used by [JAX](https://github.com/google/jax). Although PjRt may work on TPU v2
and v3, we plan on making PjRt the officially supported runtime for PyTorch/XLA
on TPU v4 and future generations of TPU.
used by [JAX](https://github.com/google/jax).

PjRt is available as an _experimental preview_ in PyTorch/XLA r1.13. The
PyTorch/XLA team will provide limited support on a best-effort basis during this
Expand All @@ -31,7 +33,7 @@ like this:
PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data
```

### v4 TPU
### TPU

To create a new TPU with PyTorch/XLA r1.13 installed:

Expand Down Expand Up @@ -64,6 +66,26 @@ gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
```

#### Docker

You can also use Docker to run your workload in a container with PyTorch/XLA
preinstalled:

```
export DOCKER_IMAGE=gcr.io/...

# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"

# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"
```

Note that `docker run` requires privileged access to the host (`--privileged`)
to expose the TPU device to the container. Docker on TPU pods is only supported
with host networking `--net=host` at this time. See the [Cloud TPU documentation](https://cloud.google.com/tpu/docs/run-in-container)
for more information.

### GPU

Coming soon in a future release!
Expand Down Expand Up @@ -100,6 +122,47 @@ for more information about TPU architecture.
([`gcloud compute tpus tpu-vm scp`](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)) and run the code on each host in
parallel (e.g. [`gcloud compute tpus tpu-vm ssh --workers=all
--command="PJRT_DEVICE=TPU python run.py"`](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh))
- `xm.rendezvous` has been reimplemented using XLA-native collective
communication to enhance stability on large TPU pods. See below for more
details.

### Changes to `xm.rendezvous`

_New in PyTorch/XLA r1.14 (nightly only)_

In practice, we found that running a single mesh master process was unreliable
on TPU pods with thousands of chips due to the number of inbound connections to
worker 0. A single client process timing out could cause a failure and force the
entire workload to restart.

Thus, we have reimplemented `xm.rendezvous` with native XLA collective
communication, which is much more stable and well-tested on large TPU pods. This
imposes two new constraints compared to the XRT implementation:

- Because the payload has to become part of the XLA graph, `xm.mark_step` is
called both before and after the data is transferred. Calling `xm.rendezvous`
in the middle of model code may force an unwanted compilation.
- Because XLA does not permit collective operations to run on a subset of
workers, all workers must participate in the `rendezvous`.

If you require the old behavior of `xm.rendezvous` (i.e. communicating data
without altering the XLA graph and/or synchronizing a subset of workers),
consider using [`torch.distributed.barrier`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier)
or [`torch.distributed.all_gather_object`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object)
with a `gloo` process group. If you are also using the `xla` `torch.distributed`
backend, you can use `torch.new_group` to create a `gloo` subgroup. See [this
example](https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier)
from the PyTorch documentation. Keep in mind these constraints:

- `torch.distributed` is not fully supported on TPU v2/v3 at this time. Only a
subset of operations with the `xla` backend are tested, and `gloo` will likely
not work as expected in a multiprocessing context.
- In our experiments, `gloo` does not scale well to thousands of chips, so
expect this alternative to be less reliable than using `xm.rendezvous` with
PJRT.

Note: PyTorch/XLA 1.13 implementenation of `xm.rendezvous` uses `gloo` and has
both of the above constraints.

## TPUs v2/v3 vs v4

Expand All @@ -108,36 +171,18 @@ v2/v3, one TPU chip is represented to PyTorch as _two_ devices. It is not
possible to access the same TPU chip from multiple processes, so workloads must
be able to handle two devices per process. The easiest way to handle this is to
spawn two threads per process on TPU v2/v3, which is done automatically by
`xmp.spawn` when using PjRt. With multiple threads per process, multiple replicas
will share global state, causing the following known issues:
`xmp.spawn` when using PjRt. With multiple threads per process, multiple
replicas will share global state, causing the following known issues:

- Threads will share the same `torch` random seed used for parameter
initialization. If you relied on each process having the same random seed for
deterministic parameter initialization, you will have to synchronize module
parameters via collective broadcasting instead (e.g.
`pjrt.broadcast_master_param(model)`).
`pjrt.broadcast_master_param(model)`). See [`test_train_mp_imagenet.py`](`../test/test_train_mp_imagenet.py`)
for an example.
- `torch.distributed` uses a global process group and does not support
multi-threading, so the `xla` `torch.distributed` backend will not work with
PjRt and TPU v2 and v3 at this time.
- Because the current implementation of `xm.rendezvous` for PjRt relies on
`torch.distributed`, `xm.rendezvous` is not supported with PjRt on TPU v2 and
v3.

### Compatible examples

For an overview of the changes required to migrate from TPU v2/v3 to v4, compare
our MNIST ([XRT](../test/test_train_mp_mnist.py),
[PjRt](../test/pjrt/test_train_pjrt_mnist.py)) and ImageNet
([XRT](../test/test_train_mp_imagenet.py),
[PjRt](../test/pjrt/test_train_pjrt_imagenet.py)) examples.

The PjRt MNIST and ImageNet examples are compatible with all versions of TPU.
Use the following commands to run them on a single-host TPU (e.g. v3-8 or v4-8).

```
PJRT_DEVICE=TPU python3 xla/test/pjrt/test_train_pjrt_mnist.py --fake_data
PJRT_DEVICE=TPU python3 xla/test/pjrt/test_train_pjrt_imagenet.py --fake_data
```
multi-threading, so the `xla` `torch.distributed` backend does not fully
support TPU v2/v3 with PJRT at this time.

## PjRt and DDP

Expand All @@ -148,4 +193,21 @@ run the DDP script as usual but with `PJRT_DEVICE=TPU`. Here is a full example:
PJRT_DEVICE=TPU MASTER_ADDR=localhost MASTER_PORT=6000 python xla/test/test_train_mp_mnist.py --ddp --fake_data --num_epochs 1
```

Caveat: for TPU V2 and V3, however, XRT will still be needed to run DDP.
### Experimental PjRt DDP implementation

_New in PyTorch/XLA r1.14 (nightly only)_

Due to `torch.distributed`'s limitations on multithreading,
`torch.nn.parallel.DistributedDataParallel` does not support TPU v2/v3 with
PJRT. Thus, we have provided an alternative implementation of DDP that is
optimized for TPUs and supports TPU v2 and v3 in
[`torch.experimental.pjrt.DistributedDataParallel`](../torch_xla/experimental/pjrt.py).

All of PjRt is in an experimental preview state, but consider this DDP
implementation to be _especially_ unstable. The behavior may change
significantly over time, it may produce incorrect results, or it may be
removed entirely. If you encounter any issues, please report them on GitHub with
the `runtime` and `ddp` tags.

See [`test_train_mp_imagenet.py`](`../test/test_train_mp_imagenet.py`) for an
example drop-in usage.
7 changes: 7 additions & 0 deletions test/test_train_mp_imagenet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from torch_xla.experimental import pjrt
import args_parse

SUPPORTED_MODELS = [
Expand Down Expand Up @@ -199,6 +200,12 @@ def train_imagenet():

device = xm.xla_device()
model = get_model_property('model_fn')().to(device)

# Initialization is nondeterministic with multiple threads in PjRt.
# Synchronize model parameters across replicas manually.
if pjrt.using_pjrt():
pjrt.broadcast_master_param(model)

if FLAGS.ddp:
model = DDP(model, gradient_as_bucket_view=True, broadcast_buffers=False)

Expand Down