diff --git a/docs/source/checkpoint.rst b/docs/source/checkpoint.rst index 4215f06c744f9..e003d2460b558 100644 --- a/docs/source/checkpoint.rst +++ b/docs/source/checkpoint.rst @@ -34,3 +34,4 @@ torch.utils.checkpoint .. currentmodule:: torch.utils.checkpoint .. autofunction:: checkpoint .. autofunction:: checkpoint_sequential +.. autofunction:: set_checkpoint_debug_enabled diff --git a/test/test_autograd.py b/test/test_autograd.py index ad98d7031c809..a71020c4b386b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5921,6 +5921,22 @@ def fn(x): out = checkpoint(fn, a, use_reentrant=False, debug=True) out.backward() + fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) + + with self.assertRaisesRegex(RuntimeError, "You are seeing this error because you passed `debug=True` to checkpoint"): + with torch.utils.checkpoint.set_checkpoint_debug_enabled(True): + out = checkpoint(fn, a, use_reentrant=False, debug=False) + out.backward() + + fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) + + with self.assertRaisesRegex(RuntimeError, "Recomputed values for the following tensors have different"): + with torch.utils.checkpoint.set_checkpoint_debug_enabled(False): + out = checkpoint(fn, a, use_reentrant=False, debug=True) + out.backward() + + + def test_access_saved_tensor_twice_without_recomputation_works(self): count = [0] diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index fb49829f2c27f..bd0cb7fea7dac 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -37,10 +37,35 @@ "set_checkpoint_early_stop", "DefaultDeviceType", "context_fn_gen", + "set_checkpoint_debug_enabled", ] _DEFAULT_DETERMINISM_MODE = "default" +_checkpoint_debug_enabled: Optional[bool] = None + + +@contextlib.contextmanager +def set_checkpoint_debug_enabled(enabled: Optional[bool]): + """ + Context manager that sets whether checkpoint should print additional debug + information when running. See the ``debug`` flag for + :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that + when set, this context manager overrides the value of ``debug`` passed to + checkpoint. To defer to the local setting, pass ``None`` to this context. + + Args: + enabled (bool): Whether checkpoint should print debug information. + Default is 'None'. + """ + global _checkpoint_debug_enabled + try: + prev = _checkpoint_debug_enabled + _checkpoint_debug_enabled = enabled + yield + finally: + _checkpoint_debug_enabled = prev + def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: if isinstance(inputs, tuple): @@ -1300,7 +1325,7 @@ def _checkpoint_without_reentrant_generator( """ unpack_error_cb = None - if debug: + if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: if context_fn != noop_context_fn: raise ValueError( "debug=True is incompatible with non-default context_fn"