Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add set_checkpoint_debug_enabled that overrides local setting #110728

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/checkpoint.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ torch.utils.checkpoint
.. currentmodule:: torch.utils.checkpoint
.. autofunction:: checkpoint
.. autofunction:: checkpoint_sequential
.. autofunction:: set_checkpoint_debug_enabled
16 changes: 16 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5921,6 +5921,22 @@ def fn(x):
out = checkpoint(fn, a, use_reentrant=False, debug=True)
out.backward()

fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt)

with self.assertRaisesRegex(RuntimeError, "You are seeing this error because you passed `debug=True` to checkpoint"):
with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):
out = checkpoint(fn, a, use_reentrant=False, debug=False)
out.backward()

fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt)

with self.assertRaisesRegex(RuntimeError, "Recomputed values for the following tensors have different"):
with torch.utils.checkpoint.set_checkpoint_debug_enabled(False):
out = checkpoint(fn, a, use_reentrant=False, debug=True)
out.backward()



def test_access_saved_tensor_twice_without_recomputation_works(self):
count = [0]

Expand Down
27 changes: 26 additions & 1 deletion torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,35 @@
"set_checkpoint_early_stop",
"DefaultDeviceType",
"context_fn_gen",
"set_checkpoint_debug_enabled",
]

_DEFAULT_DETERMINISM_MODE = "default"

_checkpoint_debug_enabled: Optional[bool] = None


@contextlib.contextmanager
def set_checkpoint_debug_enabled(enabled: Optional[bool]):
"""
Context manager that sets whether checkpoint should print additional debug
information when running. See the ``debug`` flag for
:func:`~torch.utils.checkpoint.checkpoint` for more information. Note that
when set, this context manager overrides the value of ``debug`` passed to
checkpoint. To defer to the local setting, pass ``None`` to this context.

Args:
enabled (bool): Whether checkpoint should print debug information.
Default is 'None'.
"""
global _checkpoint_debug_enabled
try:
prev = _checkpoint_debug_enabled
_checkpoint_debug_enabled = enabled
yield
finally:
_checkpoint_debug_enabled = prev


def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
if isinstance(inputs, tuple):
Expand Down Expand Up @@ -1300,7 +1325,7 @@ def _checkpoint_without_reentrant_generator(
"""
unpack_error_cb = None

if debug:
if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug:
if context_fn != noop_context_fn:
raise ValueError(
"debug=True is incompatible with non-default context_fn"
Expand Down