Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] Make some simplifications to torch.utils.checkpoint logic #101193

Closed
wants to merge 4 commits into from

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented May 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101193

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7abeb86:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

soulitzer added a commit that referenced this pull request May 11, 2023
ghstack-source-id: 7e301a7d26a0b846857dccd96f68d02f48dc9162
Pull Request resolved: #101193
@@ -477,7 +474,7 @@ def forward(input):
# - If there are multiple .grad()/.backward() calls, we would perform backward
# on the recomputed graph even if early-stop is enabled (see the example below)
#
# [ Multiple backwards ]
# [ retain_graph is False ]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a simpler example to demonstrate the same issue since what really matters is that retain_graph=False, not that we have multiple backwards

if holder.handles.get(gid, None) is None:
holder.handles[gid] = _Handle()
assert holder.handles.get(gid, None) is None, "this is a bug, please file an issue"
holder.handles[gid] = _Handle()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make this assertion because (ignoring any multithreading issues), it is not possible to compute for the same holder twice in the same backward, since after we recompute for the first time, we would've set is_recomputed=True.

)
assert holder.handles[gid] in frame.recomputed[gid], "this is a bug, please file an issue"
Copy link
Contributor Author

@soulitzer soulitzer May 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this runtime error into an internal assert because we don't ever expect to trigger that runtime error.

Suppose if we were able to reach that runtime error, one of the two following must hold. In order to show taht this runtime error will not be reached, we will go on to show that both are false:

  1. the handle was never an entry to frame.recomputed[gid]
  2. the handle was an entry at some point but the entry was cleared

(1) is false because for handle to not be None, it must've been added to frame.recompute[gid] at some point. During the outer pack_hook, the handle is None. We only set holder.handles[gid] to something non-None in the inner pack_hook where we always subsequently set recomputed[holder.handles[gid]
(2) if handle was once an entry, and then that entry is now gone - since frame.recompute[gid] is a weak key dictionary and we never explicitly remove any entries, this implies that holder.handles[gid] is no longer alive, which is false, given that the only time we set a non-None value to it is if it is None, and the only value we set to it is None.

The reason we once had this runtime error was because we previously create holder.handles[gid] in the outer pack_hook, so if recompute pass does not recompute handle, e.g. because we saved fewer tensors in the backward pass, we would've errored out here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pain "assert" are not great, you can make this an AssertionError if you want though.

@soulitzer soulitzer requested a review from albanD May 11, 2023 15:35
@soulitzer soulitzer added the topic: not user facing topic category label May 11, 2023
@pytorch pytorch deleted a comment from github-actions bot May 11, 2023
@soulitzer soulitzer changed the title Make some simplifications to torch.utils.checkpoint logic [BE] Make some simplifications to torch.utils.checkpoint logic May 11, 2023
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cleanup sounds good!

torch/utils/checkpoint.py Outdated Show resolved Hide resolved
)
assert holder.handles[gid] in frame.recomputed[gid], "this is a bug, please file an issue"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pain "assert" are not great, you can make this an AssertionError if you want though.

@soulitzer soulitzer requested a review from albanD May 11, 2023 21:40
soulitzer added a commit that referenced this pull request May 11, 2023
ghstack-source-id: f3ecd097f12bbbcc641d9640407ad6a853dcff6e
Pull Request resolved: #101193
soulitzer added a commit that referenced this pull request May 11, 2023
ghstack-source-id: 79674e085c7a52fd800746e430e8282566650927
Pull Request resolved: #101193
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

@soulitzer soulitzer added the ciflow/trunk Trigger trunk jobs on your pull request label May 11, 2023
soulitzer added a commit that referenced this pull request May 11, 2023
ghstack-source-id: 900a3ef6f08adf694d2ed91524c54f53fc966d05
Pull Request resolved: #101193
@soulitzer
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants