-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Relax post-backward assert #89791
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89791
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a510d8d: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: ee0fd21a02c0fc671ca900e2b57083a1ef60edd2 Pull Request resolved: #89791
@@ -482,9 +482,13 @@ def _post_backward_hook( | |||
"FullyShardedDataParallel._post_backward_hook" | |||
): | |||
_assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD]) | |||
# For reentrant AC, the post-backward hook may run multiple times in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: For reentrant AC multiple times
This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases. [ghstack-poisoned]
ghstack-source-id: 90b4cb6e62e1dad7afccb625996cc797692c1db2 Pull Request resolved: #89791
This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases. [ghstack-poisoned]
ghstack-source-id: 3d8a27dc946209d413b1dd5e6ed0d6815bb71721 Pull Request resolved: #89791
@pytorchbot rebase -s |
@pytorchbot successfully started a rebase job. Check the current status here |
This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases. [ghstack-poisoned]
Successfully rebased |
ghstack-source-id: 268645d4e371a830a9857b981311dbfd6455d05a Pull Request resolved: #89791
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
With combining FSDP with reentrant checkpointing, the post backward hook might run twice, and then hit [this error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487). This is because reentrant backward uses nested autograd GraphTasks. The inner GraphTask is not aware of the outer one and therefore will flush pending `AccumulateGrad` invocations on exit, which in turn triggers the post backward hooks registered by FSDP. Later, the outer GraphTask will trigger that again, leading to the above error. PR #89791 relaxes the FSDP training state check, but we still run into grad value check failures occasionally. Therefore, this PR only lands the test for non-reentrant test, and we can enable the reentrant test when the accuracy issues are addressed. [ghstack-poisoned]
With combining FSDP with reentrant checkpointing, the post backward hook might run twice, and then hit [this error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487). This is because reentrant backward uses nested autograd GraphTasks. The inner GraphTask is not aware of the outer one and therefore will flush pending `AccumulateGrad` invocations on exit, which in turn triggers the post backward hooks registered by FSDP. Later, the outer GraphTask will trigger that again, leading to the above error. PR #89791 relaxes the FSDP training state check, but we still run into grad value check failures occasionally. Therefore, this PR only lands the test for non-reentrant test, and we can enable the reentrant test when the accuracy issues are addressed. ghstack-source-id: 8848c4cbf572c3a5acd8a9c2fd2b22539a65375f Pull Request resolved: #89781
Merge failedReason: 1 additional jobs have failed, first few of them are: trunk Details for Dev Infra teamRaised by workflow job |
With combining FSDP with reentrant checkpointing, the post backward hook might run twice, and then hit [this error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487). This is because reentrant backward uses nested autograd GraphTasks. The inner GraphTask is not aware of the outer one and therefore will flush pending `AccumulateGrad` invocations on exit, which in turn triggers the post backward hooks registered by FSDP. Later, the outer GraphTask will trigger that again, leading to the above error. PR #89791 relaxes the FSDP training state check, but we still run into grad value check failures occasionally. Therefore, this PR only lands the test for non-reentrant test, and we can enable the reentrant test when the accuracy issues are addressed. Pull Request resolved: #89781 Approved by: https://github.com/rohan-varma
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
With combining FSDP with reentrant checkpointing, the post backward hook might run twice, and then hit [this error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487). This is because reentrant backward uses nested autograd GraphTasks. The inner GraphTask is not aware of the outer one and therefore will flush pending `AccumulateGrad` invocations on exit, which in turn triggers the post backward hooks registered by FSDP. Later, the outer GraphTask will trigger that again, leading to the above error. PR pytorch#89791 relaxes the FSDP training state check, but we still run into grad value check failures occasionally. Therefore, this PR only lands the test for non-reentrant test, and we can enable the reentrant test when the accuracy issues are addressed. Pull Request resolved: pytorch#89781 Approved by: https://github.com/rohan-varma
This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases. Pull Request resolved: pytorch#89791 Approved by: https://github.com/zhaojuanmao
Stack from ghstack (oldest at bottom):
This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases.