diff --git a/docs/fsdp.md b/docs/fsdp.md index 152e038af8c..f9a49812e12 100644 --- a/docs/fsdp.md +++ b/docs/fsdp.md @@ -22,10 +22,24 @@ Notes: * The ZeRO-3 optimizer should be implemented via nested FSDP with `reshard_after_forward=True`. See `test/test_train_mp_mnist_fsdp_with_ckpt.py` and `test/test_train_mp_imagenet_fsdp.py` for an example. * For large models that cannot fit into a single TPU memory or the host CPU memory, one should interleave submodule construction with inner FSDP wrapping. See [`FSDPViTModel`](https://github.com/ronghanghu/vit_10b_fsdp_example/blob/master/run_vit_training.py) for an example. * a simple wrapper `checkpoint_module` is provided (based on `torch_xla.utils.checkpoint.checkpoint` from https://github.com/pytorch/xla/pull/3524) to perform [gradient checkpointing](https://spell.ml/blog/gradient-checkpointing-pytorch-YGypLBAAACEAefHs) over a given `nn.Module` instance. See `test/test_train_mp_mnist_fsdp_with_ckpt.py` and `test/test_train_mp_imagenet_fsdp.py` for an example. +* Auto-wrapping submodules: instead of manually nested FSDP wrapping, one can also specify an `auto_wrap_policy` argument to automatically wrap the submodules with inner FSDP. `size_based_auto_wrap_policy` in `torch_xla.distributed.fsdp.wrap` is an example of `auto_wrap_policy` callable, this policy wraps layers with the number of parameters larger than 100M. `transformer_auto_wrap_policy` in `torch_xla.distributed.fsdp.wrap` is an example of `auto_wrap_policy` callable for transformer-like model architectures. + +For example, to automatically wrap all `torch.nn.Conv2d` submodules with inner FSDP, one can use: +```python3 +from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy +auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d}) +``` + +Additionally, one can also specify an `auto_wrapper_callable` argument to use a custom callable wrapper for the submodules (the default wrapper is just the `XlaFullyShardedDataParallel` class itself). For example, one can use the following to apply gradient checkpointing (i.e. activation checkpointing/rematerialization) to each auto-wrapped submodule. +```python3 +from torch_xla.distributed.fsdp import checkpoint_module +auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel( + checkpoint_module(m), *args, **kwargs) +``` * When stepping the optimizer, directly call `optimizer.step` and do not call `xm.optimizer_step`. The latter reduces the gradient across ranks, which is not needed for FSDP (where the parameters are already sharded). * When saving model and optimizer checkpoints during training, each training process needs to save its own checkpoint of the (sharded) model and optimizer state dicts (use `master_only=False` and set different paths for each rank in `xm.save`). When resuming, it needs to load the checkpoint for the corresponding rank. * Please also save `model.get_shard_metadata()` along with `model.state_dict()` as follows and use `consolidate_sharded_model_checkpoints` to stitch the sharded model checkpoints together into a full model state dict. See `test/test_train_mp_mnist_fsdp_with_ckpt.py` for an example. -``` +```python3 ckpt = { 'model': model.state_dict(), 'shard_metadata': model.get_shard_metadata(), @@ -86,12 +100,12 @@ python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \ --lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \ --use_nested_fsdp ``` -You can also add ` --use_gradient_checkpointing` (which needs to be used along with `--use_nested_fsdp`) to apply gradient checkpointing on the residual blocks. +You can also add ` --use_gradient_checkpointing` (which needs to be used along with `--use_nested_fsdp` or `--auto_wrap_policy`) to apply gradient checkpointing on the residual blocks. --- ### Example training scripts on TPU pod (with 10 billion parameters) -To train large models that cannot fit into a single TPU, one should use nested FSDP (wrapping sub-modules with inner FSDP when building the entire model) to implement the ZeRO-3 algorithm. +To train large models that cannot fit into a single TPU, one should apply auto-wrap or manually wrap the submodules with inner FSDP when building the entire model to implement the ZeRO-3 algorithm. Please see https://github.com/ronghanghu/vit_10b_fsdp_example for an example of sharded training of a Vision Transformer (ViT) model using this XLA FSDP PR. diff --git a/test/test_train_mp_imagenet_fsdp.py b/test/test_train_mp_imagenet_fsdp.py index f55db728a56..400e7eaabf3 100644 --- a/test/test_train_mp_imagenet_fsdp.py +++ b/test/test_train_mp_imagenet_fsdp.py @@ -1,4 +1,5 @@ import args_parse +from functools import partial SUPPORTED_MODELS = [ 'alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', @@ -38,6 +39,14 @@ '--flatten_parameters': { 'action': 'store_true', }, + '--auto_wrap_policy': { + 'choices': ['none', 'size_based', 'type_based'], + 'default': 'none', + }, + '--auto_wrap_min_num_params': { + 'type': int, + 'default': 1e6, + }, '--use_nested_fsdp': { 'action': 'store_true', }, @@ -54,8 +63,9 @@ '--shard_param_on_dim_0': { 'action': 'store_true', }, - '--pin_layout_in_collective_ops': { - 'action': 'store_true', + '--no_pin_layout_in_collective_ops': { + 'action': 'store_false', + 'dest': 'pin_layout_in_collective_ops', }, } @@ -89,6 +99,8 @@ import torch_xla.test.test_utils as test_utils from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP, checkpoint_module +from torch_xla.distributed.fsdp.wrap import (size_based_auto_wrap_policy, + transformer_auto_wrap_policy) DEFAULT_KWARGS = dict( batch_size=128, @@ -215,25 +227,58 @@ def train_imagenet(): device = xm.xla_device() model = get_model_property('model_fn')() - # Wrap the model with FSDP - # You may wrap all, a subset, or none of the sub-modules with inner FSDPs - # - to implement ZeRO-2, wrap none of the sub-modules - # - to implement ZeRO-3, wrap all of the sub-modules (nested FSDP) - # - you may wrap sub-modules at different granularity (e.g. at each resnet - # stage or each residual block or each conv layer). + # Automatic wrapping sub-modules with inner FSDP + auto_wrap_policy = None + auto_wrapper_callable = None + if FLAGS.auto_wrap_policy != "none": + if FLAGS.auto_wrap_policy == "size_based": + # auto-wrap all sub-modules with a certain number of parameters (default 1e6) + auto_wrap_policy = partial( + size_based_auto_wrap_policy, + min_num_params=FLAGS.auto_wrap_min_num_params) + elif FLAGS.auto_wrap_policy == "type_based": + # auto-wrap all sub-modules in torchvision ResNet's BasicBlock or Bottleneck + # or torchvision transformer's EncoderBlock as an example + # (transformer_auto_wrap_policy wraps all sub-modules in transformer_layer_cls) + auto_wrap_policy = partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + torchvision.models.resnet.BasicBlock, + torchvision.models.resnet.Bottleneck, + torchvision.models.vision_transformer.EncoderBlock, + }) + else: + raise Exception(f"Invalid auto-wrap policy: {FLAGS.auto_wrap_policy}") + if FLAGS.use_gradient_checkpointing: + # Apply gradient checkpointing to auto-wrapped sub-modules if specified + auto_wrapper_callable = lambda m, *args, **kwargs: FSDP( + checkpoint_module(m), *args, **kwargs) + fsdp_wrap = lambda m: FSDP( m, compute_dtype=getattr(torch, FLAGS.compute_dtype), fp32_reduce_scatter=FLAGS.fp32_reduce_scatter, flatten_parameters=FLAGS.flatten_parameters, shard_param_on_dim_0=FLAGS.shard_param_on_dim_0, - pin_layout_in_collective_ops=FLAGS.pin_layout_in_collective_ops) - # Apply gradient checkpointing to sub-modules if specified - grad_ckpt_wrap = checkpoint_module if FLAGS.use_gradient_checkpointing else ( - lambda x: x) + pin_layout_in_collective_ops=FLAGS.pin_layout_in_collective_ops, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable) + # Manually wrapping sub-modules with inner FSDP (if not using auto-wrap) + # (in this case, the sub-modules should be wrapped before the base model) if FLAGS.use_nested_fsdp: + assert FLAGS.auto_wrap_policy == "none", \ + "--use_nested_fsdp is for manual nested wrapping should only be used" \ + " without auto-wrapping" + # You may wrap all, a subset, or none of the sub-modules with inner FSDPs + # - to implement ZeRO-2, wrap none of the sub-modules + # - to implement ZeRO-3, wrap all of the sub-modules (nested FSDP) + # - you may wrap sub-modules at different granularity (e.g. at each resnet + # stage or each residual block or each conv layer). # Here we apply inner FSDP at the level of child modules for ZeRO-3, which # corresponds to different stages in resnet (i.e. Stage 1 to 5). + # Apply gradient checkpointing to nested-wrapped sub-modules if specified + grad_ckpt_wrap = checkpoint_module if FLAGS.use_gradient_checkpointing else ( + lambda x: x) for submodule_name, submodule in model.named_children(): if sum(p.numel() for p in submodule.parameters()) == 0: # Skip those submodules without parameters (i.e. no need to shard them) diff --git a/test/test_train_mp_mnist_fsdp_with_ckpt.py b/test/test_train_mp_mnist_fsdp_with_ckpt.py index 43203c0b20e..c9db2d5804e 100644 --- a/test/test_train_mp_mnist_fsdp_with_ckpt.py +++ b/test/test_train_mp_mnist_fsdp_with_ckpt.py @@ -1,9 +1,18 @@ import args_parse +from functools import partial MODEL_OPTS = { '--flatten_parameters': { 'action': 'store_true', }, + '--auto_wrap_policy': { + 'choices': ['none', 'size_based', 'type_based'], + 'default': 'none', + }, + '--auto_wrap_min_num_params': { + 'type': int, + 'default': 1000, + }, '--use_nested_fsdp': { 'action': 'store_true', }, @@ -28,8 +37,9 @@ '--shard_param_on_dim_0': { 'action': 'store_true', }, - '--pin_layout_in_collective_ops': { - 'action': 'store_true', + '--no_pin_layout_in_collective_ops': { + 'action': 'store_false', + 'dest': 'pin_layout_in_collective_ops', }, } @@ -64,6 +74,8 @@ consolidate_sharded_model_checkpoints, checkpoint_module, ) +from torch_xla.distributed.fsdp.wrap import (size_based_auto_wrap_policy, + transformer_auto_wrap_policy) class MNIST(nn.Module): @@ -153,19 +165,48 @@ def train_mnist(flags, **kwargs): device = xm.xla_device() model = MNIST() - # Wrap the model with FSDP + # Automatic wrapping sub-modules with inner FSDP + auto_wrap_policy = None + auto_wrapper_callable = None + if flags.auto_wrap_policy != "none": + if flags.auto_wrap_policy == "size_based": + # auto-wrap all sub-modules with a certain number of parameters (default 1000) + # (in practice, one should set a larger min_num_params such as 1e8) + auto_wrap_policy = partial( + size_based_auto_wrap_policy, + min_num_params=flags.auto_wrap_min_num_params) + elif flags.auto_wrap_policy == "type_based": + # auto-wrap all nn.Conv2d and nn.Linear sub-modules as an example + # (transformer_auto_wrap_policy wraps all sub-modules in transformer_layer_cls) + auto_wrap_policy = partial( + transformer_auto_wrap_policy, + transformer_layer_cls={nn.Conv2d, nn.Linear}) + else: + raise Exception(f"Invalid auto-wrap policy: {flags.auto_wrap_policy}") + if flags.use_gradient_checkpointing: + # Apply gradient checkpointing to auto-wrapped sub-modules if specified + auto_wrapper_callable = lambda m, *args, **kwargs: FSDP( + checkpoint_module(m), *args, **kwargs) + fsdp_wrap = lambda m: FSDP( m, compute_dtype=getattr(torch, flags.compute_dtype), fp32_reduce_scatter=flags.fp32_reduce_scatter, flatten_parameters=flags.flatten_parameters, shard_param_on_dim_0=flags.shard_param_on_dim_0, - pin_layout_in_collective_ops=flags.pin_layout_in_collective_ops) - # Apply gradient checkpointing to sub-modules if specified - grad_ckpt_wrap = checkpoint_module if flags.use_gradient_checkpointing else ( - lambda x: x) + pin_layout_in_collective_ops=flags.pin_layout_in_collective_ops, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable) + # Manually wrapping sub-modules with inner FSDP (if not using auto-wrap) + # (in this case, the sub-modules should be wrapped before the base model) if flags.use_nested_fsdp: + assert flags.auto_wrap_policy == "none", \ + "--use_nested_fsdp is for manual nested wrapping should only be used" \ + " without auto-wrapping" # Wrap a few sub-modules with inner FSDP (to implement ZeRO-3) + # Apply gradient checkpointing to nested-wrapped sub-modules if specified + grad_ckpt_wrap = checkpoint_module if flags.use_gradient_checkpointing else ( + lambda x: x) # Note: wrap with `checkpoint_module` first BEFORE wrapping with FSDP model.conv1 = fsdp_wrap(grad_ckpt_wrap(model.conv1)) model.conv2 = fsdp_wrap(grad_ckpt_wrap(model.conv2)) diff --git a/torch_xla/distributed/fsdp/wrap.py b/torch_xla/distributed/fsdp/wrap.py new file mode 100644 index 00000000000..64d6099c272 --- /dev/null +++ b/torch_xla/distributed/fsdp/wrap.py @@ -0,0 +1,216 @@ +# This file is largely adapted from ``torch.distributed.fsdp.wrap`` in +# https://github.com/pytorch/pytorch/blob/v1.13.0/torch/distributed/fsdp/wrap.py + +from typing import Any, Callable, Set, Tuple, Optional, Type, cast + +import torch.nn as nn + + +def always_wrap_policy(*args, **kwargs) -> bool: + """ + A simple wrapper policy that always returns ``True``, + i.e. when passed as the `auto_wrap_policy` into FSDP, + this will result in all submodules being wrapped as + distinct FSDP instances. + """ + return True + + +def lambda_auto_wrap_policy(module: nn.Module, recurse: bool, + unwrapped_params: int, lambda_fn: Callable) -> bool: + """ + A convenient auto wrap policy to wrap submodules based on an arbitrary user + function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as + a `wrapper_cls` unit. + Return if a module should be wrapped during auto wrapping. + The first three parameters are required by :func:`_recursive_wrap`. + Args: + module (nn.Module): + The module to be considered in this decision. + recurse (bool): + Indicate if this is called to make a decision on whether we + should recurse down a subgraph of the module structure. + If False, it means this function is called to make a decision + on whether we should wrap the said module. + unwrapped_params (int): + The number of parameters yet to be wrapped in this module. + lambda_fn (Callable[nn.Module] -> bool): + If this returns ``True``, this module will be wrapped by + wrapper_cls individually. + """ + if recurse: + # always recurse + return True + else: + # if not recursing, decide whether we should wrap for the leaf node or reminder + return lambda_fn(module) + + +def transformer_auto_wrap_policy( + module: nn.Module, + recurse: bool, + unwrapped_params: int, + transformer_layer_cls: Set[Type[nn.Module]], +) -> bool: + """ + A convenient auto wrap policy for transformer models. If the submodule + is an instance of transformer_layer_cls, the submodule will be wrapped + as a FSDP unit. Otherwise, all the other remainder submodules are wrapped + by the outermost FSDP unit. + Return if a module should be wrapped during FSDP auto wrapping. + The first three parameters are required by :func:`_recursive_wrap`. + Args: + module (nn.Module): + The module to be considered in this decision. + recurse (bool): + Indicate if this is called to make a decision on whether we + should recurse down a subgraph of the module structure. + If False, it means this function is called to make a decision + on whether we should wrap the said module. + unwrapped_params (int): + The number of parameters yet to be wrapped in this module. + transformer_layer_cls (int): + Submodules with one of the `transformer_layer_cls` names + will be wrapped as separated FSDP units + """ + if recurse: + # always recurse + return True + else: + # if not recursing, decide whether we should wrap for the leaf node or reminder + return isinstance(module, tuple(transformer_layer_cls)) + + +def size_based_auto_wrap_policy( + module: nn.Module, + recurse: bool, + unwrapped_params: int, + # These are customizable for this policy function. + min_num_params: int = int(1e8), + force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, + exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, +) -> bool: + """A size based auto_wrap_policy function for FSDP API. + Return if a module should be wrapped during FSDP auto wrapping. + The first three parameters are used by :func:`_recursive_wrap`. + Args: + module (nn.Module): + The module to be considered in this decision. + recurse (bool): + Indicate if this is called to make a decision on whether we + should recurse down a subgraph of the module structure. + If False, it means this function is called to make a decision + on whether we should wrap the said module. + unwrapped_params (int): + The number of parameters yet to be wrapped in this module. + min_num_params (int): + Customizable policy input. It controls the size threshold + on how big should a module be to be considered wrapped. + force_leaf_modules (Set[Type[nn.Module]]): set of module types to + keep as leaves, i.e., their children will never be wrapped. + exclude_wrap_modules (Set[Type[nn.Module]]): + Customizable set of module types to be excluded in wrapping. + """ + force_leaf_modules = ( + size_based_auto_wrap_policy.FORCE_LEAF_MODULES + if force_leaf_modules is None else force_leaf_modules) + exclude_wrap_modules = ( + size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES + if exclude_wrap_modules is None else exclude_wrap_modules) + + is_large = unwrapped_params >= min_num_params + if recurse: + # We should recurse if the module is big enough but not in force_leaf_modules list. + return is_large and not isinstance(module, tuple(force_leaf_modules)) + else: + # If we are not recursing, determine if we should wrap. + return is_large and not isinstance(module, tuple(exclude_wrap_modules)) + + +# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported. +size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = { + nn.ModuleList, nn.ModuleDict +} +size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} + + +def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: + assert wrapper_cls is not None + if hasattr(module, '_wrap_overrides'): + # If module has a _wrap_overrides attribute, we force overriding the + # FSDP config with these attributes for this module. Currently this + # is only used to disable mixed precision for BatchNorm when + # auto_wrapping. + overrides = {**kwargs, **module._wrap_overrides} + return wrapper_cls(module, **overrides) + + return wrapper_cls(module, **kwargs) + + +def recursive_wrap(module: nn.Module, + auto_wrap_policy: Callable, + wrapper_cls: Callable, + ignored_modules: Set[nn.Module], + ignored_params: Set[nn.Parameter], + only_wrap_children: bool = False, + **kwargs: Any) -> Tuple[nn.Module, int]: + """ + Automatically wrap child modules of *module* that meet the given + criteria with :func:`auto_wrap`. Does not rely on _ConfigAutoWrap. + Args: + module (nn.Module): + module to recursively wrap + auto_wrap_policy (Callable): + A callable specifying a policy to recursively wrap layers with FSDP. + ignored_modules (Set[torch.nn.Module]): Modules to ignore when + wrapping. + ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when + wrapping; these should be the parameters contained in the modules + in ``ignored_modules``. + Returns: + (nn.Module, int): + Wrapped module and the number parameters wrapped recursively. + """ + assert auto_wrap_policy is not None, "Must specify auto_wrap_policy." + assert wrapper_cls is not None, "Must specify wrapper_cls" + # Make sure no child is already wrapped. + for _, child in module.named_modules(): + if child in ignored_modules: + continue + try: + assert not isinstance(child, cast(type, wrapper_cls)) + except TypeError: + # wrapper_cls is a function as opposed to a class type, just bypass above check. + pass + + # We count all params, assuming none of them are already wrapped. + num_params = sum( + p.numel() for p in module.parameters() if p not in ignored_params) + + if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params): + total_wrapped_params = 0 + # Iterate through the children, recursively wrap if necessary + for name, child in module.named_children(): + if child in ignored_modules: + continue + wrapped_child, num_wrapped_params = recursive_wrap( + module=child, + auto_wrap_policy=auto_wrap_policy, + wrapper_cls=wrapper_cls, + ignored_modules=ignored_modules, + ignored_params=ignored_params, + **kwargs, + ) + setattr(module, name, wrapped_child) + # Keep track of how many parameters have been wrapped + total_wrapped_params += num_wrapped_params + # decide if we need to wrap the current module, + # since the left over parameters exceed the number of params to wrap + remainder = num_params - total_wrapped_params + if not only_wrap_children and auto_wrap_policy( + module=module, recurse=False, unwrapped_params=remainder): + # Leaf node or final wrapping of the remainder both happen here. + return _wrap(module, wrapper_cls, **kwargs), num_params + else: + return module, total_wrapped_params + return module, 0 diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index 2b999893061..83b8676b462 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -35,6 +35,7 @@ from .xla_flatten_params_wrapper import XlaFlattenParamsWrapper from .utils import dummy_all_gather, dummy_all_reduce, dummy_reduce_scatter, apply_xla_patch_to_nn_linear +from .wrap import recursive_wrap FLOAT_DTYPES = [torch.float32, torch.float16, torch.bfloat16] @@ -194,6 +195,49 @@ class XlaFullyShardedDataParallel(nn.Module): dimension (dim 0) *without* flattening them. This is a workaround for those compilers that may have trouble handling flattened parameters. This option has no effect if ``flatten_parameters`` is ``True``. + auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): + A callable specifying a policy to recursively wrap layers with FSDP. + Note that this policy currently will only apply to child modules of + the passed in module. The remainder modules are always wrapped in + the returned FSDP root instance. + ``size_based_auto_wrap_policy`` in ``torch_xla.distributed.fsdp.wrap`` + is an example of ``auto_wrap_policy`` callable, this policy wraps + layers with the number of parameters larger than 100M. + ``transformer_auto_wrap_policy`` in ``torch_xla.distributed.fsdp.wrap`` + is an example of ``auto_wrap_policy`` callable for transformer-like + model architectures. Users can supply the customized + ``auto_wrap_policy`` callable that should accept following arguments: + ``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``, + and return a ``bool`` specifying whether the passed in ``module`` + should be wrapped (if ``recurse=False``) or whether we should recurse + down the subgraph of ``module`` children (if ``recurse=True``). + Extra customized arguments could be added to the customized + ``auto_wrap_policy`` callable as well. It is a good practice to print + out the sharded model and check whether the sharded model is what the + application wants and then adjust accordingly. + Example:: + + def custom_auto_wrap_policy( + module: nn.Module, + recurse: bool, + unwrapped_params: int, + # These are customizable for this policy function. + min_num_params: int = int(1e8), + ) -> bool: + return unwrapped_params >= min_num_params + # Configure a custom min_num_params + auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=1e5) + + auto_wrapper_callable (Optional[Callable]): the wrapper class or callable + used in auto_wrap_policy (default is `XlaFullyShardedDataParallel`) + to when wrapping a submodule. One can specify a different callable + as wrapper. For example, activation checkpointing (rematerialization) + can be applied to each auto-wrapped submodule as follows: + + from torch_xla.distributed.fsdp import checkpoint_module + auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel( + checkpoint_module(m), *args, **kwargs) + """ def __init__( @@ -214,6 +258,8 @@ def __init__( sharding_world_size: Optional[int] = None, shard_param_on_dim_0: bool = False, pin_layout_in_collective_ops: bool = True, + auto_wrap_policy: Optional[Callable] = None, + auto_wrapper_callable: Optional[Callable] = None, _shard_size_multiple: int = 128, _use_xla_patched_linear: bool = True, _debug_dummy_forward_pass: bool = False, @@ -243,6 +289,46 @@ def __init__( "instead of using any of its submodules or its weights).") super().__init__() + + if auto_wrap_policy is not None: + auto_wrap_kwargs = { + "module": module, + "auto_wrap_policy": auto_wrap_policy, + "wrapper_cls": auto_wrapper_callable or XlaFullyShardedDataParallel, + "ignored_modules": [], + "ignored_params": [], + "only_wrap_children": True, # avoid double wrapping the root + } + fsdp_kwargs = dict( + reshard_after_forward=reshard_after_forward, + flatten_parameters=flatten_parameters, + execute_sharding_on_init=execute_sharding_on_init, + optimization_barrier_in_forward=optimization_barrier_in_forward, + optimization_barrier_in_backward=optimization_barrier_in_backward, + mark_step_on_finalization=mark_step_on_finalization, + disable_reshard_on_root=disable_reshard_on_root, + compute_dtype=compute_dtype, + buffer_dtype=buffer_dtype, + fp32_reduce_scatter=fp32_reduce_scatter, + sharding_groups=sharding_groups, + sharding_rank=sharding_rank, + sharding_world_size=sharding_world_size, + shard_param_on_dim_0=shard_param_on_dim_0, + pin_layout_in_collective_ops=pin_layout_in_collective_ops, + # `auto_wrap_policy` doesn't need to be specified in auto-wrapping + # `auto_wrapper_callable`` doesn't need to be specified in auto-wrapping + _shard_size_multiple=_shard_size_multiple, + _use_xla_patched_linear=_use_xla_patched_linear, + _debug_dummy_forward_pass=_debug_dummy_forward_pass, + _debug_msg=_debug_msg, + _debug_print=_debug_print, + _debug_dummy_all_gather_op=_debug_dummy_all_gather_op, + _debug_dummy_all_reduce_op=_debug_dummy_all_reduce_op, + _debug_dummy_reduce_scatter_op=_debug_dummy_reduce_scatter_op, + _debug_dummy_optimization_barrier_op=_debug_dummy_optimization_barrier_op, + ) + self._auto_wrap(auto_wrap_kwargs, fsdp_kwargs) + self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward self.disable_reshard_on_root = disable_reshard_on_root self.flatten_parameters = flatten_parameters @@ -1457,6 +1543,32 @@ def _flatten_and_pad_to_world_size(self, tensor: torch.Tensor, return tensor + def _auto_wrap( + self, + auto_wrap_kwargs: Dict[str, Any], + fsdp_kwargs: Dict[str, Any], + ) -> None: + """ + Recursively auto wraps the root module given by the key "module" in + ``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and + ``fsdp_kwargs``. + Precondition: ``auto_wrap_policy`` contains the arguments expected by + ``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``. + ``fsdp_kwargs`` contains all FSDP arguments except ``module``. + """ + auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"] + root_module = auto_wrap_kwargs["module"] + assert auto_wrap_policy is not None + # For auto wrapping, submodules should not already be wrapped with FSDP + # since double wrapping is not supported + for module_name, module in root_module.named_modules(): + if isinstance(module, XlaFullyShardedDataParallel): + raise ValueError( + f"Expected {module_name} to NOT be FullyShardedDataParallel " + "if using an `auto_wrap_policy`") + + recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs) + def apply_to_tensors( fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]