# Hooks
**1. 什么是hook？**
hook是一个特殊的<font color=blue>函数</font>，它在autograd.Function的forward或backward method调用的前后被执行，以实现预定的功能。\
**2. hook的使用对象？**
hook可以<font color=blue>用于**tensor或module**</font>，称为register on a tensor or nn.Module。典型使用场景：\
· <font color=green>用于tensor的时候主要用来控制从forward向backward传递信息过程中的pack/unpack information。</font>\
· <font color=green>用于module中的时候，可以用于模型可视化，debug，gradient check等。</font> \
**3. hooks的类型？**
pytorch提供了两种hooks：forward hook和backward hook。使用hook，要先register到使用hook的位置。forward hook又有两种使用位置，<font color=lightblue>forward prehook</font>在forward method之前执行；<font color=lightblue>forward hook</font>在forward method执行完之后执行。<font color=lightblue>backward hook</font>只有一种，在backward method执行完后执行。

In [1]:
import torch
import torchviz
import torch.nn as nn
import torch.nn.functional as F

## 1. Hooks for (autograd saved) tensors
用于tensor的hook只有backward hook。\
**signature of hook function：**<font color=red>hook(grad) -> Tensor or None </font>\
**典型用途：**
1. <font color=green>改变tensor的梯度计算方式。</font>某个tensor x的x.grad用hook改变后，DAG上位于x前面，依赖于x的tensor的梯度也按chainrule改变。如果不用tensor hook，要等bp结束后手动改变x和每一个前序tensor的梯度。
2. <font color=green>查看intermediate tensor的grad，此时不占额外内存。</font>不用tensor hook的话就要用intermediate_tensor.remain_grad()，会占用内存。

In [2]:
## 例1：用hook改变tensor梯度(eg:x.grad)的计算方式

# 定义hook function
def func(grad):
    return grad * 2

In [3]:
# 1.without hook
torch.manual_seed(2)
w = torch.randn(5, requires_grad=True)
b = torch.randn(3, requires_grad=True)
x = torch.ones((3, 5))
y = x @ w + b
y.retain_grad()
loss = (y ** 2).sum()
loss.retain_grad()

loss.backward()
print(w.grad)
print(b.grad)
print(loss.grad)

tensor([-0.9064, -0.9064, -0.9064, -0.9064, -0.9064])
tensor([-1.8891,  0.5235,  0.4592])
tensor(1.)


In [4]:
# 2.with hook for loss
torch.manual_seed(2)
w = torch.randn(5, requires_grad=True)
b = torch.randn(3, requires_grad=True)
x = torch.ones((3, 5))
y = x @ w + b
y.retain_grad()
loss = (y ** 2).sum()
loss.retain_grad()

#  register hook
loss.register_hook(func)

loss.backward()

# 因为loss的gradient翻倍，导致前向传递，w和b的梯度也翻倍
print(w.grad)
print(b.grad)
print(y.grad)
print(loss.grad)

tensor([-1.8128, -1.8128, -1.8128, -1.8128, -1.8128])
tensor([-3.7783,  1.0471,  0.9184])
tensor([-3.7783,  1.0471,  0.9184])
tensor(2.)


In [5]:
# 3.with hook for w
torch.manual_seed(2)
w = torch.randn(5, requires_grad=True)
b = torch.randn(3, requires_grad=True)
x = torch.ones((3, 5))
y = x @ w + b
y.retain_grad()
loss = (y ** 2).sum()
loss.retain_grad()

#  register hook for w
w.register_hook(func)

loss.backward()

# 因为w的gradient翻倍，前向传递不影响loss,y和b
print(w.grad)
print(b.grad)
print(y.grad)
print(loss.grad)

tensor([-1.8128, -1.8128, -1.8128, -1.8128, -1.8128])
tensor([-1.8891,  0.5235,  0.4592])
tensor([-1.8891,  0.5235,  0.4592])
tensor(1.)


In [6]:
## 例2：tensor hooks用来查看intermediate tensor的grad，且不占额外内存
#  定义hook
def func(grad):
    print(grad)
    return grad

torch.manual_seed(2)
w = torch.randn(5, requires_grad=True)
b = torch.randn(3, requires_grad=True)
x = torch.ones((3, 5))
y = x @ w + b
loss = (y ** 2).sum()
loss.retain_grad()

#  register hook for y，y在这里是intermediate tensor
#  如果不用hook，就要设置y.retain_grad()才能在bp结束后查看，会占用内存
y.register_hook(func)

loss.backward()

tensor([-1.8891,  0.5235,  0.4592])


In [7]:
## 例3：Gradient clipping：对module的parameter用hook，直接改变其值
def grad_clipper(model, val):
    for parameter in model.parameters():
        parameter.register_hook(lambda grad: grad.climp_(floor, cap))
    return model

In [8]:
## 例4：用于module中
#  要求：
#    1. 将linear层中bias的梯度改为0
#    2. conv layer从downstream拿到的gradient大小都不小于0

# ----------------  原模型  ----------------
class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 2, stride=2)
        self.relu = nn.ReLU()
        self.flatten = lambda x: x.view(-1)
        self.fc = nn.Linear(160, 5)
    
    def forward(self, x):
        x = self.relu(self.conv(x))         
        return self.fc(self.flatten(x))

torch.manual_seed(2)
x = torch.randn(1, 3, 8, 8)
net = TestNet()
out = net(x)
loss = (1 - out).mean()
loss.backward()

print(out.shape)
# bias_grad in linear layer should be -1/out.shape
print('bias_grad in linear layer:', net.fc.bias.grad)

torch.Size([5])
bias_grad in linear layer: tensor([-0.2000, -0.2000, -0.2000, -0.2000, -0.2000])


#### module中给tensor加入backward hook后，Backward pass的执行顺序
1. 从root开始按照chainrule执行backward pass
2. 遇到hook后，对制定tensor的grad执行相应的ops，如果同一位置有多个hook，按他们在module中出现的顺序执行操作，而不是想chainrule那样反向操作。如下例：\
fc layer backward method     -> \
flatten layer                -> \
register for clamp           -> \
register for print shape     -> \
register for check gradient  -> \
relu layer backward method   -> \
conv layer backward method

In [9]:
# ----------------  加入hook  ----------------
class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 2, stride=2)
        self.relu = nn.ReLU()
        self.flatten = lambda x: x.view(-1)
        self.fc = nn.Linear(160, 5)
    
    def forward(self, x):
        x = self.relu(self.conv(x))
        
        # 设置hook：让x.grad >= 0
        # 这里处理的 x 是relu layer的output
        # relu的梯度只有0和1，所以处理relu之后的x即可
        x.register_hook(lambda grad: torch.clamp(grad, min=0))
        
        x.register_hook(lambda grad: print('size of output of relu layer:', grad.shape))
        # 再加一个hook确认有没有x.grad是负值
        # 放在这里改变的是relu运算后的x.grad
        x.register_hook(lambda grad: print('any grad < 0?', \
                                           bool((grad < 0).any())))
        x = self.flatten(x)
        # 这里处理的 x 是flatten layer的output
        x.register_hook(lambda grad: print('size of output of flatten layer:', grad.shape))

        y = self.fc(x)
        return y    
    
torch.manual_seed(2)
x = torch.randn(1, 3, 8, 8)
net = TestNet()

# 在model外部，给参数tenosr设置hook: 这里将linear层中bias的梯度改为0
# 这种方式可以在不改变module定义的条件下设置hook
for name, param in net.named_parameters():
    if 'fc' in name and 'bias' in name:
        param.register_hook(lambda grad: torch.zeros(grad.shape))

out = net(x)
loss = (1 - out).mean()
loss.backward()

# 确认bias的grad成功改成了全0
print('bias_grad in linear layer:', net.fc.bias.grad)

size of output of flatten layer: torch.Size([160])
size of output of relu layer: torch.Size([1, 10, 4, 4])
any grad < 0? False
bias_grad in linear layer: tensor([0., 0., 0., 0., 0.])


## 2. Hooks for nn.Module objects
用于nn.Module的hook有forward和backward hook。\
**signature of hook function：** 
1. for backward hook: <font color=red>hook(module, grad_input, grad_output) -> Tensor or None </font>
2. for foreward hook: <font color=red>hook(module, input, output) -> None </font> 

In [10]:
## 例1：1个用hook查看module中activation value的例子

class Net(nn.Module):
    def __init__(self):
        super().__init__() 
        self.conv = nn.Conv2d(3,8,2)
        self.pool = nn.AdaptiveAvgPool2d((4,4))
        self.fc = nn.Linear(8*4*4 , 1)
    def forward(self, x):
          x = F.relu(self.conv(x))
          x = self.pool(x)
          x = x.view(x.shape[0] , -1)
          x = self.fc(x)
          return x

In [11]:
# 定义hook，用来提取activation的结果，也就是feature
features = {} 
def hook_func(model, input ,output):
    features['feature'] = output.detach()

