Skip to content

Commit 7073e28

Browse files
committed
updates
1 parent 3042c20 commit 7073e28

File tree

1 file changed

+42
-8
lines changed

1 file changed

+42
-8
lines changed

docs/amp.md

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# AMP (Automatic Mixed Precision) with Pytorch/XLA
22

3-
Pytorch/XLA's AMP extends [Pytorch's AMP package](https://pytorch.org/docs/stable/amp.html) with support for automatic mixed precision on XLA:GPU and XLA:TPU devices. This document describes how to use AMP on XLA devices and best practices.
4-
5-
TODO: talk about how we follow PYtorch's rules and blah blah.
3+
Pytorch/XLA's AMP extends [Pytorch's AMP package](https://pytorch.org/docs/stable/amp.html) with support for automatic mixed precision on XLA:GPU and XLA:TPU devices.
4+
AMP is used to accelerate training and inference by executing certain operations in `float32` and other operations in a lower precision datatype (`float16` or `bfloat16` depending on hardware support).
5+
This document describes how to use AMP on XLA devices and best practices.
66

77
## AMP for XLA:TPU
8-
TPUs natively support [bfloat16](https://cloud.google.com/tpu/docs/bfloat16) and float32 datatypes.
8+
AMP on TPUs automatically casts operations to run in either `float32` or `bfloat16` because TPUs natively support bfloat16. A simple TPU AMP example is below:
99

1010
```
1111
# Creates model and optimizer in default precision
@@ -25,14 +25,40 @@ for input, target in data:
2525
loss.backward()
2626
xm.optimizer_step.(optimizer)
2727
```
28+
`autocast(xm.xla_device())` aliases `torch.amp.autocast('xla')` when the XLA Device is a TPU. Alternatively, if a script is only used with TPUs, then `torch.amp.autocast('xla')` can be directly used.
29+
30+
Please file an issue or submit a pull request if there is an operator that should be autocasted that is not included.
31+
2832

2933
### Best Practices
30-
1. Do not set `XLA_USE_BF16` flag when using AMP on TPUs. This will override the per-operator precision settings provided by AMP and cause all operators to execute in bfloat16.
31-
2. Since TPU's use bfloat16 mixed precision, gradient scaling is not required.
32-
TODO: What is the error message?
33-
3. Pytorch/XLA provides modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) that avoid the additional sync between device and host.
34+
1. `autocast` should wrap only the forward pass(es) and loss computation(s) of the network. Backward ops run in the same type that autocast used for the corresponding forward ops.
35+
2. Do not set `XLA_USE_BF16` flag when using AMP on TPUs. This will override the per-operator precision settings provided by AMP and cause all operators to execute in bfloat16.
36+
3. Since TPU's use bfloat16 mixed precision, gradient scaling is not necessary.
37+
4. Pytorch/XLA provides modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) that avoid the additional sync between device and host.
38+
39+
### Supported Operators
40+
AMP on TPUs operates like Pytorch's AMP. Rules for how autocasting is applied is summarized below:
41+
42+
Only out-of-place ops and Tensor methods are eligible to be autocasted. In-place variants and calls that explicitly supply an out=... Tensor are allowed in autocast-enabled regions, but won’t go through autocasting. For example, in an autocast-enabled region a.addmm(b, c) can autocast, but a.addmm_(b, c) and a.addmm(b, c, out=d) cannot. For best performance and stability, prefer out-of-place ops in autocast-enabled regions.
43+
44+
Ops that run in float64 or non-floating-point dtypes are not eligible, and will run in these types whether or not autocast is enabled. Additionally, Ops called with an explicit dtype=... argument are not eligible, and will produce output that respects the dtype argument.
45+
46+
Ops not listed below do not go through autocasting. They run in the type defined by their inputs. Autocasting may still change the type in which unlisted ops run if they’re downstream from autocasted ops.
47+
48+
**Ops that autocast to `bfloat16`:**
49+
50+
`__matmul__`, `addbmm`, `addmm`, `addmv`, `addr`, `baddbmm`,` bmm`, `conv1d`, `conv2d`, `conv3d`, `conv_transpose1d`, `conv_transpose2d`, `conv_transpose3d`, `linear`, `matmul`, `mm`, `relu`, `prelu`, `max_pool2d`
51+
52+
**Ops that autocast to `float32`:**
53+
54+
`batch_norm`, `log_softmax`, `binary_cross_entropy`, `binary_cross_entropy_with_logits`, `prod`, `cdist`, `trace`, `chloesky` ,`inverse`, `reflection_pad`, `replication_pad`, `mse_loss`, `cosine_embbeding_loss`, `nll_loss`, `multilabel_margin_loss`, `qr`, `svd`, `triangular_solve`, `linalg_svd`, `linalg_inv_ex`
55+
56+
**Ops that autocast to widest input type:**
57+
58+
`stack`, `cat`, `index_copy`
3459

3560
## AMP for XLA:GPU
61+
AMP on XLA:GPU devices reuse Pytorch's AMP rules. See [Pytorch's AMP documentation](https://pytorch.org/docs/stable/amp.html) for CUDA specific behavior. A simple CUDA AMP example is below:
3662

3763
```
3864
# Creates model and optimizer in default precision
@@ -57,6 +83,14 @@ for input, target in data:
5783
scaler.update()
5884
```
5985

86+
`autocast(xm.xla_device())` aliases `torch.cuda.amp.autocast()` when the XLA Device is a CUDA device. Alternatively, if a script is only used with CUDA devices, then `torch.cuda.amp.autocast` can be directly used.
87+
88+
### Best Practices
89+
1. `autocast` should wrap only the forward pass(es) and loss computation(s) of the network. Backward ops run in the same type that autocast used for the corresponding forward ops.
90+
2. Do not set `XLA_USE_F16` flag when using AMP on Cuda devices. This will override the per-operator precision settings provided by AMP and cause all operators to execute in float16.
91+
3. Use gradient scaling to prevent float16 gradients from underflowing.
92+
4. Pytorch/XLA provides modified version of [optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree) that avoid the additional sync between device and host.
93+
6094
## Examples
6195
Our [mnist training script](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_amp.py) and [imagenet training script](https://github.com/pytorch/xla/blob/master/test/test_train_mp_imagenet_amp.py) demonstrate how AMP is used on both TPUs and GPUs.
6296

0 commit comments

Comments
 (0)