Skip to content

Commit df651ff

Browse files
snarayan21Peter Y. Yeh
authored andcommitted
Remove activation checkpointing tag to get correct FQNs (pytorch#124698)
Fixes pytorch#124546 When setting `use_orig_params = False` and using activation checkpointing, the FQN mapping as retrieved by the `_get_fqns` function is incorrect because the prefix that is added to the name of each activation checkpointed module, `_checkpoint_wrapped_module`, can still be present. I think this is an edge case with the `_get_fqns` function that was not addressed by this previous commit pytorch#118119. Without the change, the list of object names for an activation checkpointed module with FSDP (and `use_orig_params=False`) can be something like: ``` ['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param'] ``` Which will incorrectly return just one FQN, `{'model.transformer.blocks.0._flat_param'}`, when all the FQNs of the parameters of the transformer block should be returned. With the change, the list of object names will now have `_checkpoint_wrapped_module` removed: ``` ['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param'] ``` And the FQNs are correctly retrieved and returned in `_get_fqns` when [this condition](https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/distributed/checkpoint/state_dict.py#L168) is satisfied. The correct FQNs are: ``` {'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias', 'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight', 'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight', 'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight', 'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias', 'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'} ``` Pull Request resolved: pytorch#124698 Approved by: https://github.com/Skylion007
1 parent 9ae9be2 commit df651ff

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

torch/distributed/checkpoint/state_dict.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,11 @@ def _get_fqns(
152152
Returns:
153153
The canonical FQNs based on the model traversal.
154154
"""
155+
156+
# Remove the checkpoint prefix, if it exists.
157+
name = name.replace(_CHECKPOINT_PREFIX, "")
155158
if "." not in name:
156-
return {name.replace(_CHECKPOINT_PREFIX, "")}
159+
return {name}
157160

158161
obj_names = name.split(".")
159162
fqn_obj_names = []
@@ -170,8 +173,6 @@ def _get_fqns(
170173
flat_param = getattr(curr_obj, FLAT_PARAM)
171174
if prefix:
172175
prefix = f"{prefix}."
173-
# FSDP already handles removal of checkpoint prefix, so we can return
174-
# directly
175176
return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
176177
curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
177178
if curr_obj_name != FSDP_WRAPPED_MODULE:

0 commit comments

Comments
 (0)