From 0c572ebd506e81b84c74515e81539753093c9bfe Mon Sep 17 00:00:00 2001 From: Saaketh Date: Wed, 24 Apr 2024 16:47:50 +0000 Subject: [PATCH] Remove activation checkpointing tag to get correct FQNs (#124698) Fixes #124546 When setting `use_orig_params = False` and using activation checkpointing, the FQN mapping as retrieved by the `_get_fqns` function is incorrect because the prefix that is added to the name of each activation checkpointed module, `_checkpoint_wrapped_module`, can still be present. I think this is an edge case with the `_get_fqns` function that was not addressed by this previous commit #118119. Without the change, the list of object names for an activation checkpointed module with FSDP (and `use_orig_params=False`) can be something like: ``` ['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param'] ``` Which will incorrectly return just one FQN, `{'model.transformer.blocks.0._flat_param'}`, when all the FQNs of the parameters of the transformer block should be returned. With the change, the list of object names will now have `_checkpoint_wrapped_module` removed: ``` ['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param'] ``` And the FQNs are correctly retrieved and returned in `_get_fqns` when [this condition](https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/distributed/checkpoint/state_dict.py#L168) is satisfied. The correct FQNs are: ``` {'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias', 'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight', 'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight', 'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight', 'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias', 'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'} ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/124698 Approved by: https://github.com/Skylion007 --- torch/distributed/checkpoint/state_dict.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index a8f8216057a82..5215eb436590f 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -152,8 +152,11 @@ def _get_fqns( Returns: The canonical FQNs based on the model traversal. """ + + # Remove the checkpoint prefix, if it exists. + name = name.replace(_CHECKPOINT_PREFIX, "") if "." not in name: - return {name.replace(_CHECKPOINT_PREFIX, "")} + return {name} obj_names = name.split(".") fqn_obj_names = [] @@ -170,8 +173,6 @@ def _get_fqns( flat_param = getattr(curr_obj, FLAT_PARAM) if prefix: prefix = f"{prefix}." - # FSDP already handles removal of checkpoint prefix, so we can return - # directly return {f"{prefix}{fqn}" for fqn in flat_param._fqns} curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) if curr_obj_name != FSDP_WRAPPED_MODULE: