Skip to content

Commit cf7995d

Browse files
authored
Jack cao g/r20 backportdoc (#4752)
* Update API GUIDE to include multi host training and add some colors (#4706) * Update API GUIDE to include multi host training and add some colors * address review comments * Update README (#4734) * Update README * update user guide section title * Add public readme for torchdynamo (#4744) * Add public readme for torchdynamo * Update index file
1 parent 8c7394c commit cf7995d

File tree

4 files changed

+152
-18
lines changed

4 files changed

+152
-18
lines changed

API_GUIDE.md

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,19 @@ if __name__ == '__main__':
135135
```
136136

137137
There are three differences between this multi-device snippet and the previous
138-
single device snippet:
139-
140-
- `xmp.spawn()` creates the processes that each run an XLA device.
141-
- `MpDeviceLoader` loads the training data onto each device.
142-
- `xm.optimizer_step(optimizer)` consolidates the gradients between cores and issues the XLA device step computation.
138+
single device snippet. Let's go over then one by one.
139+
140+
- `xmp.spawn()`
141+
- Creates the processes that each run an XLA device.
142+
- Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device.
143+
- Note that if you print the `xm.xla_device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only execution is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads(check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
144+
- `MpDeviceLoader`
145+
- Loads the training data onto each device.
146+
- `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance.
147+
- `MpDeviceLoader` also call `xm.mark_step` for you every `batches_per_execution`(default to 1) batch being yield.
148+
- `xm.optimizer_step(optimizer)`
149+
- Consolidates the gradients between devices and issues the XLA device step computation.
150+
- It is pretty much a `all_reduce_gradients` + `optimizer.step()` + `mark_step` and returns the loss being reduced.
143151

144152
The model definition, optimizer definition and training loop remain the same.
145153

@@ -152,6 +160,33 @@ See the
152160
[full multiprocessing example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py)
153161
for more on training a network on multiple XLA devices with multi-processing.
154162

163+
### Running on TPU Pods
164+
Multi-host setup for different accelerators can be very different. This doc will talk about the device independent bits of multi-host training and will use the TPU + PJRT runtime(currently available on 1.13 and 2.x releases) as an example.
165+
166+
Before you being, please take a look at our user guide at [here](https://cloud.google.com/tpu/docs/run-calculation-pytorch) which will explain some Google Cloud basis like how to use `gcloud` command and how to setup your project. You can also check [here](https://cloud.google.com/tpu/docs/how-to) for all Cloud TPU Howto. This doc will focus on the PyTorch/XLA perspective of the Setup.
167+
168+
Let's assume you have the above mnist example from above section in a `train_mnist_xla.py`. If it is a single host multi device training, you would ssh to the TPUVM and run command like
169+
170+
```
171+
PJRT_DEVICE=TPU python3 train_mnist_xla.py
172+
```
173+
174+
Now in order to run the same models on a TPU v4-16 (which has 2 host, each with 4 TPU devices), you will need to
175+
- Make sure each host can access the training script and training data. This is usually done by using the `gcloud scp` command or `gcloud ssh` command to copy the training scripts to all hosts.
176+
- Run the same training command on all hosts at the same time.
177+
178+
```
179+
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"
180+
```
181+
182+
Above `gcloud ssh` command will ssh to all hosts in TPUVM Pod and run the same command at the same time..
183+
184+
> **NOTE:** You need to run run above `gcloud` command outside of the TPUVM vm.
185+
186+
The model code and training scirpt is the same for the multi-process training and the multi-host training. PyTorch/XLA and the underlying infrastructure will make sure each device is aware of the global topology and each device's local and global ordinal. Cross-device communication will happen across all devices instead of local devices.
187+
188+
For more details regarding PJRT runtime and how to run it on pod, please refer to this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpu). For more information about PyTorch/XLA and TPU pod and a complete guide to run a resnet50 with fakedata on TPU pod, please refer to this [guide](https://cloud.google.com/tpu/docs/pytorch-pods).
189+
155190
## XLA Tensor Deep Dive
156191

157192
Using XLA tensors and devices requires changing only a few lines of code. But

README.md

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,32 @@ running on Cloud TPUs and learn how to use Cloud TPUs as PyTorch devices:
2424

2525
The rest of this README covers:
2626

27+
* [User Guide & Best Practices](#user-guide--best-practices)
2728
* [Running PyTorch on Cloud TPUs and GPU](#running-pytorchxla-on-cloud-tpu-and-gpu)
2829
Google Cloud also runs networks faster than Google Colab.
2930
* [Available docker images and wheels](#available-docker-images-and-wheels)
30-
* [API & Best Practices](#api--best-practices)
3131
* [Performance Profiling and Auto-Metrics Analysis](#performance-profiling-and-auto-metrics-analysis)
3232
* [Troubleshooting](#troubleshooting)
3333
* [Providing Feedback](#providing-feedback)
3434
* [Building and Contributing to PyTorch/XLA](#contributing)
35+
* [Additional Reads](#additional-reads)
3536

3637

3738

3839
Additional information on PyTorch/XLA, including a description of its
3940
semantics and functions, is available at [PyTorch.org](http://pytorch.org/xla/).
4041

42+
## User Guide & Best Practices
43+
44+
Our comprehensive user guides are available at:
45+
46+
[Documentation for the latest release](https://pytorch.org/xla)
47+
48+
[Documentation for master branch](https://pytorch.org/xla/master)
49+
50+
See the [API Guide](API_GUIDE.md) for best practices when writing networks that
51+
run on XLA devices(TPU, GPU, CPU and...)
52+
4153
## Running PyTorch/XLA on Cloud TPU and GPU
4254

4355
* [Running on a single Cloud TPU](#running-on-a-single-cloud-tpu-vm)
@@ -144,17 +156,6 @@ pip3 install torch_xla[tpuvm]
144156

145157
This is only required on Cloud TPU VMs.
146158

147-
## API & Best Practices
148-
149-
In general PyTorch/XLA follows PyTorch APIs, some additional torch_xla specific APIs are available at:
150-
151-
[Documentation for the latest release](https://pytorch.org/xla)
152-
153-
[Documentation for master branch](https://pytorch.org/xla/master)
154-
155-
See the [API Guide](API_GUIDE.md) for best practices when writing networks that
156-
run on Cloud TPUs and Cloud TPU Pods.
157-
158159
## Performance Profiling and Auto-Metrics Analysis
159160

160161
With PyTorch/XLA we provide a set of performance profiling tooling and auto-metrics analysis which you can check the following resources:
@@ -181,3 +182,10 @@ See the [contribution guide](CONTRIBUTING.md).
181182

182183
## Disclaimer
183184
This repository is jointly operated and maintained by Google, Facebook and a number of individual contributors listed in the [CONTRIBUTORS](https://github.com/pytorch/xla/graphs/contributors) file. For questions directed at Facebook, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository [here](https://github.com/pytorch/xla/issues).
185+
186+
## Additional Reads
187+
You can find additional useful reading materials in
188+
* [Performance debugging on Cloud TPU VM](https://cloud.google.com/blog/topics/developers-practitioners/pytorchxla-performance-debugging-tpu-vm-part-1)
189+
* [Lazy tensor intro](https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/)
190+
* [Scaling deep learning workloads with PyTorch / XLA and Cloud TPU VM](https://cloud.google.com/blog/topics/developers-practitioners/scaling-deep-learning-workloads-pytorch-xla-and-cloud-tpu-vm)
191+
* [Scaling PyTorch models on Cloud TPUs with FSDP](https://pytorch.org/blog/scaling-pytorch-models-on-cloud-tpus-with-fsdp/)

docs/dynamo.md

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
## TorchDynamo(torch.compile) integration in PyTorch XLA
2+
3+
Torchdynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. It provides a clean API for compiler backends to hook in and its biggest feature is to dynamically modify Python bytecode right before it is executed. In the pytorch/xla 2.0 release, PyTorch/XLA provided an experimental backend for the TorchDynamo for both inference and training.
4+
5+
The way that XLA bridge works is that Dynamo will provide a TorchFX graph when it recognizes a model pattern and PyTorch/XLA will use existing Lazy Tensor technology to compile the FX graph and return the compiled function.
6+
7+
### Inference
8+
Here is a small code example of running resnet18 with `torch.compile`
9+
10+
```python
11+
import torch
12+
imprt torchvision
13+
import torch_xla.core.xla_model as xm
14+
15+
def eval_model(loader):
16+
device = xm.xla_device()
17+
xla_resnet18 = torchvision.models.resnet18().to(device)
18+
xla_resnet18.eval()
19+
dynamo_resnet18 = torch.compile(
20+
xla_resnet18, backend='torchxla_trace_once')
21+
for data, _ in loader:
22+
output = dynamo_resnet18(data)
23+
```
24+
> **NOTE:** inference backend name `torchxla_trace_once` is subject to change.
25+
26+
With the `torch.compile` you will see that PyTorch/XLA only traces the resent18 model once during the init time and executes the compiled binary everytime `dynamo_resnet18` is invoked, instead of tracing the model every time. Note that currently Dynamo does not support fallback so if there is an op that can not be traced by XLA, it will error out. We will fix this issue in the upcoming 2.1 release. Here is a inference speed analysis to compare Dynamo and Lazy using torch bench on Cloud TPU v4-8
27+
28+
| model | Speed up |
29+
| --- | ----------- |
30+
resnet18 | 1.768
31+
resnet50 | 1.61
32+
resnext50_32x4d | 1.328
33+
alexnet | 1.261
34+
mobilenet_v2 | 2.017
35+
mnasnet1_0 | 1.686
36+
vgg16 | 1.155
37+
BERT_pytorch | 3.502
38+
squeezenet1_1 | 1.674
39+
timm_vision_transformer | 3.138
40+
average | 1.9139
41+
42+
### Training
43+
PyTorch/XLA also supports Dynamo for training, but it is very experimental and we are working with the PyTorch Compiler team to iterate on the implementation. On the 2.0 release it only supports forward and backward pass but not the optimizer. Here is an example of training a resnet18 with `torch.compile`
44+
45+
```python
46+
import torch
47+
imprt torchvision
48+
import torch_xla.core.xla_model as xm
49+
50+
def train_model(model, data, target):
51+
loss_fn = torch.nn.CrossEntropyLoss()
52+
pred = model(data)
53+
loss = loss_fn(pred, target)
54+
loss.backward()
55+
return pred
56+
57+
def train_model_main(loader):
58+
device = xm.xla_device()
59+
xla_resnet18 = torchvision.models.resnet18().to(device)
60+
xla_resnet18.train()
61+
dynamo_train_model = torch.compile(
62+
train_model, backend='aot_torchxla_trace_once')
63+
for data, target in loader:
64+
output = dynamo_train_model(xla_resnet18, data, target)
65+
```
66+
67+
> **NOTE:** Backend we used here is `aot_torchxla_trace_once`(subject to change) instead of `torchxla_trace_once`
68+
69+
We expect to extract and execute 3 graphs per training step instead of one training step if you use the Lazy tensor. Here is a training speed analysis to compare Dynamo and Lazy using a torch bench on Cloud TPU v4-8.
70+
71+
| model | Speed up |
72+
| --- | ----------- |
73+
resnet50 | 0.937
74+
resnet18 | 1.003
75+
BERT_pytorch | 1.869
76+
resnext50_32x4d | 1.139
77+
alexnet | 0.802
78+
mobilenet_v2 | 0.672
79+
mnasnet1_0 | 0.967
80+
vgg16 | 0.742
81+
timm_vision_transformer | 1.69
82+
squeezenet1_1 | 0.958
83+
average | 1.0779
84+
85+
> **NOTE:** We run each model's fwd and bwd for a single step and then collect the e2e time. In the real world we will run multiple steps at each training job which can easily hide the tracing cost from execution(since it is async). Lazy Tensor will have much better performance in that scenario.
86+
87+
We are currently working on the optimizer support and that will be availiable on nightly soon but won't be in the 2.0 release.
88+
89+
### Take away
90+
TorchDynamo provides a really promising way for the compiler backend to hide the complexity from the user and easily retrieve the modeling code in a graph format. Compared with PyTorch/XLA’s traditional Lazy Tensor way of extracting the graph, TorchDynamo can skip the graph tracing for every iteration hence provide a much better inference response time. However TorchDynamo does not trace the communication ops(like `all_reduce` and `all_gather`) yet and it provides separate graphs for the forward and the backward which hurts xla performance. These feature gaps compared to Lazy Tensor makes it less efficient in real world training use cases, especially the tracing cost can be overlapped with the execution in training. The PyTorch/XLA team will keep investing in TorchDynamo and work with upstream to mature the training story.

docs/source/index.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ test
9898

9999
.. mdinclude:: ../../TROUBLESHOOTING.md
100100
.. mdinclude:: ../pjrt.md
101-
.. mdinclude:: ../ddp.md
101+
.. mdinclude:: ../dynamo.md
102102
.. mdinclude:: ../fsdp.md
103+
.. mdinclude:: ../ddp.md
103104
.. mdinclude:: ../gpu.md

0 commit comments

Comments
 (0)