-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Fix FSDP.clip_grad_norm_()
for NO_SHARD
#88955
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88955
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 55e476f: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 17065b44c859f14642d371a37198bdd556220e19 Pull Request resolved: #88955
@@ -1161,23 +1161,45 @@ def clip_grad_norm_( | |||
self._streams["unshard"], | |||
self._streams["pre_unshard"], | |||
) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For NO_SHARD, wondering whether we can just call the default PyTorch clip_grad_norm_() and early return without all_reduce and following logics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we could, but this version enables us to mix and match NO_SHARD
and FULL_SHARD
for different submodules. If the user knows the entire FSDP instance is only NO_SHARD
, then they can just use torch.nn.utils.clip_grad_norm_()
.
I think some of this logic in FSDP.clip_grad_norm_()
will be useful when thinking about how to write clip_grad_norm_()
for our composable APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can optionally do a check over all FSDP instances in the module tree, and if all are NO_SHARD
, then we can return nn.utils.clip_grad_norm_()
directly as you suggested. Maybe I can include that fast path in a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, that makes sense!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR fixes `FSDP.clip_grad_norm_()` for `NO_SHARD`, which previously "double-counted" each gradient `world_size`-many times. This does not address any discrepancies between `FULL_SHARD` and DDP. (Note that the unit tests do show parity between `FULL_SHARD` and DDP when using `FSDP.clip_grad_norm_()` and `nn.utils.clip_grad_norm_()` respectively on one iteration.) The added unit test code path tests mixing nested FSDP instances with both `FULL_SHARD` and `NO_SHARD` to ensure that the `local_sharded_norm` and `local_nonsharded_norm` computations are interoperating correctly. I want to test non-FSDP root instance in the future, but this is BC breaking since we need to make `clip_grad_norm_()` a static method, which would require a different method call syntax (`FSDP.clip_grad_norm_(root_module, ...)` vs. `root_module.clip_grad_norm_(...)`). Pull Request resolved: pytorch#88955 Approved by: https://github.com/zhaojuanmao
Stack from ghstack:
FSDP.clip_grad_norm_()
forNO_SHARD
#88955 [FSDP] FixFSDP.clip_grad_norm_()
forNO_SHARD
ModuleWrapPolicy
for simplicity #88450 [FSDP] IntroduceModuleWrapPolicy
for simplicityThis PR fixes
FSDP.clip_grad_norm_()
forNO_SHARD
, which previously "double-counted" each gradientworld_size
-many times.This does not address any discrepancies between
FULL_SHARD
and DDP. (Note that the unit tests do show parity betweenFULL_SHARD
and DDP when usingFSDP.clip_grad_norm_()
andnn.utils.clip_grad_norm_()
respectively on one iteration.)The added unit test code path tests mixing nested FSDP instances with both
FULL_SHARD
andNO_SHARD
to ensure that thelocal_sharded_norm
andlocal_nonsharded_norm
computations are interoperating correctly. I want to test non-FSDP root instance in the future, but this is BC breaking since we need to makeclip_grad_norm_()
a static method, which would require a different method call syntax (FSDP.clip_grad_norm_(root_module, ...)
vs.root_module.clip_grad_norm_(...)
).