Skip to content

Commit

Permalink
[WIP] Improve debuggability of activation checkpointing
Browse files Browse the repository at this point in the history
ghstack-source-id: bc0782e3977d0844f71ffabaa2002c370bd6f3f6
Pull Request resolved: #102241
  • Loading branch information
soulitzer committed May 25, 2023
1 parent 2e2a746 commit 07647a8
Showing 1 changed file with 252 additions and 14 deletions.
266 changes: 252 additions & 14 deletions torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Tuple,
)
from weakref import ReferenceType
from torch.testing._internal.logging_tensor import LoggingTensorMode

import torch

Expand Down Expand Up @@ -290,6 +291,8 @@ def checkpoint(
*args,
use_reentrant: Optional[bool] = None,
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
determinism_check: str = "default",
debug: bool = False,
**kwargs
):
r"""Checkpoint a model or part of the model
Expand Down Expand Up @@ -382,6 +385,18 @@ def checkpoint(
context managers. The function and its recomputation will be run
under the first and second context managers respectively.
This argument is only supported if ``use_reentrant=False``.
determinism_check(str, optional): A string specifying the determinism
check to perform. By default it is set to ``"default"`` which
compares the shapes, dtypes, and devices of the recomputed tensors
against those the saved tensors. To turn off this check, specify
``"none"``. Currently these are the only two supported values.
Please open an issue if you would like to see more determinism
checks. This argument is only supported if ``use_reentrant=False``,
if ``use_reentrant=True``, the determinism check is always disabled.
debug(bool, optional): If ``True``, error messages will also include
a trace of the operators ran during the original forward computation
as well as the recomputation. This argument is only supported if
``use_reentrant=False``.
args: tuple containing inputs to the :attr:`function`
Returns:
Expand All @@ -405,16 +420,19 @@ def checkpoint(
)

if use_reentrant:
if context_fn is not noop_context_fn:
if context_fn is not noop_context_fn or debug is not False:
raise ValueError(
"Passing context_fn is only supported when use_reentrant=False."
"Passing `context_fn` or `debug` is only supported when "
"use_reentrant=False."
)
return CheckpointFunction.apply(function, preserve, *args)
else:
return _checkpoint_without_reentrant(
function,
preserve,
context_fn,
determinism_check,
debug,
*args,
**kwargs,
)
Expand Down Expand Up @@ -717,7 +735,7 @@ def backward(ctx, *grad_outputs):


class _CheckpointFrame:
def __init__(self, recompute_fn, early_stop):
def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn):
self.recompute_fn = recompute_fn
self.input_saver = None
self.weak_holders: List[ReferenceType] = []
Expand All @@ -735,6 +753,171 @@ def __init__(self, recompute_fn, early_stop):
# See Rule 5
self.early_stop = early_stop

# Debugging
self.metadata_fn = metadata_fn
self.unpack_error_cb = unpack_error_cb
self.x_metadatas = []
self.forward_completed = False

def check_recomputed_tensors_match(self, gid):
# NOTE [ Error handling for checkpoint ]
#
# At a high level, we need to check that the tensors saved
# during original forward matches tensors saved during recompute
# This means handling 3 cases:
#
# 1. During recompute, more tensors were saved.
#
# Usually this is hidden due to the StopRecomputationError
# but if early stop is not enabled, or we would have errored
# anyway because there aren't enough weak_holders. But we
# do want to have a nice error. See the _recomputation_hook
# for details.

if not len(self.weak_holders) == self.recomp_counter[gid]:
# 2. During recompute, fewer tensors were saved
#
# We know that everytime we save something do original forward
# we append to weak_holder, and every time we save a tensor
# during recompute we increment recompute_counter.
raise CheckpointError(
"torch.utils.checkpoint: A different number of tensors was saved "
"during the original forward and recomputation.\n"
f"Number of tensors saved during forward: {len(self.weak_holders)}\n"
f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}"
)

# 3. During recompute, the same tensors were saved, but they
# have different metadata
nb_meta_different = []
for idx, weak_holder in enumerate(self.weak_holders):
holder = weak_holder()
if holder is None:
continue
# We've seen all holders since we iterate over them in order
# For every holder that is still alive now, it must've been
# alive when we saw it during recompute, therefore, the
# gid must be set.
_internal_assert(gid in holder.handles)
# We know this is the first unpack, so it couldn't have been set
# to None yet.
_internal_assert(holder.handles[gid] is not None)
# We always set these together in the recomputation hook
_internal_assert(holder.handles[gid] in self.recomputed[gid])
# see pack hook, x_metadata is 1:1 with weak_holders.
x_meta = self.x_metadatas[idx]
recomputed_x = self.recomputed[gid][holder.handles[gid]]
if x_meta != self.metadata_fn(recomputed_x):
nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x)))

if len(nb_meta_different) > 0:
mismatched_tensors = ""
for idx, x_meta, recomputed_meta in nb_meta_different:
mismatched_tensors += (
f"tensor at position {idx}:\n"
f"saved metadata: {x_meta}\n"
f"recomputed metadata: {recomputed_meta}\n"
)
raise CheckpointError(
"torch.utils.checkpoint: Recomputed values for the following tensors "
"have different metadata than during the forward pass.\n"
f"{mismatched_tensors}"
)


_checkpoint_error_template = """ \
An error happened while unpacking tensors; dumping logs of latest computation
because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`.
Scroll all the way down for guidance on how to navigate these logs.
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
| 1. Stack traces of the operators that ran in the original forward
+------------------------------------------------------------------------------+
{forward_traces}
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
| 2. Stack traces of the operators that ran during recomputation
+------------------------------------------------------------------------------+
{recompute_traces}
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
| 3. Traces of the operators in the original forward and recomputation
+------------------------------------------------------------------------------+
(Scroll up to correlate stack traces with each operation listed below. This
helps identify their source in the code.)
IMPORTANT: Differences in "detach" calls between the original forward and the
recomputation are expected. They are introduced by the checkpointing
mechanism and can be ignored.
Operations executed during the original forward:
{forward_ops}
Operations executed during recomputation:
{recompute_ops}
+------------------------------------------------------------------------------+
ERROR: Detected non-determinism while running activation checkpointing
You are seeing this error because you passed `debug=True` to checkpoint and
tensors to be saved during the original forward and differ between those saved
during recomputation. This can happen if different operators were ran in the
original forward and in the recomputation.
To identify where the mismatch may be coming from, you can do the following:
1) Compare the operators ran during original forward and recomputation to
see where they differ. These operators are printed above in the order they
were executed.
2) Review the stack trace for each operator to locate its invocation source.
Each operator's stack trace is printed in their execution order.
Note that the logs can be quite long. Here's how they are structured:
1. Original Forward Operator Stack Traces
2. Recomputation Operator Stack Traces
3. Operations in Original Forward and Recomputation
4. Error message <--- You are here
--------------------------------------------------------------------------------
"""

def _get_debug_context_and_cb() -> Tuple[Optional[Dict[str, Any]], Optional[Callable]]:
logging_mode_fwd = LoggingTensorMode(collect_logs=True)
logging_mode_recompute = LoggingTensorMode(collect_logs=True)

def unpack_error_cb(e: CheckpointError):
raise CheckpointError(
_checkpoint_error_template.format(
forward_traces=logging_mode_fwd.str_traces(),
recompute_traces=logging_mode_fwd.str_traces(),
forward_ops=logging_mode_fwd.str_logs(),
recompute_ops=logging_mode_recompute.str_logs(),
)
) from e

def context_fn():
return logging_mode_fwd, logging_mode_recompute

return context_fn, unpack_error_cb

def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]:
# These properties are fast to check, easy to understand
return {
"shape": x.shape,
"dtype": x.dtype,
"device": x.device
}

_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = {
"default": _default_meta_extractor,
"none": lambda _: None,
}

class CheckpointError(Exception):
pass

# See Rule 5
class _StopRecomputationError(Exception):
Expand All @@ -750,9 +933,15 @@ def pack_hook(x):
target_frame.recomp_counter[gid] += 1

if recomp_idx >= len(target_frame.weak_holders):
# We run into this case when early stop is not enabled and do
# grad within checkpoint.
return x.detach()
if not target_frame.forward_completed:
# We run into this case when early stop is not enabled and do
# grad within checkpoint.
return x.detach()
raise CheckpointError(
"torch.utils.checkpoint: trying to save more tensors during "
"recomputation than during the original forward pass."
)

holder = target_frame.weak_holders[recomp_idx]()

# This holder may have been cleared because someone may have called
Expand All @@ -779,10 +968,14 @@ def unpack_hook(x):

class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, frame):
def pack_hook(_unused_x):
def pack_hook(x):
# See Rule 4 above
holder = _Holder()
frame.weak_holders.append(weakref.ref(holder))
# Save metadata to detect non-determinism
if frame.metadata_fn is not None:
with torch.no_grad():
frame.x_metadatas.append(frame.metadata_fn(x))
return holder

def unpack_hook(holder):
Expand All @@ -807,20 +1000,30 @@ def unpack_hook(holder):
except _StopRecomputationError:
pass
frame.is_recomputed[gid] = True
frame.check_recomputed_tensors_match(gid)

_internal_assert(gid in holder.handles)

if holder.handles[gid] is None:
raise RuntimeError(
"torch.utils.checkpoint: unpack is being triggered for a tensor that was either "
"never recomputed, or already unpacked once. If you are calling ctx.saved_tensors "
"in backward, make sure to do so only once. Otherwise please open an issue with "
"details on your use case."
raise CheckpointError(
"torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
"unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
"so only once. Otherwise please open an issue with details on your use case."
)
_internal_assert(holder.handles[gid] in frame.recomputed[gid])
ret = frame.recomputed[gid][holder.handles[gid]]
holder.handles[gid] = None
return ret

super().__init__(pack_hook, unpack_hook)
if frame.unpack_error_cb is not None:
def unpack_hook_with_error_cb(holder):
try:
return unpack_hook(holder)
except CheckpointError as e:
frame.unpack_error_cb(e)
super().__init__(pack_hook, unpack_hook_with_error_cb)
else:
super().__init__(pack_hook, unpack_hook)


# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
Expand All @@ -829,6 +1032,8 @@ def _checkpoint_without_reentrant(
fn,
preserve_rng_state=True,
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
determinism_check: str = "default",
debug: bool = False,
*args,
**kwargs
):
Expand All @@ -845,9 +1050,36 @@ def _checkpoint_without_reentrant(
context_fn(Callable, optional): A callable returning a tuple of two
context managers. The function and its recomputation will be run
under the first and second context managers respectively.
determinism_check(str, optional): A string specifying the determinism
check to perform. By default it is set to ``"default"`` which
compares the shapes, dtypes, and devices of the recomputed tensors
against those the saved tensors. To turn off this check, specify
``"none"``. Currently these are the only two supported values.
Please open an issue if you would like to see more determinism
checks.
debug(bool, optional): If ``True``, error messages will also include
a trace of the operators ran during the original forward computation
as well as the recomputation.
*args: Arguments to pass in to the given ``function``.
**kwargs: Keyword arguments to pass into the given ``function``.
"""
unpack_error_cb = None

if debug:
if context_fn != noop_context_fn:
raise ValueError(
"debug=True is incompatible with non-default context_fn"
)
context_fn, unpack_error_cb = _get_debug_context_and_cb()

if determinism_check in _allowed_determinism_checks_to_fns:
metadata_fn = _allowed_determinism_checks_to_fns[determinism_check]
else:
raise ValueError(
f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, "
f"but got {determinism_check}"
)

device = _infer_device_type(*args)
device_module = _get_device_module(device)
forward_context, recompute_context = context_fn()
Expand Down Expand Up @@ -886,7 +1118,12 @@ def recompute_fn(*inputs):
), torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:
fn(*args, **kwargs)

new_frame = _CheckpointFrame(recompute_fn, _enable_checkpoint_early_stop)
new_frame = _CheckpointFrame(
recompute_fn,
_enable_checkpoint_early_stop,
unpack_error_cb,
metadata_fn
)
dummy = torch.empty((0,), requires_grad=True)
new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)

Expand All @@ -896,6 +1133,7 @@ def recompute_fn(*inputs):

with _checkpoint_hook(new_frame), forward_context:
ret = fn(*args, **kwargs)
new_frame.forward_completed = True

if device_module._initialized and preserve_rng_state and not had_device_in_fwd:
# Device was not initialized before running the forward, so we didn't
Expand Down

0 comments on commit 07647a8

Please sign in to comment.