Skip to content

Commit

Permalink
Expose API to specify custom context manager for checkpoint
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
soulitzer committed Mar 14, 2023
1 parent fd5b2d2 commit f80e657
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
33 changes: 33 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5608,6 +5608,39 @@ def foo(x, y, z):
out = checkpoint(foo, x, y, z, use_reentrant=False)
out.sum().backward()

def test_checkpointing_without_reentrant_with_context_fn(self):
class VerboseTorchDispatchMode(TorchDispatchMode):
def __init__(self):
self.operators = []

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
self.operators.append(func.__name__)
return func(*args, **kwargs)

x = torch.tensor(1., requires_grad=True)
verbose_mode = VerboseTorchDispatchMode()

def context_fn():
return verbose_mode, contextlib.nullcontext()
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
self.assertEqual(verbose_mode.operators, ['sin.default'])

verbose_mode.operators = []

def context_fn():
return contextlib.nullcontext(), verbose_mode
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
out.backward()
self.assertEqual(
verbose_mode.operators,
['detach.default', 'detach.default', 'detach.default', 'detach.default', 'sin.default']
)

with self.assertRaisesRegex(Exception, "only supported when use_reentrant=False"):
out = checkpoint(lambda x: x.sin(), x, use_reentrant=True, context_fn=context_fn)

def test_access_saved_tensor_twice_without_recomputation_works(self):
count = [0]

Expand Down
41 changes: 35 additions & 6 deletions torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import warnings
import weakref
from weakref import ReferenceType
from typing import Any, Iterable, List, Tuple, Dict, Optional, DefaultDict
from typing import Any, Callable, ContextManager, Iterable, List, Tuple, Dict, Optional, DefaultDict
from collections import defaultdict
import uuid
import contextlib

__all__ = [
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
"check_backward_validity", "detach_variable", "get_device_states",
"set_device_states",
"set_device_states", "noop_context_fn"
]

def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
Expand Down Expand Up @@ -165,7 +165,17 @@ def backward(ctx, *args):
return (None, None) + grads


def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
def noop_context_fn():
return contextlib.nullcontext(), contextlib.nullcontext()


def checkpoint(
function,
*args,
use_reentrant: bool = True,
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
**kwargs
):
r"""Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all
Expand Down Expand Up @@ -239,6 +249,10 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
keyword arguments input into the checkpointed function. Note that future
versions of PyTorch will default to ``use_reentrant=False``.
Default: ``True``
context_fn(Callable, optional): A callable returning a tuple of two
context managers. The function and its recomputation will be run
under the first and second context managers respectively.
This argument is only supported if ``use_reentrant=False``.
args: tuple containing inputs to the :attr:`function`
Returns:
Expand All @@ -250,11 +264,14 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

if use_reentrant:
if context_fn is not noop_context_fn:
raise ValueError("Passing context_fn is only supported when use_reentrant=False.")
return CheckpointFunction.apply(function, preserve, *args)
else:
return _checkpoint_without_reentrant(
function,
preserve,
context_fn,
*args,
**kwargs,
)
Expand Down Expand Up @@ -626,7 +643,13 @@ def unpack_hook(holder):

# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
# saving/restoring of global state is handled here.
def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
def _checkpoint_without_reentrant(
fn,
preserve_rng_state=True,
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
*args,
**kwargs
):
"""Checkpointining without re-entrant autograd
Args:
function: describes what to run in the forward pass of the model or
Expand All @@ -637,9 +660,13 @@ def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
preserve_rng_state(bool, optional): Omit stashing and restoring
the RNG state during each checkpoint.
Default: ``True``
context_fn(Callable, optional): A callable returning a tuple of two
context managers. The function and its recomputation will be run
under the first and second context managers respectively.
*args: Arguments to pass in to the given ``function``.
**kwargs: Keyword arguments to pass into the given ``function``.
"""
forward_context, recompute_context = context_fn()
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()

Expand Down Expand Up @@ -669,7 +696,8 @@ def recompute_fn(*inputs):
set_device_states(fwd_gpu_devices, fwd_gpu_states)

with torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**cpu_autocast_kwargs):
torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
fn(*args, **kwargs)

new_frame = _CheckpointFrame(recompute_fn)
Expand All @@ -680,7 +708,8 @@ def recompute_fn(*inputs):
if new_frame.input_saver.grad_fn is None:
return fn(*args, **kwargs)

with _checkpoint_hook(new_frame):
with _checkpoint_hook(new_frame), \
forward_context:
ret = fn(*args, **kwargs)

if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
Expand Down

0 comments on commit f80e657

Please sign in to comment.