In general, “hooks” are functions that automatically execute after a particular event

Pytorch hook会自动给每一个torch.Tensor和torch.nn.Module注册多个hook
在forward或者backward之前/之后自动执行的一些操作，用于检查和可视化


1. Tensor hook会自动执行。The hook will be called every time a gradient with respect to the Tensor is computed. The hook should have the following signature:

In [1]:
import torch
v = torch.tensor([1., 2., 3.], requires_grad=True)
v.register_hook(lambda grad: grad * 2)  # double the gradient
v.register_hook(lambda grad: print(grad))  # print the gradient
v.register_hook(lambda grad : grad / 2)
v.register_hook(lambda grad : print(grad))
loss = torch.sum(v ** 2)
loss.backward()


tensor([ 4.,  8., 12.])
tensor([2., 4., 6.])


In [2]:
v = torch.tensor([0., 0., 0.], requires_grad=True)
h = v.register_hook(lambda grad: grad ** 2)  # double the gradient
v.backward(torch.tensor([1., 2., 3.])) # Jacobian vector product
print(v.grad) 

h.remove()  # removes the hook

tensor([1., 4., 9.])


Pytorch 可以计算 Jacobian vector product   
y = f(x)  
y.backward(v) -> $v^TJ$

The size of v should be the same as the size of the original tensor, with respect to which we want to compute the product ---- 由于v的形状与输入x的shape一样, 此时x的梯度已经为$v^TJ$



backward()之后计算图会被销毁，但是每个张量的梯度会被保留

In [5]:
inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()

In [3]:
inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")

First call
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

Second call
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.]])

Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])


2. pytorch中为nn.Module注册了更多的hook，包括forward, backward, before forward, after forward等等

2.1  register_module_forward_pre_hook(hook) / register_forward_pre_hook

```
hook(module, input) -> None or modified input
```

We will wrap the value into a tuple if a single value is returned(unless that value is already a tuple

2.2 register_module_forward_hook / register_forward_hook
```
hook(module, input, output) -> None or modified output
```

The input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called.


2.3 register_module_full_backward_hook / register_full_backward_hook
```
hook(module, grad_input, grad_output) -> Tensor or None
```

In [41]:
import torch
import torch.nn as nn
from IPython.display import clear_output
layer = nn.Linear(5, 3)
x = torch.randn((6,5),requires_grad=True)
layer.register_forward_pre_hook(lambda module, input: print(f'forward_pre_hook with {input[0].shape}'))
layer.register_forward_hook(lambda module, input, output: print(f'forward_hook with {input[0].shape} and output {output.shape}'))
layer.register_full_backward_hook(lambda module, grad_input, grad_output: print(f'backward_hook with grad_input ', grad_input, ' and grad_output ', grad_output))
output = layer(x)
output.sum().backward()
clear_output()

In [11]:
class VerboseExecution(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

        # Register a hook for each layer
        for name, layer in self.model.named_children():
            layer.__name__ = name
            layer.register_forward_hook(
                lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
            )

    def forward(self, x) :
        return self.model(x)

In [19]:
forward = []
class ForwardCheck(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

        # Register a hook for each layer
        for name, layer in self.model.named_children():
            layer.__name__ = name
            layer.register_forward_hook(self.forward_hook(name))

    def forward_hook(self, name):
        def hook(layer, input, output):
            print(f"{name},{output.mean().item()}")
            forward.append(output.mean().item())
        return hook

    def forward(self, x):
        return self.model(x)

In [8]:
import torch
import torchvision
from torchvision.models import resnet50

verbose_resnet = VerboseExecution(resnet50())
dummy_input = torch.ones(10, 3, 224, 224)

_ = verbose_resnet(dummy_input)

conv1: torch.Size([10, 64, 112, 112])
bn1: torch.Size([10, 64, 112, 112])
relu: torch.Size([10, 64, 112, 112])
maxpool: torch.Size([10, 64, 56, 56])
layer1: torch.Size([10, 256, 56, 56])
layer2: torch.Size([10, 512, 28, 28])
layer3: torch.Size([10, 1024, 14, 14])
layer4: torch.Size([10, 2048, 7, 7])
avgpool: torch.Size([10, 2048, 1, 1])
fc: torch.Size([10, 1000])


In [14]:
forward_resnet = ForwardCheck(resnet50())

_ = forward_resnet(dummy_input)

conv1,-0.059701159596443176
bn1,3.8469323726531e-08
relu,0.1515873670578003
maxpool,0.2705785632133484
layer1,0.6606165170669556
layer2,0.9451380372047424
layer3,1.3701601028442383
layer4,0.9949647188186646
avgpool,0.9949647784233093
fc,-0.013356885872781277


In [38]:
# 给tensor注册hook
def GradientCheck(model: nn.Module) -> nn.Module:
    for name, parameter in model.named_parameters():
        parameter.register_hook(lambda grad, name=name: print(f"{name},{grad.mean().item()}"))

    return model

In [39]:
gradient_resnet = GradientCheck(resnet50())
loss = torch.nn.CrossEntropyLoss()
predict = gradient_resnet(dummy_input)
loss(predict, torch.zeros(10, dtype=torch.long)).backward()

fc.bias,-7.450580430390374e-11
fc.weight,2.0116568133499158e-10
layer4.2.bn3.weight,-4.083880776306614e-05
layer4.2.bn3.bias,-0.00010678460239432752
layer4.2.conv3.weight,-3.0424041597143514e-06
layer4.2.bn2.weight,-1.161606633104384e-08
layer4.2.bn2.bias,-1.0719409146986436e-06
layer4.2.conv2.weight,-1.3477792890626006e-05
layer4.2.bn1.weight,-2.9103830456733704e-11
layer4.2.bn1.bias,-2.8743081202264875e-05
layer4.2.conv1.weight,1.487585177528672e-05
layer4.1.bn3.weight,-7.001651829341426e-05
layer4.1.bn3.bias,-5.002642501494847e-05
layer4.1.conv3.weight,6.656162895524176e-06
layer4.1.bn2.weight,-2.3654138203710318e-08
layer4.1.bn2.bias,3.75801682821475e-05
layer4.1.conv2.weight,-8.846724085742608e-06
layer4.1.bn1.weight,8.585629984736443e-10
layer4.1.bn1.bias,0.00019211262406315655
layer4.1.conv1.weight,-1.7508104065200314e-05
layer4.0.downsample.1.weight,-4.5250169932842255e-05
layer4.0.downsample.1.bias,-2.2337619157042354e-05
layer4.0.downsample.0.weight,-2.9237884518806823e-05
la