New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Support unfreezing params for reshard-only hook #104186
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/104186
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2b5c8f1: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: acd123f95b2c3fd90555aeef08274292c0c30f9d Pull Request resolved: #104186
[ghstack-poisoned]
ghstack-source-id: 487b4b9611d6a7f307c03682b463ed48ed35c0b9 Pull Request resolved: #104186
Thanks for the timely and adroit (as always!) fix/enhancement @awgu! 🎉 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the quick fix!
for param in seq[i * 2].parameters(recurse=False): | ||
param.requires_grad = False | ||
return seq | ||
for param in seq[i * 2].parameters(recurse=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't .parameters(recurse=True) the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. This recurse=True
is unnecessary. I will leave it to avoid triggering CI.
def _reset_flat_param_grad_info_if_needed(self): | ||
""" | ||
When ``use_orig_params=True``: | ||
(1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean we need a unittest to ensure the flat_param.grad is None appropriately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the existing behavior, which we already have unit tests for in test_fsdp_use_orig_params.py
.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / win-vs2019-cuda11.8-py3 / build Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
This fixes #104148 (unfreezing parameters after
n
steps).requires_grad=False
case.already_resharded
correct forSHARD_GRAD_OP
._clear_grads_if_needed()
to_reset_flat_param_grad_info_if_needed()
to additionally include propagating the original parameters'requires_grad
to the flat parameter.