Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions API_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,19 @@ if __name__ == '__main__':
```

There are three differences between this multi-device snippet and the previous
single device snippet:

- `xmp.spawn()` creates the processes that each run an XLA device.
- `MpDeviceLoader` loads the training data onto each device.
- `xm.optimizer_step(optimizer)` consolidates the gradients between cores and issues the XLA device step computation.
single device snippet. Let's go over then one by one.

- `xmp.spawn()`
- Creates the processes that each run an XLA device.
- Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device.
- Note that if you print the `xm.xla_device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only execution is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads(check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
- `MpDeviceLoader`
- Loads the training data onto each device.
- `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance.
- `MpDeviceLoader` also call `xm.mark_step` for you every `batches_per_execution`(default to 1) batch being yield.
- `xm.optimizer_step(optimizer)`
- Consolidates the gradients between devices and issues the XLA device step computation.
- It is pretty much a `all_reduce_gradients` + `optimizer.step()` + `mark_step` and returns the loss being reduced.

The model definition, optimizer definition and training loop remain the same.

Expand All @@ -152,6 +160,33 @@ See the
[full multiprocessing example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py)
for more on training a network on multiple XLA devices with multi-processing.

### Running on TPU Pods
Multi-host setup for different accelerators can be very different. This doc will talk about the device independent bits of multi-host training and will use the TPU + PJRT runtime(currently available on 1.13 and 2.x releases) as an example.

Before you being, please take a look at our user guide at [here](https://cloud.google.com/tpu/docs/run-calculation-pytorch) which will explain some Google Cloud basis like how to use `gcloud` command and how to setup your project. You can also check [here](https://cloud.google.com/tpu/docs/how-to) for all Cloud TPU Howto. This doc will focus on the PyTorch/XLA perspective of the Setup.

Let's assume you have the above mnist example from above section in a `train_mnist_xla.py`. If it is a single host multi device training, you would ssh to the TPUVM and run command like

```
PJRT_DEVICE=TPU python3 train_mnist_xla.py
```

Now in order to run the same models on a TPU v4-16 (which has 2 host, each with 4 TPU devices), you will need to
- Make sure each host can access the training script and training data. This is usually done by using the `gcloud scp` command or `gcloud ssh` command to copy the training scripts to all hosts.
- Run the same training command on all hosts at the same time.

```
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"
```

Above `gcloud ssh` command will ssh to all hosts in TPUVM Pod and run the same command at the same time..

> **NOTE:** You need to run run above `gcloud` command outside of the TPUVM vm.

The model code and training scirpt is the same for the multi-process training and the multi-host training. PyTorch/XLA and the underlying infrastructure will make sure each device is aware of the global topology and each device's local and global ordinal. Cross-device communication will happen across all devices instead of local devices.

For more details regarding PJRT runtime and how to run it on pod, please refer to this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpu). For more information about PyTorch/XLA and TPU pod and a complete guide to run a resnet50 with fakedata on TPU pod, please refer to this [guide](https://cloud.google.com/tpu/docs/pytorch-pods).

## XLA Tensor Deep Dive

Using XLA tensors and devices requires changing only a few lines of code. But
Expand Down