New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Propagate requires_grad attribute to unsharded params #109892
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109892
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 168c5d2 with merge base 92de1d3 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D49517155 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me!
…#109892) Summary: This preserves `requires_grad` in the case where all parameters within a `FlatParameter` have the same `requires_grad` value. Currently, unsharded parameters have `requires_grad=True` in some cases where the `FlatParameter` and all original parameters have `requires_grad=False`. This could be extended to support `FlatParameters` with a mix of `requires_grad` states by extending `ParamInfo` to capture `requires_grad` for each parameter. Test Plan: test added Reviewed By: awgu Differential Revision: D49517155
2a94937
to
9bc05e0
Compare
…#109892) Summary: This preserves `requires_grad` in the case where all parameters within a `FlatParameter` have the same `requires_grad` value. Currently, unsharded parameters have `requires_grad=True` in some cases where the `FlatParameter` and all original parameters have `requires_grad=False`. This could be extended to support `FlatParameters` with a mix of `requires_grad` states by extending `ParamInfo` to capture `requires_grad` for each parameter. Test Plan: test added Reviewed By: awgu Differential Revision: D49517155
This pull request was exported from Phabricator. Differential Revision: D49517155 |
9bc05e0
to
168c5d2
Compare
This pull request was exported from Phabricator. Differential Revision: D49517155 |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-focal-rocm5.6-py3.8 / test (default, 2, 3, linux.rocm.gpu) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…checkpoints Summary: EMA can be configured to exclude frozen (`requires_grad=False`) parameters and buffers, reducing memory use and checkpoint size. However `FULL_STATE_DICT` FSDP + EMA checkpoints construct an inner `EMAState` after unsharding FSDP parameters. This inner `EMAState` uses default `include_frozen` and `include_buffers` settings, resulting in checkpoints containing frozen parameters and buffers regardless of settings. Propagate `include_frozen` and `include_buffers` settings to the inner `EMAState` when gathering `FULL_STATE_DICT` FSDP EMA state. This change only affects frozen parameters with a parallel fix to PyTorch FSDP to propagate `requires_grad` across parameter sharding/unsharding: pytorch/pytorch#109892. Differential Revision: D49517178 fbshipit-source-id: fefc9bb898a93c1746dad02007ca5b3384b0fcd6
…checkpoints Summary: Pull Request resolved: #620 EMA can be configured to exclude frozen (`requires_grad=False`) parameters and buffers, reducing memory use and checkpoint size. However `FULL_STATE_DICT` FSDP + EMA checkpoints construct an inner `EMAState` after unsharding FSDP parameters. This inner `EMAState` uses default `include_frozen` and `include_buffers` settings, resulting in checkpoints containing frozen parameters and buffers regardless of settings. Propagate `include_frozen` and `include_buffers` settings to the inner `EMAState` when gathering `FULL_STATE_DICT` FSDP EMA state. This change only affects frozen parameters with a parallel fix to PyTorch FSDP to propagate `requires_grad` across parameter sharding/unsharding: pytorch/pytorch#109892. Reviewed By: daveboat Differential Revision: D49517178 fbshipit-source-id: 0fe159dcec9ec1f2c456ae2ee7798681e7536249
…#109892) Summary: This preserves `requires_grad` in the case where all parameters within a `FlatParameter` have the same `requires_grad` value. Currently, unsharded parameters have `requires_grad=True` in some cases where the `FlatParameter` and all original parameters have `requires_grad=False`. This could be extended to support `FlatParameters` with a mix of `requires_grad` states by extending `ParamInfo` to capture `requires_grad` for each parameter. Test Plan: test added Differential Revision: D49517155 Pull Request resolved: pytorch#109892 Approved by: https://github.com/awgu
Summary:
This preserves
requires_grad
in the case where all parameters within aFlatParameter
have the samerequires_grad
value.Currently, unsharded parameters have
requires_grad=True
in some cases where theFlatParameter
and all original parameters haverequires_grad=False
.This could be extended to support
FlatParameters
with a mix ofrequires_grad
states by extendingParamInfo
to capturerequires_grad
for each parameter.Test Plan: test added
Differential Revision: D49517155