Skip to content

Option to preserve bitwise accuracy of gradient checkpointed vs non-checkpointed dropout #14253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

mcarilli
Copy link
Collaborator

@mcarilli mcarilli commented Nov 20, 2018

This issue was noticed, and fix proposed, by @raulpuric.

Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can result in the RNG state advancing more than it would without checkpointing, which can cause checkpoints that include dropout invocations to lose end-to-end bitwise accuracy as compared to non-checkpointed passes.

The present PR contains optional logic to juggle the RNG states such that checkpointed passes containing dropout achieve bitwise accuracy with non-checkpointed equivalents.** The user requests this behavior by supplying preserve_rng_state=True to torch.utils.checkpoint or torch.utils.checkpoint_sequential.

Currently, preserve_rng_state=True may incur a moderate performance hit because restoring MTGP states can be expensive. However, restoring Philox states is dirt cheap, so @syed-ahmed's RNG refactor, once merged, will make this option more or less free.

I'm a little wary of the def checkpoint(function, *args, preserve_rng_state=False): argument-passing method (specifically, putting a kwarg after a variable argument list). Python 3 seems happy with it.
Edit: It appears Python 2.7 is NOT happy with a kwarg after *args. preserve_rng_state also needs to be communicated in a way that doesn't break any existing usage. I'm open to suggestions (a global flag perhaps)?

**Batchnorm may still be an issue, but that's a battle for another day.

# rng states for those devices as well...but I see no easy way to
# handle such cases.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

each checkpointed segment during backward. This can result in running
states like the RNG state used for dropout to be advanced more than
they would be without checkpointing, which can cause checkpoints that
include dropout invocations to lose end-to-end bitwise accuracy as

This comment was marked as off-topic.

Copy link
Member

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, let's see if tests pass

torch.set_rng_state(ctx.fwd_cpu_rng_state)
if ctx.had_cuda_in_fwd:
current_cuda_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.



def checkpoint(function, *args):
def checkpoint(function, *args, preserve_rng_state=False):

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized there are some pretty fundamental design flaws in what you're doing (we might not be caching the right RNG state). Also, please don't use global state unless it's necessary...

ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()

This comment was marked as off-topic.

# Global switch to toggle whether or not checkpointed passes stash and restore
# the RNG state. If True, any checkpoints making use of RNG should achieve deterministic
# output compared to non-checkpointed passes.
preserve_rng_state = True

This comment was marked as off-topic.

@mcarilli
Copy link
Collaborator Author

@apaszke I know we might not be preserving the right RNG states. I left some comments to that effect in the code. It's challenging because A. we can't know what devices will be used in the run_fn, and B. we also can't know if the user intends to initialize the cuda context on one or more devices within their run_fn.

To deal with A. I believe we have two choices:

  1. Stash and restore the RNG states for ALL devices visible to this process. Probably overly conservative/paranoid for the majority of cases.
  2. Stash and restore the RNG states for the current device.
    I went with option 2 because it seemed like a reasonable compromise. I can add a warning to the docs that if the run_fn moves operations to other devices, bitwise accuracy is no longer guaranteed.

I have no idea how to deal with B., but B. seems like more of an edge case.

The global state was necessary (imo) because I couldn't figure out a way to mix *args with kwargs in a way that was both python 2.7-compliant and would not break existing code. I was prepared to experiment with other approaches but it appears you can live with the global flag. torch.backends.cudnn.deterministic is also exposed as a global flag so this seemed symmetric to me.

@apaszke
Copy link
Contributor

apaszke commented Nov 26, 2018

Some commentary:

  1. Yes, I agree this is definitely too much, especially considering the overhead it adds.
  2. I understand, but this really is a bad compromise. Most user PyTorch programs hardly ever change the current device, but simply specify device=... or use .to(...) to transfer some inputs to a different one, and then do operations on those tensors. This means that if someone uses more than a single device, you'll have a miss rate of almost 100%.

I think a better compromise would be to not use the current device to guide the heuristic, but any tensor input. That would be a reasonable solution I think, but it would still warrant some commentary in the docs.


Regarding your second issue, you should add **kwargs to the functions you're interested in, and then do something like this:

def my_fn(arg1, *args, **kwargs):
    my_kwarg = kwargs.pop('my_kwarg', <default value>)
    if kwargs:
        raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

facebook-github-bot pushed a commit that referenced this pull request Dec 11, 2018
…4518)

Summary:
This PR intends to address apaszke's concerns in #14253 (comment).  Preserving the rng state is now controlled by a kwarg rather than a global state, hopefully in a python 2.7-compatible way.

Additionally, the checkpointing function stashes and restores the RNG states of
1. devices associated with all input tensor args to run_fn as well as
2. the current device.

I could easily change this to only save and restore the RNG states associated 1. alone.  This would simplify the logic to create a [deduplicated, ordered](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R37) list of devices considered active.

I'm wondering if the [get_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R32) and [set_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) functions are general enough to reside elsewhere (presumably torch/random.py).  I'm also wondering if the check on [torch.cuda._initialized](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) would be better placed within `get_device_states`.
Pull Request resolved: #14518

Differential Revision: D13356210

Pulled By: ezyang

fbshipit-source-id: afa4cc21ce7862142d5cb1dec3750018df222039
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants