Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose API to specify custom context manager for checkpoint #96783

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5608,6 +5608,39 @@ def foo(x, y, z):
out = checkpoint(foo, x, y, z, use_reentrant=False)
out.sum().backward()

def test_checkpointing_without_reentrant_with_context_fn(self):
class VerboseTorchDispatchMode(TorchDispatchMode):
def __init__(self):
self.operators = []

def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
self.operators.append(func.__name__)
return func(*args, **kwargs)

x = torch.tensor(1., requires_grad=True)
verbose_mode = VerboseTorchDispatchMode()

def context_fn():
return verbose_mode, contextlib.nullcontext()
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
self.assertEqual(verbose_mode.operators, ['sin.default'])

verbose_mode.operators = []

def context_fn():
return contextlib.nullcontext(), verbose_mode
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
out.backward()
self.assertEqual(
verbose_mode.operators,
['detach.default', 'detach.default', 'detach.default', 'detach.default', 'sin.default']
)

with self.assertRaisesRegex(Exception, "only supported when use_reentrant=False"):
out = checkpoint(lambda x: x.sin(), x, use_reentrant=True, context_fn=context_fn)

def test_access_saved_tensor_twice_without_recomputation_works(self):
count = [0]

Expand Down
41 changes: 35 additions & 6 deletions torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import warnings
import weakref
from weakref import ReferenceType
from typing import Any, Iterable, List, Tuple, Dict, Optional, DefaultDict
from typing import Any, Callable, ContextManager, Iterable, List, Tuple, Dict, Optional, DefaultDict
from collections import defaultdict
import uuid
import contextlib

__all__ = [
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
"check_backward_validity", "detach_variable", "get_device_states",
"set_device_states",
"set_device_states", "noop_context_fn"
]

def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
Expand Down Expand Up @@ -165,7 +165,17 @@ def backward(ctx, *args):
return (None, None) + grads


def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
def noop_context_fn():
return contextlib.nullcontext(), contextlib.nullcontext()


def checkpoint(
function,
*args,
use_reentrant: bool = True,
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
soulitzer marked this conversation as resolved.
Show resolved Hide resolved
**kwargs
):
r"""Checkpoint a model or part of the model

Checkpointing works by trading compute for memory. Rather than storing all
Expand Down Expand Up @@ -239,6 +249,10 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
keyword arguments input into the checkpointed function. Note that future
versions of PyTorch will default to ``use_reentrant=False``.
Default: ``True``
context_fn(Callable, optional): A callable returning a tuple of two
context managers. The function and its recomputation will be run
under the first and second context managers respectively.
This argument is only supported if ``use_reentrant=False``.
args: tuple containing inputs to the :attr:`function`

Returns:
Expand All @@ -250,11 +264,14 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

if use_reentrant:
if context_fn is not noop_context_fn:
raise ValueError("Passing context_fn is only supported when use_reentrant=False.")
return CheckpointFunction.apply(function, preserve, *args)
else:
return _checkpoint_without_reentrant(
function,
preserve,
context_fn,
*args,
**kwargs,
)
Expand Down Expand Up @@ -626,7 +643,13 @@ def unpack_hook(holder):

# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
# saving/restoring of global state is handled here.
def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
def _checkpoint_without_reentrant(
fn,
preserve_rng_state=True,
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
*args,
**kwargs
):
"""Checkpointining without re-entrant autograd
Args:
function: describes what to run in the forward pass of the model or
Expand All @@ -637,9 +660,13 @@ def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
preserve_rng_state(bool, optional): Omit stashing and restoring
the RNG state during each checkpoint.
Default: ``True``
context_fn(Callable, optional): A callable returning a tuple of two
context managers. The function and its recomputation will be run
under the first and second context managers respectively.
*args: Arguments to pass in to the given ``function``.
**kwargs: Keyword arguments to pass into the given ``function``.
"""
forward_context, recompute_context = context_fn()
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()

Expand Down Expand Up @@ -669,7 +696,8 @@ def recompute_fn(*inputs):
set_device_states(fwd_gpu_devices, fwd_gpu_states)

with torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**cpu_autocast_kwargs):
torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
fn(*args, **kwargs)

new_frame = _CheckpointFrame(recompute_fn)
Expand All @@ -680,7 +708,8 @@ def recompute_fn(*inputs):
if new_frame.input_saver.grad_fn is None:
return fn(*args, **kwargs)

with _checkpoint_hook(new_frame):
with _checkpoint_hook(new_frame), \
forward_context:
ret = fn(*args, **kwargs)

if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
Expand Down