Skip to content

Conversation

yf225
Copy link
Contributor

@yf225 yf225 commented Sep 3, 2024

Ignore FSDP2 forward hook side-effects in AC

Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job:

# When composing with module-hook-based activation checkpointing, the
# the pre-backward hook is responsible for the unshard

So when we use speculate_subgraph to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation.


Test commands:

  • pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor
  • pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor

Stack from ghstack (oldest at bottom):

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec

Copy link

pytorch-bot bot commented Sep 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134997

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 7652b37 with merge base 8d68a02 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Sep 3, 2024
… usage with graph intermediate"

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@yf225 yf225 marked this pull request as draft September 3, 2024 22:26
@yf225 yf225 changed the title [WIP][Traceable FSDP2] Replace unsharded_param graph input usage with graph intermediate [WIP][Traceable FSDP2] Replace unsharded_param graph input usage with graph intermediate; FSDP2 + AC support Sep 3, 2024
… usage with graph intermediate; FSDP2 + AC support"

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support"

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support"

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support"

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support"

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
yf225 added a commit that referenced this pull request Sep 4, 2024
… graph intermediate; FSDP2 + AC support

ghstack-source-id: 54f5553
Pull Request resolved: #134997
… usage with graph intermediate; FSDP2 + AC support"

cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
yf225 added a commit that referenced this pull request Sep 4, 2024
… graph intermediate; FSDP2 + AC support

ghstack-source-id: 09fed2e
Pull Request resolved: #134997
… usage with graph intermediate; FSDP2 + AC support"

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support"

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
… usage with graph intermediate; FSDP2 + AC support"

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
…n AC; Support FSDP2 + AC"

> Ignore FSDP2 forward hook side-effects in AC

Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: 
https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220

So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation.

----

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
…n AC; Support FSDP2 + AC"

> Ignore FSDP2 forward hook side-effects in AC

Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: 
https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220

So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation.

----

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
yf225 added a commit that referenced this pull request Sep 13, 2024
ghstack-source-id: e5c5a72
Pull Request resolved: #134997
…n AC; Support FSDP2 + AC"

> Ignore FSDP2 forward hook side-effects in AC

Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: 
https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220

So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation.

----

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
…n AC; Support FSDP2 + AC"

> Ignore FSDP2 forward hook side-effects in AC

Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: 
https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220

So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation.

----

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
…n AC; Support FSDP2 + AC"

> Ignore FSDP2 forward hook side-effects in AC

Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: 
https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220

So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation.

----

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
yf225 added a commit that referenced this pull request Sep 14, 2024
ghstack-source-id: f01327d
Pull Request resolved: #134997
@yf225
Copy link
Contributor Author

yf225 commented Sep 14, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 14, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

…n AC; Support FSDP2 + AC"

> Ignore FSDP2 forward hook side-effects in AC

Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: 
https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220

So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation.

----

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
…n AC; Support FSDP2 + AC"

> Ignore FSDP2 forward hook side-effects in AC

Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job: 
https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220

So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation.

----

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang rec

[ghstack-poisoned]
yf225 added a commit that referenced this pull request Sep 14, 2024
ghstack-source-id: 2972445
Pull Request resolved: #134997
@yf225 yf225 removed the merging label Sep 14, 2024
@yf225
Copy link
Contributor Author

yf225 commented Sep 14, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…rt FSDP2 + AC (pytorch#134997)

> Ignore FSDP2 forward hook side-effects in AC

Under AC, FSDP2 does not rely on forward hook to all-gather weights to do recomputation, instead it relies on pre-backward hook to do this job:
https://github.com/pytorch/pytorch/blob/451eaf0ff247090ca5a9648fd1e17c3c011737e1/torch/distributed/_composable/fsdp/_fsdp_state.py#L219-L220

So when we use `speculate_subgraph` to trace the utils.checkpoint AC region, we don't actually need to worry about FSDP2 forward hook's side effects and can safely ignore it, because we are not and we don't expect to re-run the FSDP2 forward hook during backward recomputation.

----

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_nested_fully_shard_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_backend_inductor`

Pull Request resolved: pytorch#134997
Approved by: https://github.com/zou3519
ghstack dependencies: pytorch#135727
@github-actions github-actions bot deleted the gh/yf225/116/head branch October 15, 2024 02:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: dynamo module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants