-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Fix use_orig_params=True
+ AC
#87413
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87413
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 87d791c: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: fffc87227f7726e85633919609c5526176eee796 Pull Request resolved: #87413
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a unittest for this change?
ghstack-source-id: fffc87227f7726e85633919609c5526176eee796 Pull Request resolved: pytorch#87413
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing. This is what happens when we play with `.data` fire. The reason the hooks do not run is exactly why we need plain `Tensor` views in the forward pass normally without AC: We need the original parameters to be connected in autograd to the `FlatParameter` via `split()` and `view()` so that the gradients propagate through to the `FlatParameter`'s gradient. Reentrant AC runs the 1st forward pass with [`no_grad()`](https://github.com/pytorch/pytorch/blob/7e83f65ad502992a8d75c91eea2cf3de69bb0b7a/torch/utils/checkpoint.py#L106) and relies on the 2nd recomputed forward pass to autograd visible. Therefore, we need to use `Tensor` views in the backward pass as well, not `nn.Parameter` views. [ghstack-poisoned]
ghstack-source-id: 76cca85cd08d7d51ab6eb6eac8d82ef8f5328f23 Pull Request resolved: #87413
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing. **Explanation** FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward. For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd. My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward. The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary). **Test Plan** I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach). [ghstack-poisoned]
ghstack-source-id: 6a0c26fae8b2ac18b4396c0caa8865a2690bbc6e Pull Request resolved: #87413
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing. **Explanation** FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward. For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd. My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward. The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary). An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.) **Test Plan** I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach). [ghstack-poisoned]
ghstack-source-id: 92b0904f308f5ae1838cc623250f62ab0ae851a7 Pull Request resolved: #87413
ghstack-source-id: 6a0c26fae8b2ac18b4396c0caa8865a2690bbc6e Pull Request resolved: pytorch#87413
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix and excellent debugging!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing. **Explanation** FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward. For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd. My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward. The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary). An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.) **Test Plan** I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach). Pull Request resolved: #87413 Approved by: https://github.com/rohan-varma
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing. **Explanation** FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward. For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd. My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward. The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary). An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.) **Test Plan** I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach). Pull Request resolved: pytorch#87413 Approved by: https://github.com/rohan-varma
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing. **Explanation** FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward. For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd. My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward. The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary). An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.) **Test Plan** I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach). Pull Request resolved: pytorch#87413 Approved by: https://github.com/rohan-varma
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing. **Explanation** FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward. For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd. My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward. The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary). An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.) **Test Plan** I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach). Pull Request resolved: pytorch#87413 Approved by: https://github.com/rohan-varma
Stack from ghstack:
use_orig_params=True
+ AC #87413 [FSDP] Fixuse_orig_params=True
+ ACNone
edge case #87308 [FSDP][2/N] Fix grad zero vs.None
edge casesummon_full_params(with_grads)
None
gradient #87314 [FSDP][1/N] Updatesummon_full_params(with_grads)
None
gradientWithout this change, the post-backward hooks do not run when using reentrant activation checkpointing.
Explanation
FSDP registers the original parameters as plain
Tensor
s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into theFlatParameter
s. FSDP registers the post-backward hooks in its pre-forward.For
use_orig_params=True
, FSDP replaces the plainTensor
s with the shardednn.Parameter
s in the post-forward when resharding. This differs fromuse_orig_params=False
, which keeps the plainTensor
s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, foruse_orig_params=True
, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd.My initial solution was to simply have FSDP restore the original parameters as plain
Tensor
s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. TheFlatParameter
'sAccumulateGrad
object may change after the original pre-forward when performing a recomputed forward.The new approach in this PR is to follow the
use_orig_params=False
way -- namely, to preserve the plainTensor
variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary).An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.)
Test Plan
I augmented the existing reentrant checkpointing unit tests to also test
use_orig_params=True
. I also verified that the pycls model does not error (even with the new approach).