-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[Traceable FSDP2] Ignore FSDP2 forward hook side-effects in AC; Support FSDP2 + AC #134997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… graph intermediate [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134997
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 7652b37 with merge base 8d68a02 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… usage with graph intermediate" cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support" cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support" cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support" cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support" cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support" cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support" cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support" Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support" Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support" Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
…n AC; Support FSDP2 + AC" > Ignore FSDP2 forward hook side-effects in AC Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220 So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation. ---- Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
…n AC; Support FSDP2 + AC" > Ignore FSDP2 forward hook side-effects in AC Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220 So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation. ---- Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
…n AC; Support FSDP2 + AC" > Ignore FSDP2 forward hook side-effects in AC Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220 So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation. ---- Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
…n AC; Support FSDP2 + AC" > Ignore FSDP2 forward hook side-effects in AC Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220 So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation. ---- Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
…n AC; Support FSDP2 + AC" > Ignore FSDP2 forward hook side-effects in AC Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220 So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation. ---- Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
…n AC; Support FSDP2 + AC" > Ignore FSDP2 forward hook side-effects in AC Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220 So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation. ---- Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
…n AC; Support FSDP2 + AC" > Ignore FSDP2 forward hook side-effects in AC Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220 So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation. ---- Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…rt FSDP2 + AC (pytorch#134997) > Ignore FSDP2 forward hook side-effects in AC Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220 So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation. ---- Test commands: - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor` - `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor` Pull Request resolved: pytorch#134997 Approved by: https://github.com/zou3519 ghstack dependencies: pytorch#135727
Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job:
pytorch/torch/distributed/_composable/fsdp/_fsdp_state.py
Lines 219 to 220 in 451eaf0
So when we use
speculate_subgraph
to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation.Test commands:
pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor
pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor
Stack from ghstack (oldest at bottom):
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec