Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Skip _use_sharded_views() for SHARD_GRAD_OP #98250

Closed
wants to merge 6 commits into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Apr 3, 2023

Stack from ghstack (oldest at bottom):

This PR has SHARD_GRAD_OP (and _HYBRID_SHARD_ZERO2) skip _use_sharded_views() in the post-forward reshard since the strategy does not free the unsharded flat parameter and can preserve the unsharded views. This saves nontrivial CPU overhead both in the post-forward reshard (_use_sharded_views()) and the pre-backward unshard (_use_unsharded_views()).

(Before) Pre-backward hook: 4.356 ms Screenshot 2023-04-03 at 6 32 19 PM
(After) Pre-backward hook: 1.044 ms

Screenshot 2023-04-04 at 9 05 53 AM

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 3, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98250

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5e3aacc:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Apr 3, 2023
awgu added a commit that referenced this pull request Apr 3, 2023
ghstack-source-id: bfec3b5c298bd1fb4ade5843039205ed28c35b1b
Pull Request resolved: #98250
awgu added a commit that referenced this pull request Apr 3, 2023
ghstack-source-id: 6c1d4fcf0987271375074c1c70149c5223c23cef
Pull Request resolved: #98250
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

If we're concerned about blast radius here, feel free to gate w an env var before landing.

@@ -1055,7 +1063,7 @@ def pre_unshard(self) -> bool:
matches the dtype of the expected unsharded parameter.
"""
ret = False
if self._use_orig_params:
if self._use_orig_params and not self._skipped_use_sharded_views:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean writeback feature gets disabled for zero2? can we just call use_sharded_views here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, we should not call _use_sharded_views() here because that will simply undo the skipping.

Post-forward reshard skips _use_sharded_views().
Pre-backward unshard first calls pre_unshard(). If we call _use_sharded_views() in pre_unshard(), then we undo the skip.


I am adding back the writeback check but raising an error if we detect a change between forward and backward for SHARD_GRAD_OP.

torch/distributed/fsdp/flat_param.py Show resolved Hide resolved
if (
in_forward
and self._sharding_strategy
not in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we also skip calling use sharded grad views in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think we have to unless we want to use the .data hack to bypass the shape check since otherwise the parameters are unsharded while the gradients are sharded. This is a niche use case anyway since a gradient must be accumulated (either actually accumulated or zero_grad(set_to_none=False)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon more thought, I think we should skip it anyway because unsharded parameters with sharded gradients can be confusing. If the user wants to inspect the gradients, they can use summon_full_params(with_grads=True).

awgu added a commit that referenced this pull request Apr 3, 2023
ghstack-source-id: 4b5eb453deb8e55f2051778546a10e69a2f1e7ce
Pull Request resolved: #98250
awgu added a commit to awgu/pytorch that referenced this pull request Apr 3, 2023
ghstack-source-id: 4b5eb453deb8e55f2051778546a10e69a2f1e7ce
Pull Request resolved: pytorch#98250
(Before) Pre-backward hook: 4.356 ms
<img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png">


(After) Pre-backward hook: 0.483 ms
<img width="1025" alt="Screenshot 2023-04-03 at 6 32 25 PM" src="https://user-images.githubusercontent.com/31054793/229641301-971d3c60-a4f1-4561-bb33-7aa07c42d0bd.png">

The "after" value might increase slightly if I add back the `_writeback_orig_params()` check.


[ghstack-poisoned]
awgu added a commit that referenced this pull request Apr 3, 2023
ghstack-source-id: d33c3fe30e2a44229664dd819e7bc6ecca0c20a5
Pull Request resolved: #98250
(Before) Pre-backward hook: 4.356 ms
<img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png">


(After) Pre-backward hook: 0.483 ms
<img width="1025" alt="Screenshot 2023-04-03 at 6 32 25 PM" src="https://user-images.githubusercontent.com/31054793/229641301-971d3c60-a4f1-4561-bb33-7aa07c42d0bd.png">

The "after" value might increase slightly if I add back the `_writeback_orig_params()` check.


[ghstack-poisoned]
awgu added a commit that referenced this pull request Apr 4, 2023
ghstack-source-id: c458765fefa941c218242e27f02d0c3307a6d31c
Pull Request resolved: #98250
(Before) Pre-backward hook: 4.356 ms
<img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png">


(After) Pre-backward hook: 0.483 ms
<img width="1025" alt="Screenshot 2023-04-03 at 6 32 25 PM" src="https://user-images.githubusercontent.com/31054793/229641301-971d3c60-a4f1-4561-bb33-7aa07c42d0bd.png">

The "after" value might increase slightly if I add back the `_writeback_orig_params()` check.


[ghstack-poisoned]
awgu added a commit that referenced this pull request Apr 4, 2023
ghstack-source-id: 526f1c42f31b858228c66c8934857c8348ffa281
Pull Request resolved: #98250
@awgu awgu added ciflow/trunk Trigger trunk jobs on your pull request topic: improvements topic category labels Apr 4, 2023
@awgu awgu marked this pull request as ready for review April 4, 2023 13:34
@awgu
Copy link
Contributor Author

awgu commented Apr 4, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ZainRizvi pushed a commit that referenced this pull request Apr 19, 2023
This PR has `SHARD_GRAD_OP` (and `_HYBRID_SHARD_ZERO2`) skip `_use_sharded_views()` in the post-forward reshard since the strategy does not free the unsharded flat parameter and can preserve the unsharded views. This saves nontrivial CPU overhead both in the post-forward reshard (`_use_sharded_views()`) and the pre-backward unshard (`_use_unsharded_views()`).

<details>
<summary>(Before) Pre-backward hook: 4.356 ms</summary>

<img width="812" alt="Screenshot 2023-04-03 at 6 32 19 PM" src="https://user-images.githubusercontent.com/31054793/229641309-778cf1f9-4b5b-42ec-b2d8-0a1e6e7ce330.png">

</details>

<details>
<summary>(After) Pre-backward hook: 1.044 ms</summary>

![Screenshot 2023-04-04 at 9 05 53 AM](https://user-images.githubusercontent.com/31054793/229800917-9580ce6b-3721-469a-9212-f0cbfd8cbb52.png)

</details>

Pull Request resolved: #98250
Approved by: https://github.com/rohan-varma
@facebook-github-bot facebook-github-bot deleted the gh/awgu/377/head branch June 8, 2023 15:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants