From 13ae2dd0fbd82811d2a1341a741b7d1a6289cfdc Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Wed, 5 Oct 2022 21:47:21 +0000 Subject: [PATCH 1/2] CheckpointSequential support non-reentrant [ghstack-poisoned] --- test/test_utils.py | 63 ++++++++++++++++++++------------------- torch/utils/checkpoint.py | 14 +++++++-- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 3fe597b6826f0..5ce036ca56ba4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -52,37 +52,38 @@ def _check_checkpoint_sequential( num_chunks, input, ): - - # not checkpointed - out = model(input) - out_not_checkpointed = out.detach().clone() - model.zero_grad() - out.sum().backward() - grad_not_checkpointed = { - name: param.grad.detach().clone() - for name, param in model.named_parameters() - } - input_grad_not_checkpointed = input.grad.detach().clone() - for model_to_compare in module_lists_to_compare: - # checkpointed model by passing list of modules - detached = input.detach() - detached.requires_grad = True - - # pass list of modules to checkpoint - out = checkpoint_sequential(model_to_compare, num_chunks, detached) - out_checkpointed = out.detach().clone() - model.zero_grad() - out.sum().backward() - grad_checkpointed = { - name: param.grad.detach().clone() - for name, param in model.named_parameters() - } - input_grad_checkpointed = detached.grad.detach().clone() - # compare outputs as well as the gradients of input and parameters - self.assertEqual(out_checkpointed, out_not_checkpointed) - self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed) - for name in grad_checkpointed: - self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) + for use_reentrant in [True, False]: + with self.subTest(use_reentrant=use_reentrant): + # not checkpointed + out = model(input) + out_not_checkpointed = out.detach().clone() + model.zero_grad() + out.sum().backward() + grad_not_checkpointed = { + name: param.grad.detach().clone() + for name, param in model.named_parameters() + } + input_grad_not_checkpointed = input.grad.detach().clone() + for model_to_compare in module_lists_to_compare: + # checkpointed model by passing list of modules + detached = input.detach() + detached.requires_grad = True + + # pass list of modules to checkpoint + out = checkpoint_sequential(model_to_compare, num_chunks, detached, use_reentrant=use_reentrant) + out_checkpointed = out.detach().clone() + model.zero_grad() + out.sum().backward() + grad_checkpointed = { + name: param.grad.detach().clone() + for name, param in model.named_parameters() + } + input_grad_checkpointed = detached.grad.detach().clone() + # compare outputs as well as the gradients of input and parameters + self.assertEqual(out_checkpointed, out_not_checkpointed) + self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed) + for name in grad_checkpointed: + self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) # Test whether checkpoint is being triggered or not. For this, we check # the number of times forward pass happens diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index d28cf4a1c3ac6..dc3ff35543ee9 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -226,7 +226,7 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Default: ``True`` - use_reentrant(bool, optional): Use checkpointing + use_reentrant(bool, optional): Use (the default) checkpointing implementation that requires re-entrant autograd. If ``use_reentrant=False`` is specified, ``checkpoint`` will use an implementation that does not require re-entrant autograd. This @@ -256,7 +256,7 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): ) -def checkpoint_sequential(functions, segments, input, **kwargs): +def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs): r"""A helper function for checkpointing sequential models. Sequential models execute a list of modules/functions in order @@ -290,6 +290,14 @@ def checkpoint_sequential(functions, segments, input, **kwargs): preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Default: ``True`` + use_reentrant(bool, optional): Use (the default) checkpointing + implementation that requires re-entrant autograd. + If ``use_reentrant=False`` is specified, ``checkpoint`` will use an + implementation that does not require re-entrant autograd. This + allows ``checkpoint`` to support additional functionality, such as + working as expected with ``torch.autograd.grad`` and support for + keyword arguments input into the checkpointed function. + Default: ``True`` Returns: Output of running :attr:`functions` sequentially on :attr:`*inputs` @@ -320,7 +328,7 @@ def forward(input): for start in range(0, segment_size * (segments - 1), segment_size): end = start + segment_size - 1 input = checkpoint(run_function(start, end, functions), input, - preserve_rng_state=preserve) + use_reentrant=use_reentrant, preserve_rng_state=preserve) return run_function(end + 1, len(functions) - 1, functions)(input) def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs): From 2f57ef037076570fd845e87b1a4c6c44fe66b3cf Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Wed, 5 Oct 2022 21:58:12 +0000 Subject: [PATCH 2/2] Update on "CheckpointSequential support non-reentrant" Closes https://github.com/pytorch/pytorch/issues/86328 Adds `use_reentrant` argument to `checkpoint_sequential`. [ghstack-poisoned] --- test/test_utils.py | 51 ++++++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 5ce036ca56ba4..de215da919654 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -100,18 +100,20 @@ def forward(self, input_var): return input_var # checkpointed - modules = [Net() for _ in range(10)] - for m in modules: - self.assertEqual(m.counter, 0) - input_var = torch.randn(3, 4, requires_grad=True) - out = checkpoint_sequential(modules, 2, input_var) - for m in modules: - self.assertEqual(m.counter, 1) - out.sum().backward() - for m in modules[:(len(modules) // 2)]: - self.assertEqual(m.counter, 2) - for m in modules[(len(modules) // 2):]: - self.assertEqual(m.counter, 1) + for use_reentrant in [True, False]: + with self.subTest(use_reentrant=use_reentrant): + modules = [Net() for _ in range(10)] + for m in modules: + self.assertEqual(m.counter, 0) + input_var = torch.randn(3, 4, requires_grad=True) + out = checkpoint_sequential(modules, 2, input_var, use_reentrant=use_reentrant) + for m in modules: + self.assertEqual(m.counter, 1) + out.sum().backward() + for m in modules[:(len(modules) // 2)]: + self.assertEqual(m.counter, 2) + for m in modules[(len(modules) // 2):]: + self.assertEqual(m.counter, 1) def test_checkpoint_valid(self): model = nn.Sequential( @@ -133,6 +135,18 @@ def test_checkpoint_valid(self): torch.autograd.grad( outputs=[out], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True ) + # works with use_reentrant=False, and grads are the same + out = model(input_var) + grads_no_checkpoint = torch.autograd.grad( + outputs=[out], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True, + ) + out_checkpoint = checkpoint_sequential(modules, chunks, input_var, use_reentrant=False) + # check outputs are the same + self.assertEqual(out_checkpoint, out) + grads_checkpoint = torch.autograd.grad( + outputs=[out_checkpoint], grad_outputs=[torch.ones(1, 5)], inputs=[input_var], create_graph=True, + ) + self.assertEqual(grads_no_checkpoint, grads_checkpoint) def test_checkpoint(self): model = nn.Sequential( @@ -193,8 +207,10 @@ def forward(self, a, b): a = torch.randn(1, 100, requires_grad=True) b = torch.randn(1, 100, requires_grad=True) - with self.assertRaises(TypeError): - checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg] + for use_reentrant in [True, False]: + with self.subTest(use_reentrant=use_reentrant): + with self.assertRaises(TypeError): + checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg] def test_checkpoint_sequential_deprecated_no_args(self): class Noop(nn.Module): @@ -202,9 +218,10 @@ def forward(self): pass model = nn.Sequential(Noop()) - - with self.assertRaises(TypeError): - checkpoint_sequential(model, 1) # type: ignore[call-arg] + for use_reentrant in [True, False]: + with self.subTest(use_reentrant=use_reentrant): + with self.assertRaises(TypeError): + checkpoint_sequential(model, 1) # type: ignore[call-arg] def test_checkpoint_rng_cpu(self): for _ in range(5):