-
Notifications
You must be signed in to change notification settings - Fork 25k
Option to preserve bitwise accuracy of gradient checkpointed vs non-checkpointed dropout #14253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Option to preserve bitwise accuracy of gradient checkpointed vs non-checkpointed dropout #14253
Conversation
torch/utils/checkpoint.py
Outdated
# rng states for those devices as well...but I see no easy way to | ||
# handle such cases. | ||
ctx.fwd_cpu_rng_state = torch.get_rng_state() | ||
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/utils/checkpoint.py
Outdated
each checkpointed segment during backward. This can result in running | ||
states like the RNG state used for dropout to be advanced more than | ||
they would be without checkpointing, which can cause checkpoints that | ||
include dropout invocations to lose end-to-end bitwise accuracy as |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, let's see if tests pass
torch/utils/checkpoint.py
Outdated
torch.set_rng_state(ctx.fwd_cpu_rng_state) | ||
if ctx.had_cuda_in_fwd: | ||
current_cuda_rng_state = torch.cuda.get_rng_state() | ||
torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/utils/checkpoint.py
Outdated
|
||
|
||
def checkpoint(function, *args): | ||
def checkpoint(function, *args, preserve_rng_state=False): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized there are some pretty fundamental design flaws in what you're doing (we might not be caching the right RNG state). Also, please don't use global state unless it's necessary...
ctx.had_cuda_in_fwd = False | ||
if torch.cuda._initialized: | ||
ctx.had_cuda_in_fwd = True | ||
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
# Global switch to toggle whether or not checkpointed passes stash and restore | ||
# the RNG state. If True, any checkpoints making use of RNG should achieve deterministic | ||
# output compared to non-checkpointed passes. | ||
preserve_rng_state = True |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@apaszke I know we might not be preserving the right RNG states. I left some comments to that effect in the code. It's challenging because A. we can't know what devices will be used in the To deal with A. I believe we have two choices:
I have no idea how to deal with B., but B. seems like more of an edge case. The global state was necessary (imo) because I couldn't figure out a way to mix *args with kwargs in a way that was both python 2.7-compliant and would not break existing code. I was prepared to experiment with other approaches but it appears you can live with the global flag. |
Some commentary:
I think a better compromise would be to not use the current device to guide the heuristic, but any tensor input. That would be a reasonable solution I think, but it would still warrant some commentary in the docs. Regarding your second issue, you should add def my_fn(arg1, *args, **kwargs):
my_kwarg = kwargs.pop('my_kwarg', <default value>)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) |
…4518) Summary: This PR intends to address apaszke's concerns in #14253 (comment). Preserving the rng state is now controlled by a kwarg rather than a global state, hopefully in a python 2.7-compatible way. Additionally, the checkpointing function stashes and restores the RNG states of 1. devices associated with all input tensor args to run_fn as well as 2. the current device. I could easily change this to only save and restore the RNG states associated 1. alone. This would simplify the logic to create a [deduplicated, ordered](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R37) list of devices considered active. I'm wondering if the [get_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R32) and [set_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) functions are general enough to reside elsewhere (presumably torch/random.py). I'm also wondering if the check on [torch.cuda._initialized](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) would be better placed within `get_device_states`. Pull Request resolved: #14518 Differential Revision: D13356210 Pulled By: ezyang fbshipit-source-id: afa4cc21ce7862142d5cb1dec3750018df222039
This issue was noticed, and fix proposed, by @raulpuric.
Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can result in the RNG state advancing more than it would without checkpointing, which can cause checkpoints that include dropout invocations to lose end-to-end bitwise accuracy as compared to non-checkpointed passes.
The present PR contains optional logic to juggle the RNG states such that checkpointed passes containing dropout achieve bitwise accuracy with non-checkpointed equivalents.** The user requests this behavior by supplying
preserve_rng_state=True
totorch.utils.checkpoint
ortorch.utils.checkpoint_sequential
.Currently,
preserve_rng_state=True
may incur a moderate performance hit because restoring MTGP states can be expensive. However, restoring Philox states is dirt cheap, so @syed-ahmed's RNG refactor, once merged, will make this option more or less free.I'm a little wary of the def checkpoint(function, *args, preserve_rng_state=False): argument-passing method (specifically, putting a kwarg after a variable argument list). Python 3 seems happy with it.
Edit: It appears Python 2.7 is NOT happy with a kwarg after *args.
preserve_rng_state
also needs to be communicated in a way that doesn't break any existing usage. I'm open to suggestions (a global flag perhaps)?**Batchnorm may still be an issue, but that's a battle for another day.