Skip to content
Permalink
Browse files

Deprecate variadic inputs of checkpoint_sequential (#21006)

Summary:
I've reported inconsistency between `checkpoint_sequential` and `nn.Sequential` at #19260. Both should provide the same input signature but they don't. I think the consistency is important and I agree with apaszke that `nn.Sequential`'s semantics should be kept instead of `checkpoint_sequential`.

I hope `checkpoint_sequential` raises `TypeError` on variadic arguments since PyTorch 1.2.0. But for now, it's okay just to warn as `DeprecationWarning`. I've talked about this approach with soumith.

Please review this pull request. Any comment will be my pleasure.
Pull Request resolved: #21006

Differential Revision: D15530801

Pulled By: soumith

fbshipit-source-id: 0ceb2cc6a17dcc547d0d00ebaf9df8603be53183
  • Loading branch information...
sublee authored and facebook-github-bot committed May 29, 2019
1 parent d23d04f commit ffdce79078438c8722eccc87d9d68396559b97e4
Showing with 85 additions and 2 deletions.
  1. +42 −2 test/common_utils.py
  2. +28 −0 test/test_utils.py
  3. +15 −0 torch/utils/checkpoint.py
@@ -594,7 +594,7 @@ def assertWarns(self, callable, msg=''):
r"""
Test if :attr:`callable` raises a warning.
"""
with warnings.catch_warnings(record=True) as ws:
with self._reset_warning_registry(), warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
callable()
self.assertTrue(len(ws) > 0, msg)
@@ -604,13 +604,53 @@ def assertWarnsRegex(self, callable, regex, msg=''):
Test if :attr:`callable` raises any warning with message that contains
the regex pattern :attr:`regex`.
"""
with warnings.catch_warnings(record=True) as ws:
with self._reset_warning_registry(), warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always") # allow any warning to be raised
callable()
self.assertTrue(len(ws) > 0, msg)
found = any(re.search(regex, str(w.message)) is not None for w in ws)
self.assertTrue(found, msg)

@contextmanager
def _reset_warning_registry(self):
r"""
warnings.catch_warnings() in Python 2 misses already registered
warnings. We need to manually clear the existing warning registries to
ensure catching warnings in a scope.
"""
# Python 3 has no problem.
if sys.version_info >= (3,):
yield
return

# Backup and clear all existing warning registries.
backup = {}
for name, mod in list(sys.modules.items()):
try:
reg = mod.__warningregistry__
except AttributeError:
continue
else:
backup[name] = reg.copy()
reg.clear()

yield

# Restore backed up warning registries.
for name, reg_orig in backup.items():
try:
mod = sys.modules[name]
except KeyError:
continue

try:
reg = mod.__warningregistry__
except AttributeError:
mod.__warningregistry__ = reg_orig
else:
reg.clear()
reg.update(reg_orig)

def assertExpected(self, s, subname=None):
r"""
Test that a string matches the recorded contents of a file
@@ -198,6 +198,34 @@ def forward(self, *inputs):
torch.randn(1, 60, requires_grad=True)
)

def test_checkpoint_sequential_deprecated_multiple_args(self):
class Two(nn.Module):
def forward(self, a, b):
return a, b

model = nn.Sequential(Two())
a = torch.randn(1, 100, requires_grad=True)
b = torch.randn(1, 100, requires_grad=True)

self.assertWarnsRegex(
lambda: checkpoint_sequential(model, 1, a, b),
'deprecated',
'checkpoint_sequential with multiple args should be deprecated',
)

def test_checkpoint_sequential_deprecated_no_args(self):
class Noop(nn.Module):
def forward(self):
pass

model = nn.Sequential(Noop())

self.assertWarnsRegex(
lambda: checkpoint_sequential(model, 1),
'deprecated',
'checkpoint_sequential with no args should be deprecated',
)

def test_checkpoint_rng_cpu(self):
for _ in range(5):
inp = torch.randn(20000, device='cpu').requires_grad_()
@@ -155,6 +155,9 @@ def checkpoint(function, *args, **kwargs):
return CheckpointFunction.apply(function, preserve, *args)


# TODO(sublee): When releasing PyTorch 1.3,
# fix the function signature to not accept variadic arguments.
# See also: https://github.com/pytorch/pytorch/issues/19260
def checkpoint_sequential(functions, segments, *inputs, **kwargs):
r"""A helper function for checkpointing sequential models.
@@ -196,6 +199,18 @@ def checkpoint_sequential(functions, segments, *inputs, **kwargs):
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

# To accept variadic arguments is not consistent with nn.Sequential.
# This interface will be changed at PyTorch 1.3.
# See also: https://github.com/pytorch/pytorch/issues/19260
if not inputs:
warnings.warn('Giving no input to checkpoint_sequential has been deprecated, '
'a TypeError will be raised after PyTorch 1.3',
DeprecationWarning)
elif len(inputs) > 1:
warnings.warn('multiple inputs to checkpoint_sequential has been deprecated, '
'a TypeError will be raised after PyTorch 1.3',
DeprecationWarning)

def run_function(start, end, functions):
def forward(*inputs):
for j in range(start, end + 1):

0 comments on commit ffdce79

Please sign in to comment.
You can’t perform that action at this time.