-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Skip _use_sharded_views()
for SHARD_GRAD_OP
#98250
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98250
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5e3aacc: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: bfec3b5c298bd1fb4ade5843039205ed28c35b1b Pull Request resolved: #98250
[ghstack-poisoned]
ghstack-source-id: 6c1d4fcf0987271375074c1c70149c5223c23cef Pull Request resolved: #98250
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
If we're concerned about blast radius here, feel free to gate w an env var before landing.
torch/distributed/fsdp/flat_param.py
Outdated
@@ -1055,7 +1063,7 @@ def pre_unshard(self) -> bool: | |||
matches the dtype of the expected unsharded parameter. | |||
""" | |||
ret = False | |||
if self._use_orig_params: | |||
if self._use_orig_params and not self._skipped_use_sharded_views: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this mean writeback feature gets disabled for zero2? can we just call use_sharded_views here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To clarify, we should not call _use_sharded_views()
here because that will simply undo the skipping.
Post-forward reshard skips _use_sharded_views()
.
Pre-backward unshard first calls pre_unshard()
. If we call _use_sharded_views()
in pre_unshard()
, then we undo the skip.
I am adding back the writeback check but raising an error if we detect a change between forward and backward for SHARD_GRAD_OP
.
torch/distributed/fsdp/flat_param.py
Outdated
if ( | ||
in_forward | ||
and self._sharding_strategy | ||
not in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we also skip calling use sharded grad views in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think we have to unless we want to use the .data
hack to bypass the shape check since otherwise the parameters are unsharded while the gradients are sharded. This is a niche use case anyway since a gradient must be accumulated (either actually accumulated or zero_grad(set_to_none=False)
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upon more thought, I think we should skip it anyway because unsharded parameters with sharded gradients can be confusing. If the user wants to inspect the gradients, they can use summon_full_params(with_grads=True)
.
[ghstack-poisoned]
ghstack-source-id: 4b5eb453deb8e55f2051778546a10e69a2f1e7ce Pull Request resolved: #98250
ghstack-source-id: 4b5eb453deb8e55f2051778546a10e69a2f1e7ce Pull Request resolved: pytorch#98250
(Before) Pre-backward hook: 4.356 ms <img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png"> (After) Pre-backward hook: 0.483 ms <img width="1025" alt="Screenshot 2023-04-03 at 6 32 25 PM" src="https://user-images.githubusercontent.com/31054793/229641301-971d3c60-a4f1-4561-bb33-7aa07c42d0bd.png"> The "after" value might increase slightly if I add back the `_writeback_orig_params()` check. [ghstack-poisoned]
ghstack-source-id: d33c3fe30e2a44229664dd819e7bc6ecca0c20a5 Pull Request resolved: #98250
(Before) Pre-backward hook: 4.356 ms <img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png"> (After) Pre-backward hook: 0.483 ms <img width="1025" alt="Screenshot 2023-04-03 at 6 32 25 PM" src="https://user-images.githubusercontent.com/31054793/229641301-971d3c60-a4f1-4561-bb33-7aa07c42d0bd.png"> The "after" value might increase slightly if I add back the `_writeback_orig_params()` check. [ghstack-poisoned]
ghstack-source-id: c458765fefa941c218242e27f02d0c3307a6d31c Pull Request resolved: #98250
(Before) Pre-backward hook: 4.356 ms <img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png"> (After) Pre-backward hook: 0.483 ms <img width="1025" alt="Screenshot 2023-04-03 at 6 32 25 PM" src="https://user-images.githubusercontent.com/31054793/229641301-971d3c60-a4f1-4561-bb33-7aa07c42d0bd.png"> The "after" value might increase slightly if I add back the `_writeback_orig_params()` check. [ghstack-poisoned]
ghstack-source-id: 526f1c42f31b858228c66c8934857c8348ffa281 Pull Request resolved: #98250
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR has `SHARD_GRAD_OP` (and `_HYBRID_SHARD_ZERO2`) skip `_use_sharded_views()` in the post-forward reshard since the strategy does not free the unsharded flat parameter and can preserve the unsharded views. This saves nontrivial CPU overhead both in the post-forward reshard (`_use_sharded_views()`) and the pre-backward unshard (`_use_unsharded_views()`). <details> <summary>(Before) Pre-backward hook: 4.356 ms</summary> <img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png"> </details> <details> <summary>(After) Pre-backward hook: 1.044 ms</summary> ![Screenshot 2023-04-04 at 9 05 53 AM](https://user-images.githubusercontent.com/31054793/229800917-9580ce6b-3721-469a-9212-f0cbfd8cbb52.png) </details> Pull Request resolved: #98250 Approved by: https://github.com/rohan-varma
Stack from ghstack (oldest at bottom):
requires_grad_mask
#98299_use_sharded_views()
forSHARD_GRAD_OP
#98250requires_grad
foruse_orig_params=True
#98221This PR has
SHARD_GRAD_OP
(and_HYBRID_SHARD_ZERO2
) skip_use_sharded_views()
in the post-forward reshard since the strategy does not free the unsharded flat parameter and can preserve the unsharded views. This saves nontrivial CPU overhead both in the post-forward reshard (_use_sharded_views()
) and the pre-backward unshard (_use_unsharded_views()
).(Before) Pre-backward hook: 4.356 ms
(After) Pre-backward hook: 1.044 ms