# Hooks
https://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html#forward-and-backward-function-hooks

In [None]:
def tensorinfo_hook(module, input_, output):
    """
    Register this forward hook to print some infos about the tensor/module.

    Example:

        >>> from torchvision.models import resnet18
        >>> model = resnet18(pretrained=False)
        >>> hook_fc = model.fc.register_forward_hook(tensorinfo_hook)
        >>> # model(torch.ones(1, 3, 244, 244))
        >>> hook_fc.remove()

    """
    print(f"Inside '{module.__class__.__name__}' forward")
    print(f"  input:     {str(type(input_)):<25}")
    print(f"  input[0]:  {str(type(input_[0])):<25} {input_[0].size()}")
    print(f"  output:    {str(type(output)):<25} {output.data.size()}")
    print()

In [None]:
import torch
import torch.nn as nn

In [None]:
m = nn.Linear(1, 3)

In [None]:
hook = m.register_forward_hook(tensorinfo_hook)

In [None]:
m(torch.rand(1));

In [None]:
hook.remove()

## Exercise
- Write a context manager hook that removes the hook when leaving the with block.