Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Propagate requires_grad attribute to unsharded params #109892

Closed
wants to merge 1 commit into from

Conversation

edpizzi
Copy link
Contributor

@edpizzi edpizzi commented Sep 22, 2023

Summary:
This preserves requires_grad in the case where all parameters within a FlatParameter have the same requires_grad value.

Currently, unsharded parameters have requires_grad=True in some cases where the FlatParameter and all original parameters have requires_grad=False.

This could be extended to support FlatParameters with a mix of requires_grad states by extending ParamInfo to capture requires_grad for each parameter.

Test Plan: test added

Differential Revision: D49517155

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 22, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109892

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 168c5d2 with merge base 92de1d3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Sep 22, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: edpizzi / name: Ed Pizzi (168c5d2)

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Sep 22, 2023
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49517155

Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me!

edpizzi added a commit to edpizzi/pytorch that referenced this pull request Sep 22, 2023
…#109892)

Summary:

This preserves `requires_grad` in the case where all parameters within a `FlatParameter` have the same `requires_grad` value.

Currently, unsharded parameters have `requires_grad=True` in some cases where the `FlatParameter` and all original parameters have `requires_grad=False`.

This could be extended to support `FlatParameters` with a mix of `requires_grad` states by extending `ParamInfo` to capture `requires_grad` for each parameter.

Test Plan: test added

Reviewed By: awgu

Differential Revision: D49517155
…#109892)

Summary:

This preserves `requires_grad` in the case where all parameters within a `FlatParameter` have the same `requires_grad` value.

Currently, unsharded parameters have `requires_grad=True` in some cases where the `FlatParameter` and all original parameters have `requires_grad=False`.

This could be extended to support `FlatParameters` with a mix of `requires_grad` states by extending `ParamInfo` to capture `requires_grad` for each parameter.

Test Plan: test added

Reviewed By: awgu

Differential Revision: D49517155
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49517155

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D49517155

@fegin fegin added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 23, 2023
@edpizzi
Copy link
Contributor Author

edpizzi commented Sep 23, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-focal-rocm5.6-py3.8 / test (default, 2, 3, linux.rocm.gpu)

Details for Dev Infra team Raised by workflow job

@edpizzi
Copy link
Contributor Author

edpizzi commented Sep 24, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

edpizzi added a commit to edpizzi/d2go that referenced this pull request Sep 25, 2023
…checkpoints

Summary:
EMA can be configured to exclude frozen (`requires_grad=False`) parameters and buffers, reducing memory use and checkpoint size.

However `FULL_STATE_DICT` FSDP + EMA checkpoints construct an inner `EMAState` after unsharding FSDP parameters. This inner `EMAState` uses default `include_frozen` and `include_buffers` settings, resulting in checkpoints containing frozen parameters and buffers regardless of settings.

Propagate `include_frozen` and `include_buffers` settings to the inner `EMAState` when gathering `FULL_STATE_DICT` FSDP EMA state.

This change only affects frozen parameters with a parallel fix to PyTorch FSDP to propagate `requires_grad` across parameter sharding/unsharding: pytorch/pytorch#109892.

Differential Revision: D49517178

fbshipit-source-id: fefc9bb898a93c1746dad02007ca5b3384b0fcd6
facebook-github-bot pushed a commit to facebookresearch/d2go that referenced this pull request Sep 25, 2023
…checkpoints

Summary:
Pull Request resolved: #620

EMA can be configured to exclude frozen (`requires_grad=False`) parameters and buffers, reducing memory use and checkpoint size.

However `FULL_STATE_DICT` FSDP + EMA checkpoints construct an inner `EMAState` after unsharding FSDP parameters. This inner `EMAState` uses default `include_frozen` and `include_buffers` settings, resulting in checkpoints containing frozen parameters and buffers regardless of settings.

Propagate `include_frozen` and `include_buffers` settings to the inner `EMAState` when gathering `FULL_STATE_DICT` FSDP EMA state.

This change only affects frozen parameters with a parallel fix to PyTorch FSDP to propagate `requires_grad` across parameter sharding/unsharding: pytorch/pytorch#109892.

Reviewed By: daveboat

Differential Revision: D49517178

fbshipit-source-id: 0fe159dcec9ec1f2c456ae2ee7798681e7536249
@edpizzi edpizzi deleted the export-D49517155 branch September 25, 2023 21:27
ringohoffman pushed a commit to ringohoffman/pytorch that referenced this pull request Sep 27, 2023
…#109892)

Summary:
This preserves `requires_grad` in the case where all parameters within a `FlatParameter` have the same `requires_grad` value.

Currently, unsharded parameters have `requires_grad=True` in some cases where the `FlatParameter` and all original parameters have `requires_grad=False`.

This could be extended to support `FlatParameters` with a mix of `requires_grad` states by extending `ParamInfo` to capture `requires_grad` for each parameter.

Test Plan: test added

Differential Revision: D49517155

Pull Request resolved: pytorch#109892
Approved by: https://github.com/awgu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants