<a href="https://colab.research.google.com/github/stemgene/Computer-Vision-Projects/blob/main/01_supplement_register_forward_hook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://www.bilibili.com/video/BV1YL4y1A7oY/?spm_id_from=333.788&vd_source=81884c519d60bbdad4b6fd87d340415f

重点讲一下在01 extract image vector中用到的`layer.register_forward_hook(function)`

这个函数的基本概念是数据在从layer向前forward遍历结束后，执行function函数。有些类似map函数

In [25]:
!pip install timm



In [26]:
import timm
import torch

In [27]:
timm.list_models("vgg*")

['vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn']

定义function

In [28]:
def print_shape(m, i, o):
    print(m)
    print(i[0].shape, "==>", o.shape)

In [29]:
model_name = "vgg11"
model = timm.create_model(model_name, pretrained=True)

In [30]:
for layer in model.children():
    layer.register_forward_hook(print_shape)

In [31]:
batch_input = torch.randn(4, 3, 128, 128)

In [32]:
model(batch_input)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): ReLU(inplace=True)
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (12): ReLU(inplace=True)
  (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (14): ReLU(inplace=True)
  (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (16): Conv2d(512, 512, kernel_size=(3, 3), stride=

tensor([[-1.7719,  0.6477, -0.5525,  ...,  0.5022, -1.0627,  2.3289],
        [-1.9030,  0.5128,  0.0219,  ..., -0.0405, -1.1995,  2.8581],
        [-1.8407,  0.5190, -0.6980,  ..., -0.1646, -0.6779,  2.8847],
        [-2.0917,  0.5560,  0.5120,  ..., -0.9731, -1.5027,  2.2113]],
       grad_fn=<AddmmBackward0>)

可以看到在调用`model.children`时，把`Sequential`, `ConvMlp`（由两个fc构成的MLP）, `classifierHead`当作三个layer，然后在每次通过layer之后，调用`print_shape`函数。

但准确的说，这几个应该被看作block，因为现代deep learning model已经很少有这种简单的layer堆叠的结构了

如果要显示各个层，需要新建一个逐层分析的函数：

In [34]:
# 递归逐层分析

def get_children(model: torch.nn.Module):
    # get children from model
    children = list(model.children())
    flatten_children = []
    if children == []:
        # if model has no children, model is last child
        return model
    else:
        # look for children from children... to the last child
        for child in children:
            try:
                flatten_children.extend(get_children(child))
            except TypeError:
                flatten_children.append(get_children(child))
    return flatten_children

In [36]:
flatten_children = get_children(model)
for layer in flatten_children:
    layer.register_forward_hook(print_shape)

In [37]:
batch_input = torch.randn(4, 3, 128, 128)
model(batch_input)

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
torch.Size([4, 3, 128, 128]) ==> torch.Size([4, 64, 128, 128])
ReLU(inplace=True)
torch.Size([4, 64, 128, 128]) ==> torch.Size([4, 64, 128, 128])
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
torch.Size([4, 64, 128, 128]) ==> torch.Size([4, 64, 64, 64])
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
torch.Size([4, 64, 64, 64]) ==> torch.Size([4, 128, 64, 64])
ReLU(inplace=True)
torch.Size([4, 128, 64, 64]) ==> torch.Size([4, 128, 64, 64])
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
torch.Size([4, 128, 64, 64]) ==> torch.Size([4, 128, 32, 32])
Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
torch.Size([4, 128, 32, 32]) ==> torch.Size([4, 256, 32, 32])
ReLU(inplace=True)
torch.Size([4, 256, 32, 32]) ==> torch.Size([4, 256, 32, 32])
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
torch.Size([4, 256, 32, 32]

tensor([[-1.8095,  0.2594,  0.2079,  ..., -1.2397, -1.4545,  3.0064],
        [-1.7435,  0.9886,  1.1593,  ..., -0.8438, -0.9814,  2.9578],
        [-1.5887,  0.5364, -0.1238,  ..., -0.4951, -1.0784,  2.6812],
        [-1.2460,  1.3619,  0.2186,  ...,  0.1429, -1.4827,  2.4170]],
       grad_fn=<AddmmBackward0>)

这个hook函数非常灵活，作为输入的三个参数是model, input和output，可以针对它们做非常多的功能

In [41]:
def print_shape(m, i, o):
    print(i[0].shape, "==>", o.shape)

In [42]:
model_name = "vgg11"
model = timm.create_model(model_name, pretrained=True)
flatten_children = get_children(model)
for layer in flatten_children:
    layer.register_forward_hook(print_shape)
batch_input = torch.randn(4, 3, 128, 128)
model(batch_input)

torch.Size([4, 3, 128, 128]) ==> torch.Size([4, 64, 128, 128])
torch.Size([4, 64, 128, 128]) ==> torch.Size([4, 64, 128, 128])
torch.Size([4, 64, 128, 128]) ==> torch.Size([4, 64, 64, 64])
torch.Size([4, 64, 64, 64]) ==> torch.Size([4, 128, 64, 64])
torch.Size([4, 128, 64, 64]) ==> torch.Size([4, 128, 64, 64])
torch.Size([4, 128, 64, 64]) ==> torch.Size([4, 128, 32, 32])
torch.Size([4, 128, 32, 32]) ==> torch.Size([4, 256, 32, 32])
torch.Size([4, 256, 32, 32]) ==> torch.Size([4, 256, 32, 32])
torch.Size([4, 256, 32, 32]) ==> torch.Size([4, 256, 32, 32])
torch.Size([4, 256, 32, 32]) ==> torch.Size([4, 256, 32, 32])
torch.Size([4, 256, 32, 32]) ==> torch.Size([4, 256, 16, 16])
torch.Size([4, 256, 16, 16]) ==> torch.Size([4, 512, 16, 16])
torch.Size([4, 512, 16, 16]) ==> torch.Size([4, 512, 16, 16])
torch.Size([4, 512, 16, 16]) ==> torch.Size([4, 512, 16, 16])
torch.Size([4, 512, 16, 16]) ==> torch.Size([4, 512, 16, 16])
torch.Size([4, 512, 16, 16]) ==> torch.Size([4, 512, 8, 8])
torch.Si

tensor([[-1.1220,  1.1860,  0.6997,  ...,  0.4043, -1.3233,  2.3906],
        [-2.0805,  0.5330, -0.2525,  ..., -0.1023, -0.7992,  3.3472],
        [-1.9768,  0.9057, -0.2673,  ..., -0.5380, -1.2889,  2.0675],
        [-2.2991,  0.8637, -0.4698,  ..., -0.2876, -1.4231,  3.3596]],
       grad_fn=<AddmmBackward0>)