-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BE] Make some simplifications to torch.utils.checkpoint logic #101193
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101193
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7abeb86: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 7e301a7d26a0b846857dccd96f68d02f48dc9162 Pull Request resolved: #101193
@@ -477,7 +474,7 @@ def forward(input): | |||
# - If there are multiple .grad()/.backward() calls, we would perform backward | |||
# on the recomputed graph even if early-stop is enabled (see the example below) | |||
# | |||
# [ Multiple backwards ] | |||
# [ retain_graph is False ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a simpler example to demonstrate the same issue since what really matters is that retain_graph=False, not that we have multiple backwards
if holder.handles.get(gid, None) is None: | ||
holder.handles[gid] = _Handle() | ||
assert holder.handles.get(gid, None) is None, "this is a bug, please file an issue" | ||
holder.handles[gid] = _Handle() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make this assertion because (ignoring any multithreading issues), it is not possible to compute for the same holder twice in the same backward, since after we recompute for the first time, we would've set is_recomputed=True.
torch/utils/checkpoint.py
Outdated
) | ||
assert holder.handles[gid] in frame.recomputed[gid], "this is a bug, please file an issue" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making this runtime error into an internal assert because we don't ever expect to trigger that runtime error.
Suppose if we were able to reach that runtime error, one of the two following must hold. In order to show taht this runtime error will not be reached, we will go on to show that both are false:
- the handle was never an entry to frame.recomputed[gid]
- the handle was an entry at some point but the entry was cleared
(1) is false because for handle to not be None, it must've been added to frame.recompute[gid] at some point. During the outer pack_hook, the handle is None. We only set holder.handles[gid] to something non-None in the inner pack_hook where we always subsequently set recomputed[holder.handles[gid]
(2) if handle was once an entry, and then that entry is now gone - since frame.recompute[gid] is a weak key dictionary and we never explicitly remove any entries, this implies that holder.handles[gid] is no longer alive, which is false, given that the only time we set a non-None value to it is if it is None, and the only value we set to it is None.
The reason we once had this runtime error was because we previously create holder.handles[gid] in the outer pack_hook, so if recompute pass does not recompute handle, e.g. because we saved fewer tensors in the backward pass, we would've errored out here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pain "assert" are not great, you can make this an AssertionError if you want though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cleanup sounds good!
torch/utils/checkpoint.py
Outdated
) | ||
assert holder.handles[gid] in frame.recomputed[gid], "this is a bug, please file an issue" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pain "assert" are not great, you can make this an AssertionError if you want though.
…ogic" [ghstack-poisoned]
ghstack-source-id: f3ecd097f12bbbcc641d9640407ad6a853dcff6e Pull Request resolved: #101193
…ogic" [ghstack-poisoned]
ghstack-source-id: 79674e085c7a52fd800746e430e8282566650927 Pull Request resolved: #101193
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
…ogic" [ghstack-poisoned]
ghstack-source-id: 900a3ef6f08adf694d2ed91524c54f53fc966d05 Pull Request resolved: #101193
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):