Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DCP] Removes Checkpoint Wrapped Prefix from state dict fqns #118119

Closed
wants to merge 6 commits into from

Conversation

LucasLLC
Copy link
Contributor

@LucasLLC LucasLLC commented Jan 23, 2024

Fixes #117399

Soliciting some early feedback here.

Do we happen to know if there already some tests that cover this case or would it make sense to add? @fegin , @wz337

Edit: Added tests

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225

Copy link

pytorch-bot bot commented Jan 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118119

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3bbd358 with merge base abe3c55 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions github-actions bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jan 23, 2024
@fegin
Copy link
Contributor

fegin commented Jan 23, 2024

You can check https://github.com/pytorch/pytorch/blob/main/test/distributed/fsdp/test_fsdp_optim_state.py#L620. But remember NOT use FSDP to wrap the model as FSDP will handle the prefix for you. So you won't be able to test the logic if FSDP is used.

@LucasLLC LucasLLC marked this pull request as ready for review January 25, 2024 16:26
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add one comment. Others LGTM.

Comment on lines 462 to 464
for model, name in model_names:
for fqn in _get_fqns(model, name):
self.assertNotIn(_CHECKPOINT_WRAPPED_MODULE, fqn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just call get_state_dict() and compare its keys with the original model without activation checkpoint (you can deepcopy the original model).

@LucasLLC
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 26, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@LucasLLC
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: Check mergeability and dependencies for ghstack prs / ghstack-mergeability-check

Details for Dev Infra team Raised by workflow job

@LucasLLC
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jeffdaily pushed a commit to ROCm/pytorch that referenced this pull request Feb 8, 2024
…#118119)

Fixes pytorch#117399

~~Soliciting some early feedback here.~~

~~Do we happen to know if there already some tests that cover this case or would it make sense to add? @fegin , @wz337~~

Edit: Added tests

Pull Request resolved: pytorch#118119
Approved by: https://github.com/fegin
@github-actions github-actions bot deleted the remove_checkpoint_wrapped_from_fqn branch February 29, 2024 02:08
pytorchmergebot pushed a commit that referenced this pull request Apr 24, 2024
Fixes #124546

When setting `use_orig_params = False` and using activation checkpointing, the FQN mapping as retrieved by the `_get_fqns` function is incorrect because the prefix that is added to the name of each activation checkpointed module, `_checkpoint_wrapped_module`, can still be present. I think this is an edge case with the `_get_fqns` function that was not addressed by this previous commit #118119.

Without the change, the list of object names for an activation checkpointed module with FSDP (and `use_orig_params=False`) can be something like:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param']
```
Which will incorrectly return just one FQN, `{'model.transformer.blocks.0._flat_param'}`, when all the FQNs of the parameters of the transformer block should be returned.

With the change, the list of object names will now have `_checkpoint_wrapped_module` removed:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param']
```
And the FQNs are correctly retrieved and returned in `_get_fqns` when [this condition](https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/distributed/checkpoint/state_dict.py#L168) is satisfied. The correct FQNs are:
```
{'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias',
'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight',
'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight',
'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight',
'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias',
'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'}
```

Pull Request resolved: #124698
Approved by: https://github.com/Skylion007
pytorchmergebot pushed a commit to xuhancn/pytorch that referenced this pull request Apr 24, 2024
Fixes pytorch#124546

When setting `use_orig_params = False` and using activation checkpointing, the FQN mapping as retrieved by the `_get_fqns` function is incorrect because the prefix that is added to the name of each activation checkpointed module, `_checkpoint_wrapped_module`, can still be present. I think this is an edge case with the `_get_fqns` function that was not addressed by this previous commit pytorch#118119.

Without the change, the list of object names for an activation checkpointed module with FSDP (and `use_orig_params=False`) can be something like:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param']
```
Which will incorrectly return just one FQN, `{'model.transformer.blocks.0._flat_param'}`, when all the FQNs of the parameters of the transformer block should be returned.

With the change, the list of object names will now have `_checkpoint_wrapped_module` removed:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param']
```
And the FQNs are correctly retrieved and returned in `_get_fqns` when [this condition](https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/distributed/checkpoint/state_dict.py#L168) is satisfied. The correct FQNs are:
```
{'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias',
'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight',
'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight',
'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight',
'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias',
'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'}
```

Pull Request resolved: pytorch#124698
Approved by: https://github.com/Skylion007
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
Fixes pytorch#124546

When setting `use_orig_params = False` and using activation checkpointing, the FQN mapping as retrieved by the `_get_fqns` function is incorrect because the prefix that is added to the name of each activation checkpointed module, `_checkpoint_wrapped_module`, can still be present. I think this is an edge case with the `_get_fqns` function that was not addressed by this previous commit pytorch#118119.

Without the change, the list of object names for an activation checkpointed module with FSDP (and `use_orig_params=False`) can be something like:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param']
```
Which will incorrectly return just one FQN, `{'model.transformer.blocks.0._flat_param'}`, when all the FQNs of the parameters of the transformer block should be returned.

With the change, the list of object names will now have `_checkpoint_wrapped_module` removed:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param']
```
And the FQNs are correctly retrieved and returned in `_get_fqns` when [this condition](https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/distributed/checkpoint/state_dict.py#L168) is satisfied. The correct FQNs are:
```
{'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias',
'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight',
'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight',
'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight',
'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias',
'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'}
```

Pull Request resolved: pytorch#124698
Approved by: https://github.com/Skylion007
mvpatel2000 pushed a commit to mvpatel2000/pytorch that referenced this pull request May 17, 2024
Fixes pytorch#124546

When setting `use_orig_params = False` and using activation checkpointing, the FQN mapping as retrieved by the `_get_fqns` function is incorrect because the prefix that is added to the name of each activation checkpointed module, `_checkpoint_wrapped_module`, can still be present. I think this is an edge case with the `_get_fqns` function that was not addressed by this previous commit pytorch#118119.

Without the change, the list of object names for an activation checkpointed module with FSDP (and `use_orig_params=False`) can be something like:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param']
```
Which will incorrectly return just one FQN, `{'model.transformer.blocks.0._flat_param'}`, when all the FQNs of the parameters of the transformer block should be returned.

With the change, the list of object names will now have `_checkpoint_wrapped_module` removed:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param']
```
And the FQNs are correctly retrieved and returned in `_get_fqns` when [this condition](https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/distributed/checkpoint/state_dict.py#L168) is satisfied. The correct FQNs are:
```
{'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias',
'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight',
'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight',
'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight',
'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias',
'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'}
```

Pull Request resolved: pytorch#124698
Approved by: https://github.com/Skylion007
atalman pushed a commit that referenced this pull request May 24, 2024
…26559)

Fixes #124546

When setting `use_orig_params = False` and using activation checkpointing, the FQN mapping as retrieved by the `_get_fqns` function is incorrect because the prefix that is added to the name of each activation checkpointed module, `_checkpoint_wrapped_module`, can still be present. I think this is an edge case with the `_get_fqns` function that was not addressed by this previous commit #118119.

Without the change, the list of object names for an activation checkpointed module with FSDP (and `use_orig_params=False`) can be something like:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param']
```
Which will incorrectly return just one FQN, `{'model.transformer.blocks.0._flat_param'}`, when all the FQNs of the parameters of the transformer block should be returned.

With the change, the list of object names will now have `_checkpoint_wrapped_module` removed:
```
['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param']
```
And the FQNs are correctly retrieved and returned in `_get_fqns` when [this condition](https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/distributed/checkpoint/state_dict.py#L168) is satisfied. The correct FQNs are:
```
{'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias',
'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight',
'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight',
'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight',
'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias',
'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'}
```

Pull Request resolved: #124698
Approved by: https://github.com/Skylion007

Co-authored-by: Saaketh <narayan.saaketh@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CHECKPOINT_PREFIX is not stripped when non-root module is activation checkpointed
3 participants