Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Fix use_orig_params=True + AC #87413

Closed
wants to merge 4 commits into from
Closed

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Oct 20, 2022

Stack from ghstack:

Without this change, the post-backward hooks do not run when using reentrant activation checkpointing.

Explanation
FSDP registers the original parameters as plain Tensors in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the FlatParameters. FSDP registers the post-backward hooks in its pre-forward.

For use_orig_params=True, FSDP replaces the plain Tensors with the sharded nn.Parameters in the post-forward when resharding. This differs from use_orig_params=False, which keeps the plain Tensors registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for use_orig_params=True, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd.

My initial solution was to simply have FSDP restore the original parameters as plain Tensors again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The FlatParameter's AccumulateGrad object may change after the original pre-forward when performing a recomputed forward.

The new approach in this PR is to follow the use_orig_params=False way -- namely, to preserve the plain Tensor variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary).

An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.)

Test Plan
I augmented the existing reentrant checkpointing unit tests to also test use_orig_params=True. I also verified that the pycls model does not error (even with the new approach).

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 20, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87413

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 87d791c:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Oct 20, 2022
awgu added a commit that referenced this pull request Oct 20, 2022
ghstack-source-id: fffc87227f7726e85633919609c5526176eee796
Pull Request resolved: #87413
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a unittest for this change?

awgu added a commit to awgu/pytorch that referenced this pull request Oct 24, 2022
ghstack-source-id: fffc87227f7726e85633919609c5526176eee796
Pull Request resolved: pytorch#87413
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing. This is what happens when we play with `.data` fire.

The reason the hooks do not run is exactly why we need plain `Tensor` views in the forward pass normally without AC: We need the original parameters to be connected in autograd to the `FlatParameter` via `split()` and `view()` so that the gradients propagate through to the `FlatParameter`'s gradient. Reentrant AC runs the 1st forward pass with [`no_grad()`](https://github.com/pytorch/pytorch/blob/7e83f65ad502992a8d75c91eea2cf3de69bb0b7a/torch/utils/checkpoint.py#L106) and relies on the 2nd recomputed forward pass to autograd visible. Therefore, we need to use `Tensor` views in the backward pass as well, not `nn.Parameter` views.

[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 24, 2022
ghstack-source-id: 76cca85cd08d7d51ab6eb6eac8d82ef8f5328f23
Pull Request resolved: #87413
@awgu awgu requested review from rohan-varma and removed request for mingzhe09088 October 24, 2022 13:19
@awgu awgu added the topic: improvements topic category label Oct 24, 2022
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing.

**Explanation**
FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward.

For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd.

My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward.

The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary).

**Test Plan**
I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach).

[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 24, 2022
ghstack-source-id: 6a0c26fae8b2ac18b4396c0caa8865a2690bbc6e
Pull Request resolved: #87413
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing.

**Explanation**
FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward.

For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd.

My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward.

The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary).

An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.)

**Test Plan**
I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach).

[ghstack-poisoned]
awgu added a commit that referenced this pull request Oct 24, 2022
ghstack-source-id: 92b0904f308f5ae1838cc623250f62ab0ae851a7
Pull Request resolved: #87413
awgu added a commit to awgu/pytorch that referenced this pull request Oct 24, 2022
ghstack-source-id: 6a0c26fae8b2ac18b4396c0caa8865a2690bbc6e
Pull Request resolved: pytorch#87413
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix and excellent debugging!

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 24, 2022
@awgu
Copy link
Contributor Author

awgu commented Oct 24, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 24, 2022
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing.

**Explanation**
FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward.

For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd.

My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward.

The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary).

An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.)

**Test Plan**
I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach).
Pull Request resolved: #87413
Approved by: https://github.com/rohan-varma
sgrigory pushed a commit to sgrigory/pytorch that referenced this pull request Oct 28, 2022
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing.

**Explanation**
FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward.

For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd.

My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward.

The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary).

An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.)

**Test Plan**
I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach).
Pull Request resolved: pytorch#87413
Approved by: https://github.com/rohan-varma
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing.

**Explanation**
FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward.

For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd.

My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward.

The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary).

An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.)

**Test Plan**
I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach).
Pull Request resolved: pytorch#87413
Approved by: https://github.com/rohan-varma
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing.

**Explanation**
FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward.

For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd.

My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward.

The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary).

An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.)

**Test Plan**
I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach).
Pull Request resolved: pytorch#87413
Approved by: https://github.com/rohan-varma
@facebook-github-bot facebook-github-bot deleted the gh/awgu/137/head branch June 8, 2023 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: improvements topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants