diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index 18d631d8b6fba..4102a26897cf8 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -3,7 +3,9 @@ TestCase, run_tests, skipIfTorchDynamo, - IS_WINDOWS + IS_WINDOWS, + parametrize as parametrize_test, + instantiate_parametrized_tests ) from torch.testing._internal.common_nn import NNTestCase, _create_basic_net @@ -18,7 +20,7 @@ from tempfile import NamedTemporaryFile import weakref import pickle -from collections import OrderedDict +from collections import OrderedDict, namedtuple import math import warnings @@ -32,16 +34,21 @@ def __init__(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: return self.seq2(self.seq1(x)) +ToyNamedTuple = namedtuple("ToyNamedTuple", "content") class ToyModel(nn.Module): - def __init__(self) -> None: + def __init__(self, with_named_tuple=False) -> None: super().__init__() self.net1 = Net() self.net2 = Net() + self.with_named_tuple = with_named_tuple def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.net2(self.net1(x)) - + res = self.net2(self.net1(x)) + if self.with_named_tuple: + return ToyNamedTuple(res) + else: + return (res,) def forward_hook( self: TestCase, @@ -178,9 +185,10 @@ def __exit__(self, *args, **kwargs): class TestModuleHooks(TestCase): @skipIfTorchDynamo("Dynamo does not yet capture hooks") - def test_forward_hooks(self): + @parametrize_test("named_tuple", (True, False)) + def test_forward_hooks(self, named_tuple): fired_hooks: List[int] = [] - model = ToyModel() + model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(forward_hook, self, fired_hooks, model.net1.seq2) model.net1.seq2.register_forward_hook(partial(hook, 0)) @@ -193,15 +201,17 @@ def test_forward_hooks(self): self.assertEqual(fired_hooks, []) out = model(x) self.assertEqual(fired_hooks, expected) - out.sum().backward() + self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) + out[0].sum().backward() self.assertEqual(fired_hooks, expected) - model(x).sum().backward() + model(x)[0].sum().backward() self.assertEqual(fired_hooks, expected + expected) @skipIfTorchDynamo("Dynamo does not yet capture hooks") - def test_forward_pre_hooks(self): + @parametrize_test("named_tuple", (True, False)) + def test_forward_pre_hooks(self, named_tuple): fired_hooks: List[int] = [] - model = ToyModel() + model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1) model.net2.seq1.register_forward_pre_hook( @@ -218,15 +228,17 @@ def test_forward_pre_hooks(self): self.assertEqual(fired_hooks, []) out = model(x) self.assertEqual(fired_hooks, expected) - out.sum().backward() + self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) + out[0].sum().backward() self.assertEqual(fired_hooks, expected) - model(x).sum().backward() + model(x)[0].sum().backward() self.assertEqual(fired_hooks, expected + expected) @skipIfTorchDynamo("Dynamo does not yet capture hooks") - def test_full_backward_hooks(self): + @parametrize_test("named_tuple", (True, False)) + def test_full_backward_hooks(self, named_tuple): fired_hooks: List[int] = [] - model = ToyModel() + model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(full_backward_hook, self, fired_hooks, model.net1) model.net1.register_full_backward_hook(partial(hook, 0)) @@ -239,15 +251,17 @@ def test_full_backward_hooks(self): self.assertEqual(fired_hooks, []) out = model(x) self.assertEqual(fired_hooks, []) - out.sum().backward() + self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) + out[0].sum().backward() self.assertEqual(fired_hooks, expected) - model(x).sum().backward() + model(x)[0].sum().backward() self.assertEqual(fired_hooks, expected + expected) @skipIfTorchDynamo("Dynamo does not yet capture hooks") - def test_full_backward_pre_hooks(self): + @parametrize_test("named_tuple", (True, False)) + def test_full_backward_pre_hooks(self, named_tuple): fired_hooks: List[int] = [] - model = ToyModel() + model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1) model.net1.register_full_backward_pre_hook( @@ -264,9 +278,10 @@ def test_full_backward_pre_hooks(self): self.assertEqual(fired_hooks, []) out = model(x) self.assertEqual(fired_hooks, []) - out.sum().backward() + self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) + out[0].sum().backward() self.assertEqual(fired_hooks, expected) - model(x).sum().backward() + model(x)[0].sum().backward() self.assertEqual(fired_hooks, expected + expected) # Backward pre hook can affect subsequent gradient computation @@ -283,9 +298,10 @@ def fn(_unused_module, grad_output): self.assertEqual(a.grad, torch.zeros_like(a)) @skipIfTorchDynamo("Dynamo does not yet capture hooks") - def test_mixed_hooks(self): + @parametrize_test("named_tuple", (True, False)) + def test_mixed_hooks(self, named_tuple): fired_hooks: List[int] = [] - model = ToyModel() + model = ToyModel(named_tuple) x = torch.randn(10, 10) model.register_forward_pre_hook( partial(forward_pre_hook, self, fired_hooks, model, 0) @@ -303,9 +319,10 @@ def test_mixed_hooks(self): self.assertEqual(fired_hooks, []) out = model(x) self.assertEqual(fired_hooks, [0, 1]) - out.sum().backward() + self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple) + out[0].sum().backward() self.assertEqual(fired_hooks, [0, 1, 2, 3]) - model(x).sum().backward() + model(x)[0].sum().backward() self.assertEqual(fired_hooks, [0, 1, 2, 3, 0, 1, 2, 3]) @skipIfTorchDynamo("Dynamo does not yet capture hooks") @@ -1490,6 +1507,7 @@ def parameter_registration_hook(module, name, parameter): finally: handle.remove() +instantiate_parametrized_tests(TestModuleHooks) if __name__ == "__main__": run_tests() diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index e0ab3618242af..d6f8776b54155 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -181,7 +181,11 @@ def _apply_on_tensors(self, fn, args): for idx, val in zip(tensors_idx, new_tensors): arg_list[idx] = val - return tuple(arg_list), tensors_idx + if type(args) is tuple: + out = tuple(arg_list) + else: + out = type(args)(*arg_list) + return out, tensors_idx def setup_input_hook(self, args): def fn(grad_fn):