From fd97714c5a1f4e1410ec30a1d91906e8b353bdd4 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 6 Oct 2023 13:40:56 -0400 Subject: [PATCH] Add set_checkpoint_debug_enabled that overrides local setting [ghstack-poisoned] --- test/test_autograd.py | 16 ++++++++++++++++ torch/utils/checkpoint.py | 27 ++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index ad98d7031c809..a71020c4b386b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -5921,6 +5921,22 @@ def fn(x): out = checkpoint(fn, a, use_reentrant=False, debug=True) out.backward() + fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) + + with self.assertRaisesRegex(RuntimeError, "You are seeing this error because you passed `debug=True` to checkpoint"): + with torch.utils.checkpoint.set_checkpoint_debug_enabled(True): + out = checkpoint(fn, a, use_reentrant=False, debug=False) + out.backward() + + fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) + + with self.assertRaisesRegex(RuntimeError, "Recomputed values for the following tensors have different"): + with torch.utils.checkpoint.set_checkpoint_debug_enabled(False): + out = checkpoint(fn, a, use_reentrant=False, debug=True) + out.backward() + + + def test_access_saved_tensor_twice_without_recomputation_works(self): count = [0] diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index a20ba759c5875..fa79929879209 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -37,10 +37,35 @@ "set_checkpoint_early_stop", "DefaultDeviceType", "context_fn_gen", + "set_checkpoint_debug_enabled", ] _DEFAULT_DETERMINISM_MODE = "default" +_checkpoint_debug_enabled: Optional[bool] = None + + +@contextlib.contextmanager +def set_checkpoint_debug_enabled(enabled: Optional[bool]): + """ + Context manager that sets whether checkpoint should print additional debug + information when running. See the ``debug`` flag for + :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that + when set, this context manager overrides the value of ``debug`` passed to + checkpoint. To defer to the local setting, pass ``None`` to this context. + + Args: + enabled (bool): Whether checkpoint should print debug information. + Default is 'None'. + """ + global _checkpoint_debug_enabled + try: + prev = _checkpoint_debug_enabled + _checkpoint_debug_enabled = enabled + yield + finally: + _checkpoint_debug_enabled = prev + def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: if isinstance(inputs, tuple): @@ -1299,7 +1324,7 @@ def _checkpoint_without_reentrant_generator( """ unpack_error_cb = None - if debug: + if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: if context_fn != noop_context_fn: raise ValueError( "debug=True is incompatible with non-default context_fn"