Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Fix FSDP.clip_grad_norm_() for NO_SHARD #88955

Closed
wants to merge 1 commit into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Nov 12, 2022

Stack from ghstack:

This PR fixes FSDP.clip_grad_norm_() for NO_SHARD, which previously "double-counted" each gradient world_size-many times.

This does not address any discrepancies between FULL_SHARD and DDP. (Note that the unit tests do show parity between FULL_SHARD and DDP when using FSDP.clip_grad_norm_() and nn.utils.clip_grad_norm_() respectively on one iteration.)

The added unit test code path tests mixing nested FSDP instances with both FULL_SHARD and NO_SHARD to ensure that the local_sharded_norm and local_nonsharded_norm computations are interoperating correctly. I want to test non-FSDP root instance in the future, but this is BC breaking since we need to make clip_grad_norm_() a static method, which would require a different method call syntax (FSDP.clip_grad_norm_(root_module, ...) vs. root_module.clip_grad_norm_(...)).

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 12, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88955

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 55e476f:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

awgu added a commit that referenced this pull request Nov 12, 2022
ghstack-source-id: 17065b44c859f14642d371a37198bdd556220e19
Pull Request resolved: #88955
@awgu awgu added the topic: improvements topic category label Nov 12, 2022
@@ -1161,23 +1161,45 @@ def clip_grad_norm_(
self._streams["unshard"],
self._streams["pre_unshard"],
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For NO_SHARD, wondering whether we can just call the default PyTorch clip_grad_norm_() and early return without all_reduce and following logics?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we could, but this version enables us to mix and match NO_SHARD and FULL_SHARD for different submodules. If the user knows the entire FSDP instance is only NO_SHARD, then they can just use torch.nn.utils.clip_grad_norm_().

I think some of this logic in FSDP.clip_grad_norm_() will be useful when thinking about how to write clip_grad_norm_() for our composable APIs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can optionally do a check over all FSDP instances in the module tree, and if all are NO_SHARD, then we can return nn.utils.clip_grad_norm_() directly as you suggested. Maybe I can include that fast path in a follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense!

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 12, 2022
@awgu
Copy link
Contributor Author

awgu commented Nov 13, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
This PR fixes `FSDP.clip_grad_norm_()` for `NO_SHARD`, which previously "double-counted" each gradient `world_size`-many times.

This does not address any discrepancies between `FULL_SHARD` and DDP. (Note that the unit tests do show parity between `FULL_SHARD` and DDP when using `FSDP.clip_grad_norm_()` and `nn.utils.clip_grad_norm_()` respectively on one iteration.)

The added unit test code path tests mixing nested FSDP instances with both `FULL_SHARD` and `NO_SHARD` to ensure that the `local_sharded_norm` and `local_nonsharded_norm` computations are interoperating correctly. I want to test non-FSDP root instance in the future, but this is BC breaking since we need to make `clip_grad_norm_()` a static method, which would require a different method call syntax (`FSDP.clip_grad_norm_(root_module, ...)` vs. `root_module.clip_grad_norm_(...)`).
Pull Request resolved: pytorch#88955
Approved by: https://github.com/zhaojuanmao
@facebook-github-bot facebook-github-bot deleted the gh/awgu/199/head branch June 8, 2023 15:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants