Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions docs/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,24 @@ Notes:
* The ZeRO-3 optimizer should be implemented via nested FSDP with `reshard_after_forward=True`. See `test/test_train_mp_mnist_fsdp_with_ckpt.py` and `test/test_train_mp_imagenet_fsdp.py` for an example.
* For large models that cannot fit into a single TPU memory or the host CPU memory, one should interleave submodule construction with inner FSDP wrapping. See [`FSDPViTModel`](https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py) for an example.
* a simple wrapper `checkpoint_module` is provided (based on `torch_xla.utils.checkpoint.checkpoint` from https://github.com/pytorch/xla/pull/3524) to perform [gradient checkpointing](https://spell.ml/blog/gradient-checkpointing-pytorch-YGypLBAAACEAefHs) over a given `nn.Module` instance. See `test/test_train_mp_mnist_fsdp_with_ckpt.py` and `test/test_train_mp_imagenet_fsdp.py` for an example.
* Auto-wrapping submodules: instead of manually nested FSDP wrapping, one can also specify an `auto_wrap_policy` argument to automatically wrap the submodules with inner FSDP. `size_based_auto_wrap_policy` in `torch_xla.distributed.fsdp.wrap` is an example of `auto_wrap_policy` callable, this policy wraps layers with the number of parameters larger than 100M. `transformer_auto_wrap_policy` in `torch_xla.distributed.fsdp.wrap` is an example of `auto_wrap_policy` callable for transformer-like model architectures.

For example, to automatically wrap all `torch.nn.Conv2d` submodules with inner FSDP, one can use:
```python3
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})
```

Additionally, one can also specify an `auto_wrapper_callable` argument to use a custom callable wrapper for the submodules (the default wrapper is just the `XlaFullyShardedDataParallel` class itself). For example, one can use the following to apply gradient checkpointing (i.e. activation checkpointing/rematerialization) to each auto-wrapped submodule.
```python3
from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
checkpoint_module(m), *args, **kwargs)
```
* When stepping the optimizer, directly call `optimizer.step` and do not call `xm.optimizer_step`. The latter reduces the gradient across ranks, which is not needed for FSDP (where the parameters are already sharded).
* When saving model and optimizer checkpoints during training, each training process needs to save its own checkpoint of the (sharded) model and optimizer state dicts (use `master_only=False` and set different paths for each rank in `xm.save`). When resuming, it needs to load the checkpoint for the corresponding rank.
* Please also save `model.get_shard_metadata()` along with `model.state_dict()` as follows and use `consolidate_sharded_model_checkpoints` to stitch the sharded model checkpoints together into a full model state dict. See `test/test_train_mp_mnist_fsdp_with_ckpt.py` for an example.
```
```python3
ckpt = {
'model': model.state_dict(),
'shard_metadata': model.get_shard_metadata(),
Expand Down Expand Up @@ -86,12 +100,12 @@ python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
--use_nested_fsdp
```
You can also add ` --use_gradient_checkpointing` (which needs to be used along with `--use_nested_fsdp`) to apply gradient checkpointing on the residual blocks.
You can also add ` --use_gradient_checkpointing` (which needs to be used along with `--use_nested_fsdp` or `--auto_wrap_policy`) to apply gradient checkpointing on the residual blocks.

---

### Example training scripts on TPU pod (with 10 billion parameters)

To train large models that cannot fit into a single TPU, one should use nested FSDP (wrapping sub-modules with inner FSDP when building the entire model) to implement the ZeRO-3 algorithm.
To train large models that cannot fit into a single TPU, one should apply auto-wrap or manually wrap the submodules with inner FSDP when building the entire model to implement the ZeRO-3 algorithm.

Please see https://github.com/ronghanghu/vit_10b_fsdp_example for an example of sharded training of a Vision Transformer (ViT) model using this XLA FSDP PR.
69 changes: 57 additions & 12 deletions test/test_train_mp_imagenet_fsdp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import args_parse
from functools import partial

SUPPORTED_MODELS = [
'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
Expand Down Expand Up @@ -38,6 +39,14 @@
'--flatten_parameters': {
'action': 'store_true',
},
'--auto_wrap_policy': {
'choices': ['none', 'size_based', 'type_based'],
'default': 'none',
},
'--auto_wrap_min_num_params': {
'type': int,
'default': 1e6,
},
'--use_nested_fsdp': {
'action': 'store_true',
},
Expand All @@ -54,8 +63,9 @@
'--shard_param_on_dim_0': {
'action': 'store_true',
},
'--pin_layout_in_collective_ops': {
'action': 'store_true',
'--no_pin_layout_in_collective_ops': {
'action': 'store_false',
'dest': 'pin_layout_in_collective_ops',
},
}

Expand Down Expand Up @@ -89,6 +99,8 @@
import torch_xla.test.test_utils as test_utils

from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP, checkpoint_module
from torch_xla.distributed.fsdp.wrap import (size_based_auto_wrap_policy,
transformer_auto_wrap_policy)

DEFAULT_KWARGS = dict(
batch_size=128,
Expand Down Expand Up @@ -215,25 +227,58 @@ def train_imagenet():

device = xm.xla_device()
model = get_model_property('model_fn')()
# Wrap the model with FSDP
# You may wrap all, a subset, or none of the sub-modules with inner FSDPs
# - to implement ZeRO-2, wrap none of the sub-modules
# - to implement ZeRO-3, wrap all of the sub-modules (nested FSDP)
# - you may wrap sub-modules at different granularity (e.g. at each resnet
# stage or each residual block or each conv layer).
# Automatic wrapping sub-modules with inner FSDP
auto_wrap_policy = None
auto_wrapper_callable = None
if FLAGS.auto_wrap_policy != "none":
if FLAGS.auto_wrap_policy == "size_based":
# auto-wrap all sub-modules with a certain number of parameters (default 1e6)
auto_wrap_policy = partial(
size_based_auto_wrap_policy,
min_num_params=FLAGS.auto_wrap_min_num_params)
elif FLAGS.auto_wrap_policy == "type_based":
# auto-wrap all sub-modules in torchvision ResNet's BasicBlock or Bottleneck
# or torchvision transformer's EncoderBlock as an example
# (transformer_auto_wrap_policy wraps all sub-modules in transformer_layer_cls)
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
torchvision.models.resnet.BasicBlock,
torchvision.models.resnet.Bottleneck,
torchvision.models.vision_transformer.EncoderBlock,
})
else:
raise Exception(f"Invalid auto-wrap policy: {FLAGS.auto_wrap_policy}")
if FLAGS.use_gradient_checkpointing:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
auto_wrapper_callable = lambda m, *args, **kwargs: FSDP(
checkpoint_module(m), *args, **kwargs)

fsdp_wrap = lambda m: FSDP(
m,
compute_dtype=getattr(torch, FLAGS.compute_dtype),
fp32_reduce_scatter=FLAGS.fp32_reduce_scatter,
flatten_parameters=FLAGS.flatten_parameters,
shard_param_on_dim_0=FLAGS.shard_param_on_dim_0,
pin_layout_in_collective_ops=FLAGS.pin_layout_in_collective_ops)
# Apply gradient checkpointing to sub-modules if specified
grad_ckpt_wrap = checkpoint_module if FLAGS.use_gradient_checkpointing else (
lambda x: x)
pin_layout_in_collective_ops=FLAGS.pin_layout_in_collective_ops,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable)
# Manually wrapping sub-modules with inner FSDP (if not using auto-wrap)
# (in this case, the sub-modules should be wrapped before the base model)
if FLAGS.use_nested_fsdp:
assert FLAGS.auto_wrap_policy == "none", \
"--use_nested_fsdp is for manual nested wrapping should only be used" \
" without auto-wrapping"
# You may wrap all, a subset, or none of the sub-modules with inner FSDPs
# - to implement ZeRO-2, wrap none of the sub-modules
# - to implement ZeRO-3, wrap all of the sub-modules (nested FSDP)
# - you may wrap sub-modules at different granularity (e.g. at each resnet
# stage or each residual block or each conv layer).
# Here we apply inner FSDP at the level of child modules for ZeRO-3, which
# corresponds to different stages in resnet (i.e. Stage 1 to 5).
# Apply gradient checkpointing to nested-wrapped sub-modules if specified
grad_ckpt_wrap = checkpoint_module if FLAGS.use_gradient_checkpointing else (
lambda x: x)
for submodule_name, submodule in model.named_children():
if sum(p.numel() for p in submodule.parameters()) == 0:
# Skip those submodules without parameters (i.e. no need to shard them)
Expand Down
55 changes: 48 additions & 7 deletions test/test_train_mp_mnist_fsdp_with_ckpt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import args_parse
from functools import partial

MODEL_OPTS = {
'--flatten_parameters': {
'action': 'store_true',
},
'--auto_wrap_policy': {
'choices': ['none', 'size_based', 'type_based'],
'default': 'none',
},
'--auto_wrap_min_num_params': {
'type': int,
'default': 1000,
},
'--use_nested_fsdp': {
'action': 'store_true',
},
Expand All @@ -28,8 +37,9 @@
'--shard_param_on_dim_0': {
'action': 'store_true',
},
'--pin_layout_in_collective_ops': {
'action': 'store_true',
'--no_pin_layout_in_collective_ops': {
'action': 'store_false',
'dest': 'pin_layout_in_collective_ops',
},
}

Expand Down Expand Up @@ -64,6 +74,8 @@
consolidate_sharded_model_checkpoints,
checkpoint_module,
)
from torch_xla.distributed.fsdp.wrap import (size_based_auto_wrap_policy,
transformer_auto_wrap_policy)


class MNIST(nn.Module):
Expand Down Expand Up @@ -153,19 +165,48 @@ def train_mnist(flags, **kwargs):

device = xm.xla_device()
model = MNIST()
# Wrap the model with FSDP
# Automatic wrapping sub-modules with inner FSDP
auto_wrap_policy = None
auto_wrapper_callable = None
if flags.auto_wrap_policy != "none":
if flags.auto_wrap_policy == "size_based":
# auto-wrap all sub-modules with a certain number of parameters (default 1000)
# (in practice, one should set a larger min_num_params such as 1e8)
auto_wrap_policy = partial(
size_based_auto_wrap_policy,
min_num_params=flags.auto_wrap_min_num_params)
elif flags.auto_wrap_policy == "type_based":
# auto-wrap all nn.Conv2d and nn.Linear sub-modules as an example
# (transformer_auto_wrap_policy wraps all sub-modules in transformer_layer_cls)
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={nn.Conv2d, nn.Linear})
else:
raise Exception(f"Invalid auto-wrap policy: {flags.auto_wrap_policy}")
if flags.use_gradient_checkpointing:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
auto_wrapper_callable = lambda m, *args, **kwargs: FSDP(
checkpoint_module(m), *args, **kwargs)

fsdp_wrap = lambda m: FSDP(
m,
compute_dtype=getattr(torch, flags.compute_dtype),
fp32_reduce_scatter=flags.fp32_reduce_scatter,
flatten_parameters=flags.flatten_parameters,
shard_param_on_dim_0=flags.shard_param_on_dim_0,
pin_layout_in_collective_ops=flags.pin_layout_in_collective_ops)
# Apply gradient checkpointing to sub-modules if specified
grad_ckpt_wrap = checkpoint_module if flags.use_gradient_checkpointing else (
lambda x: x)
pin_layout_in_collective_ops=flags.pin_layout_in_collective_ops,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable)
# Manually wrapping sub-modules with inner FSDP (if not using auto-wrap)
# (in this case, the sub-modules should be wrapped before the base model)
if flags.use_nested_fsdp:
assert flags.auto_wrap_policy == "none", \
"--use_nested_fsdp is for manual nested wrapping should only be used" \
" without auto-wrapping"
# Wrap a few sub-modules with inner FSDP (to implement ZeRO-3)
# Apply gradient checkpointing to nested-wrapped sub-modules if specified
grad_ckpt_wrap = checkpoint_module if flags.use_gradient_checkpointing else (
lambda x: x)
# Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP
model.conv1 = fsdp_wrap(grad_ckpt_wrap(model.conv1))
model.conv2 = fsdp_wrap(grad_ckpt_wrap(model.conv2))
Expand Down
Loading