Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] enable autograd in forward prefetching #116792

Closed
wants to merge 4 commits into from

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Jan 4, 2024

problem
when prefetching for next forward, current forward may be annotated by
@torch.no_grad. param.grad_fn keeps being None during prefetching.
_post_backward_hook never gets triggered

repro
pytest test/distributed/fsdp/test_fsdp_freezing_weights.py

solution
this PR enabled autograd during prefetching (_use_unsharded_views), so
param.grad_fn are properly assigned for next forward

a longer-term fix would be moving _use_unsharded_views out of
_prefetch_handle and put it in _pre_forward_unshard

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225

weifengpy and others added 4 commits January 4, 2024 12:53
**problem**
when prefetching for next forward, current forward may be annotated by
`@torch.no_grad`. `param.grad_fn` keeps being None during prefetching.
`_post_backward_hook` never gets triggered

repro
```pytest test/distributed/fsdp/test_fsdp_freezing_weights.py```

**solution**
this PR enabled autograd during prefetching (`_use_unsharded_views`), so
`param.grad_fn` are properly assigned for next forward

a longer-term fix would be moving `_use_unsharded_views` out of
`_prefetch_handle` and put it in `_pre_forward_unshard`
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
[FSDP] enable autograd in forward prefetching
Copy link

pytorch-bot bot commented Jan 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116792

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e67952f with merge base 43fb1b6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Jan 4, 2024
@github-actions github-actions bot added oncall: distributed Add this issue/PR to distributed oncall triage queue ciflow/inductor labels Jan 4, 2024
@weifengpy weifengpy requested a review from awgu January 4, 2024 21:56
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the quick fix and unit testing!

@facebook-github-bot
Copy link
Contributor

@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@weifengpy weifengpy added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 4, 2024
@weifengpy
Copy link
Contributor Author

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@atalman atalman added this to the 2.2.1 milestone Jan 17, 2024
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Feb 12, 2024
**problem**
when prefetching for next forward, current forward may be annotated by
`@torch.no_grad`. `param.grad_fn` keeps being None during prefetching.
`_post_backward_hook` never gets triggered

repro
```pytest test/distributed/fsdp/test_fsdp_freezing_weights.py```

**solution**
this PR enabled autograd during prefetching (`_use_unsharded_views`), so
`param.grad_fn` are properly assigned for next forward

a longer-term fix would be moving `_use_unsharded_views` out of
`_prefetch_handle` and put it in `_pre_forward_unshard`

Pull Request resolved: pytorch#116792
Approved by: https://github.com/awgu
atalman pushed a commit that referenced this pull request Feb 14, 2024
Co-authored-by: Wei (Will) Feng <134637289+weifengpy@users.noreply.github.com>
resolved: #116792
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants