Skip to content

Commit

Permalink
Remove activation checkpointing tag to get correct FQNs (pytorch#124698)
Browse files Browse the repository at this point in the history
Fixes pytorch#124546

When setting `use_orig_params = False` and using activation checkpointing, the FQN mapping as retrieved by the `_get_fqns` function is incorrect because the prefix that is added to the name of each activation checkpointed module, `_checkpoint_wrapped_module`, can still be present. I think this is an edge case with the `_get_fqns` function that was not addressed by this previous commit pytorch#118119.

Without the change, the list of object names for an activation checkpointed module with FSDP (and `use_orig_params=False`) can be something like:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param']
```
Which will incorrectly return just one FQN, `{'model.transformer.blocks.0._flat_param'}`, when all the FQNs of the parameters of the transformer block should be returned.

With the change, the list of object names will now have `_checkpoint_wrapped_module` removed:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param']
```
And the FQNs are correctly retrieved and returned in `_get_fqns` when [this condition](https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/distributed/checkpoint/state_dict.py#L168) is satisfied. The correct FQNs are:
```
{'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias',
'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight',
'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight',
'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight',
'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias',
'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'}
```

Pull Request resolved: pytorch#124698
Approved by: https://github.com/Skylion007
  • Loading branch information
snarayan21 authored and pytorchmergebot committed Apr 24, 2024
1 parent 997d52c commit 0c572eb
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions torch/distributed/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,11 @@ def _get_fqns(
Returns:
The canonical FQNs based on the model traversal.
"""

# Remove the checkpoint prefix, if it exists.
name = name.replace(_CHECKPOINT_PREFIX, "")
if "." not in name:
return {name.replace(_CHECKPOINT_PREFIX, "")}
return {name}

obj_names = name.split(".")
fqn_obj_names = []
Expand All @@ -170,8 +173,6 @@ def _get_fqns(
flat_param = getattr(curr_obj, FLAT_PARAM)
if prefix:
prefix = f"{prefix}."
# FSDP already handles removal of checkpoint prefix, so we can return
# directly
return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
if curr_obj_name != FSDP_WRAPPED_MODULE:
Expand Down

0 comments on commit 0c572eb

Please sign in to comment.