In [12]:
net = Net()
# 在pooling layer上register hook
net.pool.register_forward_hook(hook_func)

x= torch.randn(1,3,10,10)
output = net(x)
print(features['feature'].shape)

torch.Size([1, 8, 4, 4])


### 2.1 典型应用场景和应用方式
#### 场景1. 给module加wrapper，在wrapper module上加hooks，打印所需信息
**优点：**
1. 方便debug，避免手动增加和删除print的麻烦
2. 不仅可以在自定义module，还可以在pytorch自带的module和第三方module上使用

In [13]:
## 场景1例1：在ResNet18上用hooks来打印model信息

#  给model加一个wrapper class
class VerboseNet(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        
        # register a hook for each layer
        for name, layer in self.model.named_children():
            # 虽然所有nn.Mudule的实例都继承了__name__属性，但有的module可能没赋值
            # 这一步可以明确让layer都有对应的name
            layer.__name__ = name
            layer.register_forward_hook(
                lambda layer, _, output: print(f"new --->: {layer.__name__}: {output.shape}")
            )
    def forward(self, x: torch.Tensor):
        return self.model(x)

In [14]:
from torchvision.models import resnet18, ResNet18_Weights

verbose_resnet = VerboseNet(resnet18(weights=ResNet18_Weights.DEFAULT))
dummy_input = torch.ones(10, 3, 28, 28)

# 这里用了assignment是为了不让jupyter打印函数的output
_ = verbose_resnet(dummy_input)

new --->: conv1: torch.Size([10, 64, 14, 14])
new --->: bn1: torch.Size([10, 64, 14, 14])
new --->: relu: torch.Size([10, 64, 14, 14])
new --->: maxpool: torch.Size([10, 64, 7, 7])
new --->: layer1: torch.Size([10, 64, 7, 7])
new --->: layer2: torch.Size([10, 128, 4, 4])
new --->: layer3: torch.Size([10, 256, 2, 2])
new --->: layer4: torch.Size([10, 512, 1, 1])
new --->: avgpool: torch.Size([10, 512, 1, 1])
new --->: fc: torch.Size([10, 1000])


In [15]:
## 场景1例2：Feature extraction
#  说明：在一个预训练好的模型上做transfer learning时，可能想查看该模型的feature
#  hook可以在不改变模型本身的条件下实现这一需求

## 仍然是用wrapper
class FeatureExt(nn.Module):
    def __init__(self, model, layers):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}
        
        for layer_i in layers:
            # 把named_modules的sub-module list构造成dict，用layer_i索引
            # [说明见下一个cell]
            layer = dict([*self.model.named_modules()])[layer_i]
            
            # layer_i索引出来对应layer后，给该layer加上hook:
            layer.register_forward_hook(self.save_outputs_hook(layer_i))
         
    # 定义hook function：
    def save_outputs_hook(self, layer_i):
        def fn(_, __, output):  # 下划线长度不同，因为arguments name不能相同
            # 将该layer的output提取出来，存入feature dict
            self._features[layer_i] = output
        return fn
        
    def forward(self, x):
        _ = self.model(x)
        return self._features  # 返回dict存放了指定layers的activation output

In [16]:
res18 = resnet18(weights=ResNet18_Weights.DEFAULT)
## 对resnet.named_modules的说明：
#  1. resnet18.named_modules是一个generator
#     每次yield返回的是tuple，形如('layer_name', module)
#  2. 加'*'来把generator解包成单独的tuples：*res18.named_modules()
#  3. 将dict()用到list of tuples上，可以转换成key-value pair，便于用key索引
#     这里要先将解包后的tuples打包到一个list里面，所以是用 dict([])

#  ----- 打印sample layers 来查看网络的结构  -----
# print([*res18.named_modules()][0]) # nest结构的顶层，描绘整个结构
# print('-' * 40)
# print([*res18.named_modules()][5]) # nest结构的第2层，'layer1'
# print('-' * 40)
# print([*res18.named_modules()][6]) # nest结构的第3层，'layer1'的sub-module
# print('-' * 40)
# print([*res18.named_modules()][7]) # nest结构的第4层，...

In [17]:
res18_feats = FeatureExt(res18, layers=['layer4', 'avgpool'])

fests = res18_feats(dummy_input)
for name, output in fests.items():
    print(name, '->', output.shape)

layer4 -> torch.Size([10, 512, 1, 1])
avgpool -> torch.Size([10, 512, 1, 1])


#### 场景2. 