-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Updates dcp tutorial with recent updates to api including save, load, and distributed state dict #2832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updates dcp tutorial with recent updates to api including save, load, and distributed state dict #2832
Changes from all commits
e3a49e0
949b75b
d07b09b
f1acc61
65615fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
Getting Started with Distributed Checkpoint (DCP) | ||
===================================================== | ||
|
||
**Author**: `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__ | ||
**Author**: `Iris Zhang <https://github.com/wz337>`__, `Rodrigo Kumpera <https://github.com/kumpera>`__, `Chien-Chin Huang <https://github.com/fegin>`__, `Lucas Pasqualin <https://github.com/lucasllc>`__ | ||
|
||
.. note:: | ||
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst>`__. | ||
|
@@ -22,8 +22,12 @@ In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model. | |
How DCP works | ||
-------------- | ||
|
||
:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel. | ||
In addition, checkpointing automatically handles fully-qualified-name (FQN) mappings across models and optimizers, enabling load-time resharding across differing cluster topologies. | ||
:func:`torch.distributed.checkpoint` enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel, | ||
and then re-shard across differing cluster topologies at load time. | ||
|
||
Addditionally, through the use of modules in :func:`torch.distributed.checkpoint.state_dict`, | ||
DCP offers support for gracefully handling ``state_dict`` generation and loading in distributed settings. | ||
This includes managing fully-qualified-name (FQN) mappings across models and optimizers, and setting default parameters for PyTorch provided parallelisms. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optimizer initialization is a good point, but I was referencing setting the state dict type for FSDP, since users don't have to do this if they use distributed state dict |
||
|
||
DCP is different from :func:`torch.save` and :func:`torch.load` in a few significant ways: | ||
|
||
|
@@ -42,19 +46,20 @@ Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly, | |
Saving | ||
~~~~~~ | ||
|
||
Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it. | ||
Now, let's create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it. | ||
|
||
.. code-block:: python | ||
|
||
import os | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.distributed.checkpoint as DCP | ||
import torch.distributed.checkpoint as dcp | ||
import torch.multiprocessing as mp | ||
import torch.nn as nn | ||
|
||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.distributed.checkpoint.state_dict import get_state_dict | ||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType | ||
|
||
CHECKPOINT_DIR = "checkpoint" | ||
|
@@ -99,20 +104,14 @@ Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy inp | |
model(torch.rand(8, 16, device="cuda")).sum().backward() | ||
optimizer.step() | ||
|
||
# set FSDP StateDictType to SHARDED_STATE_DICT so we can use DCP to checkpoint sharded model state dict | ||
# note that we do not support FSDP StateDictType.LOCAL_STATE_DICT | ||
FSDP.set_state_dict_type( | ||
model, | ||
StateDictType.SHARDED_STATE_DICT, | ||
) | ||
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT | ||
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) | ||
state_dict = { | ||
"model": model.state_dict(), | ||
"model": model_state_dict, | ||
"optimizer": optimizer_state_dict | ||
} | ||
dcp.save(state_dict,checkpoint_id=CHECKPOINT_DIR) | ||
|
||
DCP.save_state_dict( | ||
state_dict=state_dict, | ||
storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR), | ||
) | ||
|
||
cleanup() | ||
|
||
|
@@ -152,12 +151,12 @@ The reason that we need the ``state_dict`` prior to loading is: | |
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.distributed.checkpoint as DCP | ||
import torch.distributed.checkpoint as dcp | ||
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict | ||
import torch.multiprocessing as mp | ||
import torch.nn as nn | ||
|
||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType | ||
|
||
CHECKPOINT_DIR = "checkpoint" | ||
|
||
|
@@ -194,21 +193,23 @@ The reason that we need the ``state_dict`` prior to loading is: | |
model = ToyModel().to(rank) | ||
model = FSDP(model) | ||
|
||
FSDP.set_state_dict_type( | ||
model, | ||
StateDictType.SHARDED_STATE_DICT, | ||
) | ||
# different from ``torch.load()``, DCP requires model state_dict prior to loading to get | ||
# the allocated storage and sharding information. | ||
# generates the state dict we will load into | ||
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) | ||
state_dict = { | ||
"model": model.state_dict(), | ||
"model": model_state_dict, | ||
"optimizer": optimizer_state_dict | ||
} | ||
|
||
DCP.load_state_dict( | ||
dcp.load( | ||
state_dict=state_dict, | ||
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR), | ||
checkpoint_id=CHECKPOINT_DIR, | ||
) | ||
# sets our state dicts on the model and optimizer, now that we've loaded | ||
set_state_dict( | ||
model, | ||
optimizer, | ||
model_state_dict=model_state_dict, | ||
optim_state_dict=optimizer_state_dict | ||
) | ||
model.load_state_dict(state_dict["model"]) | ||
|
||
cleanup() | ||
|
||
|
@@ -224,7 +225,8 @@ The reason that we need the ``state_dict`` prior to loading is: | |
) | ||
|
||
If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP. | ||
By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. To load without a distributed setup, please set ``no_dist`` to ``True`` when loading with DCP. | ||
By default, DCP saves and loads a distributed ``state_dict`` in Single Program Multiple Data(SPMD) style. However if no process group is initialized, DCP infers | ||
the intent is to save or load in "non-distributed" style, meaning entirely in the current process. | ||
|
||
.. note:: | ||
Distributed checkpoint support for Multi-Program Multi-Data is still under development. | ||
|
@@ -259,11 +261,10 @@ By default, DCP saves and loads a distributed ``state_dict`` in Single Program M | |
"model": model.state_dict(), | ||
} | ||
|
||
# turn no_dist to be true to load in non-distributed setting | ||
DCP.load_state_dict( | ||
# since no progress group is initialized, DCP will disable any collectives. | ||
dcp.load( | ||
state_dict=state_dict, | ||
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR), | ||
no_dist=True, | ||
checkpoint_id=CHECKPOINT_DIR, | ||
) | ||
model.load_state_dict(state_dict["model"]) | ||
|
||
|
@@ -274,7 +275,9 @@ By default, DCP saves and loads a distributed ``state_dict`` in Single Program M | |
|
||
Conclusion | ||
---------- | ||
In conclusion, we have learned how to use DCP's :func:`save_state_dict` and :func:`load_state_dict` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`. | ||
In conclusion, we have learned how to use DCP's :func:`save` and :func:`load` APIs, as well as how they are different form :func:`torch.save` and :func:`torch.load`. | ||
Additionally, we've learned how to use :func:`get_state_dict` and :func:`set_state_dict` to automatically manage parallelism-specific FQN's and defaults during state dict | ||
generation and loading. | ||
|
||
For more information, please see the following: | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also add
and different parallelisms
?