Skip to content

Commit

Permalink
Make sure namedtuple are preserved when adding hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
albanD committed Oct 30, 2023
1 parent 219763c commit 8b7122a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 26 deletions.
68 changes: 43 additions & 25 deletions test/nn/test_module_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
TestCase,
run_tests,
skipIfTorchDynamo,
IS_WINDOWS
IS_WINDOWS,
parametrize as parametrize_test,
instantiate_parametrized_tests
)
from torch.testing._internal.common_nn import NNTestCase, _create_basic_net

Expand All @@ -18,7 +20,7 @@
from tempfile import NamedTemporaryFile
import weakref
import pickle
from collections import OrderedDict
from collections import OrderedDict, namedtuple
import math
import warnings

Expand All @@ -32,16 +34,21 @@ def __init__(self) -> None:
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.seq2(self.seq1(x))

ToyNamedTuple = namedtuple("ToyNamedTuple", "content")

class ToyModel(nn.Module):
def __init__(self) -> None:
def __init__(self, with_named_tuple=False) -> None:
super().__init__()
self.net1 = Net()
self.net2 = Net()
self.with_named_tuple = with_named_tuple

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net2(self.net1(x))

res = self.net2(self.net1(x))
if self.with_named_tuple:
return ToyNamedTuple(res)
else:
return (res,)

def forward_hook(
self: TestCase,
Expand Down Expand Up @@ -178,9 +185,10 @@ def __exit__(self, *args, **kwargs):

class TestModuleHooks(TestCase):
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
def test_forward_hooks(self):
@parametrize_test("named_tuple", (True, False))
def test_forward_hooks(self, named_tuple):
fired_hooks: List[int] = []
model = ToyModel()
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(forward_hook, self, fired_hooks, model.net1.seq2)
model.net1.seq2.register_forward_hook(partial(hook, 0))
Expand All @@ -193,15 +201,17 @@ def test_forward_hooks(self):
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, expected)
out.sum().backward()
self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
out[0].sum().backward()
self.assertEqual(fired_hooks, expected)
model(x).sum().backward()
model(x)[0].sum().backward()
self.assertEqual(fired_hooks, expected + expected)

@skipIfTorchDynamo("Dynamo does not yet capture hooks")
def test_forward_pre_hooks(self):
@parametrize_test("named_tuple", (True, False))
def test_forward_pre_hooks(self, named_tuple):
fired_hooks: List[int] = []
model = ToyModel()
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1)
model.net2.seq1.register_forward_pre_hook(
Expand All @@ -218,15 +228,17 @@ def test_forward_pre_hooks(self):
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, expected)
out.sum().backward()
self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
out[0].sum().backward()
self.assertEqual(fired_hooks, expected)
model(x).sum().backward()
model(x)[0].sum().backward()
self.assertEqual(fired_hooks, expected + expected)

@skipIfTorchDynamo("Dynamo does not yet capture hooks")
def test_full_backward_hooks(self):
@parametrize_test("named_tuple", (True, False))
def test_full_backward_hooks(self, named_tuple):
fired_hooks: List[int] = []
model = ToyModel()
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(full_backward_hook, self, fired_hooks, model.net1)
model.net1.register_full_backward_hook(partial(hook, 0))
Expand All @@ -239,15 +251,17 @@ def test_full_backward_hooks(self):
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, [])
out.sum().backward()
self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
out[0].sum().backward()
self.assertEqual(fired_hooks, expected)
model(x).sum().backward()
model(x)[0].sum().backward()
self.assertEqual(fired_hooks, expected + expected)

@skipIfTorchDynamo("Dynamo does not yet capture hooks")
def test_full_backward_pre_hooks(self):
@parametrize_test("named_tuple", (True, False))
def test_full_backward_pre_hooks(self, named_tuple):
fired_hooks: List[int] = []
model = ToyModel()
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1)
model.net1.register_full_backward_pre_hook(
Expand All @@ -264,9 +278,10 @@ def test_full_backward_pre_hooks(self):
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, [])
out.sum().backward()
self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
out[0].sum().backward()
self.assertEqual(fired_hooks, expected)
model(x).sum().backward()
model(x)[0].sum().backward()
self.assertEqual(fired_hooks, expected + expected)

# Backward pre hook can affect subsequent gradient computation
Expand All @@ -283,9 +298,10 @@ def fn(_unused_module, grad_output):
self.assertEqual(a.grad, torch.zeros_like(a))

@skipIfTorchDynamo("Dynamo does not yet capture hooks")
def test_mixed_hooks(self):
@parametrize_test("named_tuple", (True, False))
def test_mixed_hooks(self, named_tuple):
fired_hooks: List[int] = []
model = ToyModel()
model = ToyModel(named_tuple)
x = torch.randn(10, 10)
model.register_forward_pre_hook(
partial(forward_pre_hook, self, fired_hooks, model, 0)
Expand All @@ -303,9 +319,10 @@ def test_mixed_hooks(self):
self.assertEqual(fired_hooks, [])
out = model(x)
self.assertEqual(fired_hooks, [0, 1])
out.sum().backward()
self.assertIsInstance(out, ToyNamedTuple if named_tuple else tuple)
out[0].sum().backward()
self.assertEqual(fired_hooks, [0, 1, 2, 3])
model(x).sum().backward()
model(x)[0].sum().backward()
self.assertEqual(fired_hooks, [0, 1, 2, 3, 0, 1, 2, 3])

@skipIfTorchDynamo("Dynamo does not yet capture hooks")
Expand Down Expand Up @@ -1490,6 +1507,7 @@ def parameter_registration_hook(module, name, parameter):
finally:
handle.remove()

instantiate_parametrized_tests(TestModuleHooks)

if __name__ == "__main__":
run_tests()
6 changes: 5 additions & 1 deletion torch/utils/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@ def _apply_on_tensors(self, fn, args):
for idx, val in zip(tensors_idx, new_tensors):
arg_list[idx] = val

return tuple(arg_list), tensors_idx
if type(args) is tuple:
out = tuple(arg_list)
else:
out = type(args)(*arg_list)
return out, tensors_idx

def setup_input_hook(self, args):
def fn(grad_fn):
Expand Down

0 comments on commit 8b7122a

Please sign in to comment.