Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Relax post-backward assert #89791

Closed
wants to merge 4 commits into from
Closed

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Nov 28, 2022

Stack from ghstack (oldest at bottom):

This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 28, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89791

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a510d8d:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Nov 28, 2022
awgu added a commit that referenced this pull request Nov 28, 2022
ghstack-source-id: ee0fd21a02c0fc671ca900e2b57083a1ef60edd2
Pull Request resolved: #89791
@@ -482,9 +482,13 @@ def _post_backward_hook(
"FullyShardedDataParallel._post_backward_hook"
):
_assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
# For reentrant AC, the post-backward hook may run multiple times in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: For reentrant AC multiple times

This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases.

[ghstack-poisoned]
awgu added a commit that referenced this pull request Nov 28, 2022
ghstack-source-id: 90b4cb6e62e1dad7afccb625996cc797692c1db2
Pull Request resolved: #89791
This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases.

[ghstack-poisoned]
awgu added a commit that referenced this pull request Nov 28, 2022
ghstack-source-id: 3d8a27dc946209d413b1dd5e6ed0d6815bb71721
Pull Request resolved: #89791
@awgu awgu added the topic: improvements topic category label Nov 28, 2022
@awgu
Copy link
Contributor Author

awgu commented Nov 28, 2022

@pytorchbot rebase -s

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases.

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/awgu/218/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/89791)

pytorchmergebot pushed a commit that referenced this pull request Nov 28, 2022
ghstack-source-id: 268645d4e371a830a9857b981311dbfd6455d05a
Pull Request resolved: #89791
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 28, 2022
@awgu
Copy link
Contributor Author

awgu commented Nov 29, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

mrshenli added a commit that referenced this pull request Nov 29, 2022
With combining FSDP with reentrant checkpointing, the post backward
hook might run twice, and then hit [this
error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487).
This is because reentrant backward uses nested autograd GraphTasks.
The inner GraphTask is not aware of the outer one and therefore
will flush pending `AccumulateGrad` invocations on exit, which in
turn triggers the post backward hooks registered by FSDP. Later,
the outer GraphTask will trigger that again, leading to the above
error.

PR #89791 relaxes the FSDP training state check, but we still run
into grad value check failures occasionally. Therefore, this PR only
lands the test for non-reentrant test, and we can enable the
reentrant test when the accuracy issues are addressed.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Nov 29, 2022
With combining FSDP with reentrant checkpointing, the post backward
hook might run twice, and then hit [this
error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487).
This is because reentrant backward uses nested autograd GraphTasks.
The inner GraphTask is not aware of the outer one and therefore
will flush pending `AccumulateGrad` invocations on exit, which in
turn triggers the post backward hooks registered by FSDP. Later,
the outer GraphTask will trigger that again, leading to the above
error.

PR #89791 relaxes the FSDP training state check, but we still run
into grad value check failures occasionally. Therefore, this PR only
lands the test for non-reentrant test, and we can enable the
reentrant test when the accuracy issues are addressed.

ghstack-source-id: 8848c4cbf572c3a5acd8a9c2fd2b22539a65375f
Pull Request resolved: #89781
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 additional jobs have failed, first few of them are: trunk

Details for Dev Infra team Raised by workflow job

pytorchmergebot pushed a commit that referenced this pull request Nov 29, 2022
With combining FSDP with reentrant checkpointing, the post backward
hook might run twice, and then hit [this
error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487).
This is because reentrant backward uses nested autograd GraphTasks.
The inner GraphTask is not aware of the outer one and therefore
will flush pending `AccumulateGrad` invocations on exit, which in
turn triggers the post backward hooks registered by FSDP. Later,
the outer GraphTask will trigger that again, leading to the above
error.

PR #89791 relaxes the FSDP training state check, but we still run
into grad value check failures occasionally. Therefore, this PR only
lands the test for non-reentrant test, and we can enable the
reentrant test when the accuracy issues are addressed.
Pull Request resolved: #89781
Approved by: https://github.com/rohan-varma
@awgu
Copy link
Contributor Author

awgu commented Nov 29, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
With combining FSDP with reentrant checkpointing, the post backward
hook might run twice, and then hit [this
error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487).
This is because reentrant backward uses nested autograd GraphTasks.
The inner GraphTask is not aware of the outer one and therefore
will flush pending `AccumulateGrad` invocations on exit, which in
turn triggers the post backward hooks registered by FSDP. Later,
the outer GraphTask will trigger that again, leading to the above
error.

PR pytorch#89791 relaxes the FSDP training state check, but we still run
into grad value check failures occasionally. Therefore, this PR only
lands the test for non-reentrant test, and we can enable the
reentrant test when the accuracy issues are addressed.
Pull Request resolved: pytorch#89781
Approved by: https://github.com/rohan-varma
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
This assert was accidentally made stricter when transitioning from per-FSDP-instance training state to per-handle training state. This PR relaxes it again, which should restore compatibility for some reentrant AC plus FSDP cases.
Pull Request resolved: pytorch#89791
Approved by: https://github.com/zhaojuanmao
@facebook-github-bot facebook-github-bot deleted the gh/awgu/218/head branch June 8, 2023 15:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants