New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] enable autograd in forward prefetching #116792
Conversation
**problem** when prefetching for next forward, current forward may be annotated by `@torch.no_grad`. `param.grad_fn` keeps being None during prefetching. `_post_backward_hook` never gets triggered repro ```pytest test/distributed/fsdp/test_fsdp_freezing_weights.py``` **solution** this PR enabled autograd during prefetching (`_use_unsharded_views`), so `param.grad_fn` are properly assigned for next forward a longer-term fix would be moving `_use_unsharded_views` out of `_prefetch_handle` and put it in `_pre_forward_unshard`
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
[FSDP] enable autograd in forward prefetching
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116792
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e67952f with merge base 43fb1b6 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the quick fix and unit testing!
@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
**problem** when prefetching for next forward, current forward may be annotated by `@torch.no_grad`. `param.grad_fn` keeps being None during prefetching. `_post_backward_hook` never gets triggered repro ```pytest test/distributed/fsdp/test_fsdp_freezing_weights.py``` **solution** this PR enabled autograd during prefetching (`_use_unsharded_views`), so `param.grad_fn` are properly assigned for next forward a longer-term fix would be moving `_use_unsharded_views` out of `_prefetch_handle` and put it in `_pre_forward_unshard` Pull Request resolved: pytorch#116792 Approved by: https://github.com/awgu
problem
when prefetching for next forward, current forward may be annotated by
@torch.no_grad
.param.grad_fn
keeps being None during prefetching._post_backward_hook
never gets triggeredrepro
pytest test/distributed/fsdp/test_fsdp_freezing_weights.py
solution
this PR enabled autograd during prefetching (
_use_unsharded_views
), soparam.grad_fn
are properly assigned for next forwarda longer-term fix would be moving
_use_unsharded_views
out of_prefetch_handle
and put it in_pre_forward_unshard
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225