diff --git a/recipes_source/distributed_checkpoint_recipe.rst b/recipes_source/distributed_checkpoint_recipe.rst index c20dca6639c..118dc7e7794 100644 --- a/recipes_source/distributed_checkpoint_recipe.rst +++ b/recipes_source/distributed_checkpoint_recipe.rst @@ -1,7 +1,7 @@ Getting Started with Distributed Checkpoint (DCP) ===================================================== -**Author**: `Iris Zhang `__, `Rodrigo Kumpera `__, `Chien-Chin Huang `__ +**Author**: `Iris Zhang `__, `Rodrigo Kumpera `__, `Chien-Chin Huang `__, `Lucas Pasqualin `__ .. note:: |edit| View and edit this tutorial in `github `__. @@ -22,8 +22,12 @@ In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model. How DCP works -------------- -:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel. -In addition, checkpointing automatically handles fully-qualified-name (FQN) mappings across models and optimizers, enabling load-time resharding across differing cluster topologies. +:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel, +and then re-shard across differing cluster topologies at load time. + +Addditionally, through the use of modules in :func:`torch.distributed.checkpoint.state_dict`, +DCP offers support for gracefully handling ``state_dict`` generation and loading in distributed settings. +This includes managing fully-qualified-name (FQN) mappings across models and optimizers, and setting default parameters for PyTorch provided parallelisms. DCP is different from :func:`torch.save` and :func:`torch.load` in a few significant ways: @@ -42,7 +46,7 @@ Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly, Saving ~~~~~~ -Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it. +Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it. .. code-block:: python @@ -50,11 +54,12 @@ Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy inp import torch import torch.distributed as dist - import torch.distributed.checkpoint as DCP + import torch.distributed.checkpoint as dcp import torch.multiprocessing as mp import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.checkpoint.state_dict import get_state_dict from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType CHECKPOINT_DIR = "checkpoint" @@ -99,20 +104,14 @@ Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy inp model(torch.rand(8, 16, device="cuda")).sum().backward() optimizer.step() - # set FSDP StateDictType to SHARDED_STATE_DICT so we can use DCP to checkpoint sharded model state dict - # note that we do not support FSDP StateDictType.LOCAL_STATE_DICT - FSDP.set_state_dict_type( - model, - StateDictType.SHARDED_STATE_DICT, - ) + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) state_dict = { - "model": model.state_dict(), + "model": model_state_dict, + "optimizer": optimizer_state_dict } + dcp.save(state_dict,checkpoint_id=CHECKPOINT_DIR) - DCP.save_state_dict( - state_dict=state_dict, - storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR), - ) cleanup() @@ -152,12 +151,12 @@ The reason that we need the ``state_dict`` prior to loading is: import torch import torch.distributed as dist - import torch.distributed.checkpoint as DCP + import torch.distributed.checkpoint as dcp + from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict import torch.multiprocessing as mp import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType CHECKPOINT_DIR = "checkpoint" @@ -194,21 +193,23 @@ The reason that we need the ``state_dict`` prior to loading is: model = ToyModel().to(rank) model = FSDP(model) - FSDP.set_state_dict_type( - model, - StateDictType.SHARDED_STATE_DICT, - ) - # different from ``torch.load()``, DCP requires model state_dict prior to loading to get - # the allocated storage and sharding information. + # generates the state dict we will load into + model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) state_dict = { - "model": model.state_dict(), + "model": model_state_dict, + "optimizer": optimizer_state_dict } - - DCP.load_state_dict( + dcp.load( state_dict=state_dict, - storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR), + checkpoint_id=CHECKPOINT_DIR, + ) + # sets our state dicts on the model and optimizer, now that we've loaded + set_state_dict( + model, + optimizer, + model_state_dict=model_state_dict, + optim_state_dict=optimizer_state_dict ) - model.load_state_dict(state_dict["model"]) cleanup() @@ -224,7 +225,8 @@ The reason that we need the ``state_dict`` prior to loading is: ) If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP. -By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. To load without a distributed setup, please set ``no_dist`` to ``True`` when loading with DCP. +By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. However if no process group is initialized, DCP infers +the intent is to save or load in "non-distributed" style, meaning entirely in the current process. .. note:: Distributed checkpoint support for Multi-Program Multi-Data is still under development. @@ -259,11 +261,10 @@ By default, DCP saves and loads a distributed ``state_dict`` in Single Program M "model": model.state_dict(), } - # turn no_dist to be true to load in non-distributed setting - DCP.load_state_dict( + # since no progress group is initialized, DCP will disable any collectives. + dcp.load( state_dict=state_dict, - storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR), - no_dist=True, + checkpoint_id=CHECKPOINT_DIR, ) model.load_state_dict(state_dict["model"]) @@ -274,7 +275,9 @@ By default, DCP saves and loads a distributed ``state_dict`` in Single Program M Conclusion ---------- -In conclusion, we have learned how to use DCP's :func:`save_state_dict` and :func:`load_state_dict` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`. +In conclusion, we have learned how to use DCP's :func:`save` and :func:`load` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`. +Additionally, we've learned how to use :func:`get_state_dict` and :func:`set_state_dict` to automatically manage parallelism-specific FQN's and defaults during state dict +generation and loading. For more information, please see the following: