### 注册前向钩子（forward前），Module.register_forward_pre_hook(hook) -> torch.utils.hooks.RemovableHandle
* 返回类型是torch.utils.hooks.RemovableHandle，使用handle.remove()删掉该钩子
* 调用模块forward前执行钩子!!!
* 钩子函数格式是：hook(module, input) -> None or modified input，返回值为None或者返回修改后的input
* 钩子编程是编程范式

 -  假如你的模型有三个步骤才能走完。 你在第二个步骤后添加一个钩子。  模型运行时，会先运行前两个步骤，然后再执行钩子。再执行第三个步骤
      - 钩子可以把一些当前的数据带出来。实现可视化等功能。
      
 - return: 如果return了数据，return的数据将 改变input

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
'''主要改变输入的信息。
主要功能。在forward前修改输入
    - 本来输出是 1,2,1,1
    - 增加hook结果输出时 1,2,3,3
'''
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 3, 3)
        self.conv2 = nn.Conv2d(3, 2, 3)
        
    def forward(self, x):
        print("==[forward]==\t Do forward, x.shape is", x.shape)
        x = F.relu(self.conv1(x))
        output = F.relu(self.conv2(x))
        print("==[forward]==\t End forward, output shape is", output.shape)
        return output
    
def before_hook2(module, x):
    #input 是tuple
    print("==[Hook]==\t Pre Hook:", module, "\n==[Hook]==\t Input shape is:", x[0].shape)
    print(f"==[Hook]==\t ", module.conv1, sep="")
    return torch.zeros(1, 1, 7, 7) # 使得在forward前，输入数据已经从torch.zeros(1, 1, 5, 5)修改为torch.zeros(1, 1, 7, 7)

model2 = Model()
handle = model2.register_forward_pre_hook(before_hook2) # model2.后面注册，说明修改所有模型的前向。
input1 = torch.zeros(1, 1, 5, 5) # 输入形状(1, 1, 5, 5) ，正常经过网络卷积，输出应该是(1, 2, 1, 1)
model2(input1)
handle.remove()    # 删除

# 结论：register_forward_pre_hook钩子方便 模型的 修改输入数据。

==[Hook]==	 Pre Hook: Model(
  (conv1): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1))
) 
==[Hook]==	 Input shape is: torch.Size([1, 1, 5, 5])
==[Hook]==	 Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
==[forward]==	 Do forward, x.shape is torch.Size([1, 1, 7, 7])
==[forward]==	 End forward, output shape is torch.Size([1, 2, 3, 3])


### 注册前向钩子（forward后），Module.register_forward_hook(hook) -> torch.utils.hooks.RemovableHandle
* 返回类型是torch.utils.hooks.RemovableHandle，使用handle.remove()删掉该钩子
* 调用模块forward后执行钩子
* 钩子函数格式是：hook(module, input, output) -> None or modified output，返回值为None或者返回修改后的output
* 钩子编程是编程范式

In [3]:
class Model3(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 3, 3)
        self.conv2 = nn.Conv2d(3, 2, 3)
    def forward(self, x):
        print("==[forward]==\t Do forward, x.shape is", x.shape)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        print("==[forward]==\t End forward, output shape is", x.shape)
        return x

def hook_function(module, x, output):
    print(f"==[Hook]==\t this is register_forward_hook", module, f"\n==[Hook]==\t", "this is x\t", x[0].shape,f"\n==[Hook]==\t", "this is output\t", output.shape)
    # 本来，模型结果形状是  [1, 2, 1, 1]
    return torch.zeros(1, 1, 7, 7) # 但是hook在forward执行完。最后，return的结果改了output
    
model3 = Model3()
x = torch.randn((1, 1, 5, 5))
handel = model3.register_forward_hook(hook_function)
res = model3(x)
print(f"==[res]==\t", res.shape)
handel.remove()

# 结论：register_forward_hook钩子方便修改 模型的 输出数据。

==[forward]==	 Do forward, x.shape is torch.Size([1, 1, 5, 5])
==[forward]==	 End forward, output shape is torch.Size([1, 2, 1, 1])
==[Hook]==	 this is register_forward_hook Model3(
  (conv1): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(3, 2, kernel_size=(3, 3), stride=(1, 1))
) 
==[Hook]==	 this is x	 torch.Size([1, 1, 5, 5]) 
==[Hook]==	 this is output	 torch.Size([1, 2, 1, 1])
==[res]==	 torch.Size([1, 1, 7, 7])


#### 总结：
    - model.register_forward_pre_hook。  forward前的钩子return的值改的是input
    - model.register_forward_hook。      forward后的钩子return的值改的是output

### 直接替换forward函数

#### 1.0 nn层的forward替换

In [4]:
class Model3(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 3, 3)
        self.conv2 = nn.Conv2d(3, 2, 3)
    def forward(self, x):
        print("==[forward]==\t Do forward, x.shape is", x.shape)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        print("==[forward]==\t End forward, output shape is", x.shape)
        return x

# 单独改torch.nn.Conv2d层的forward
def hook_function3(oldfunc):
    def inner_function(self, x):
        res = oldfunc(self, x)
        print(f"==[Hook]==\t type:{type(self)} \t input_id:{id(x)} \t output_id:{id(res)}")
        return res
    return inner_function


nn.Conv2d.forward = hook_function3(nn.Conv2d.forward) # 把nn.Conv2d.forward函数替换成了inner_function函数

model4 = Model3().eval()
x = torch.randn((1, 1, 5, 5))
output = model4(x)

# 总结：替换了 特定nn层 的forward函数。
# 方法：使用python特性，直接替换特定函数的实现，以实现自定义修改。
# 这种替换更为灵活。没有上面两个的局限性。

==[forward]==	 Do forward, x.shape is torch.Size([1, 1, 5, 5])
==[Hook]==	 type:<class 'torch.nn.modules.conv.Conv2d'> 	 input_id:140446614049952 	 output_id:140439854633904
==[Hook]==	 type:<class 'torch.nn.modules.conv.Conv2d'> 	 input_id:140439854633984 	 output_id:140439854633904
==[forward]==	 End forward, output shape is torch.Size([1, 2, 1, 1])


#### 2.0 relu也替换掉的例子
- 这里换了一个conv与relu分开的网络

In [1]:
import numpy as np
import torch.nn as nn
import torch

class Model5(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 3, 1, 1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(3, 1, 1, 1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x
        
def hook_function(oldfunc): # 传入的参数是一个函数地址，即被替换掉的函数。
    def inner_function(self, x): # 内部的函数参数形式要与forward一样。
        res = oldfunc(self, x)
        print(f"==[Hook]==\t {type(self)}, input id = {id(x)}, output id = {id(res)}")
        return res
    return inner_function
    
nn.Conv2d.forward = hook_function(nn.Conv2d.forward) # 替换所有Conv2d模块的前向函数
nn.ReLU.forward =   hook_function(nn.ReLU.forward)   

model = Model5().eval()
input = torch.zeros(1, 3, 3, 3)
a = model(input)        

==[Hook]==	 <class 'torch.nn.modules.conv.Conv2d'>, input id = 139901407312928, output id = 139901407312448
==[Hook]==	 <class 'torch.nn.modules.activation.ReLU'>, input id = 139901407312448, output id = 139901407310768
==[Hook]==	 <class 'torch.nn.modules.conv.Conv2d'>, input id = 139901407310768, output id = 139901407312448


In [None]:
'''
==[Hook]==	 <class 'torch.nn.modules.conv.Conv2d'>, input id = 139901407312928, output id = 139901407312448
==[Hook]==	 <class 'torch.nn.modules.activation.ReLU'>, input id = 139901407312448, output id = 139901407310768
==[Hook]==	 <class 'torch.nn.modules.conv.Conv2d'>, input id = 139901407310768, output id = 139901407312448

id = 139901407312448 出现了多次。这是因为pytorch的内存复用造成的结果。
'''

#### 3.0 更新版
- 对程序进一步修改。
- 主要功能仍然是，将nn模块自带的forward函数，在功能(前向)不变的情况下，替换为自定义myforward函数。
- 使用带参装饰器，参数为想要改变的 nn模块 + forwar字符串。
  - 例如像改变  Conv2d 模块 的 forward。 传参 "torch.nn.Conv2d.forward"
- 使用getattr，在函数内得到与装饰器参数一致的forward对象
- 使用setattr，真正做到修改forward函数
- 使用clone()，解决pytorch自动优化tensor导致tensor复用的问题。仍然留下一个问题

In [None]:
# 为什么要避免复用？
# 主要和后面构建onnx有关。
# ==[Hook]==	 <class 'torch.nn.modules.conv.Conv2d'>, input id = 139901407312928, output id = 139901407312448
# ==[Hook]==	 <class 'torch.nn.modules.activation.ReLU'>, input id = 139901407312448, output id = 139901407310768
# ==[Hook]==	 <class 'torch.nn.modules.conv.Conv2d'>, input id = 139901407310768, output id = 139901407312448
# 从打印结果看，第一个Conv2d的output的id与第二个Conv2d的output的id一致。这样在构建graph时会出现问题。

In [1]:
import numpy as np
import torch.nn as nn
import torch

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 3, 1, 1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(3, 1, 1, 1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x
        
def hook_forward(fn):# 再包裹一层，使得能够使用装饰器。
    fnname_list = fn.split(".") # ['torch', 'nn', 'Conv2d', 'forward']
    layer_name = eval(".".join(fnname_list[:-1])) # 得到torch.nn.Conv2d
    fn_name = fnname_list[-1] # "forward"
    oldfn = getattr(layer_name, fn_name) # 得到<function Conv2d.forward at 0x000002AAED0B8940>
                                         # <function ReLU.forward at 0x000002AAED0D31F0>
    
    def make_hook(bind_func):
        
        def myforward(self, x):
            y = oldfn(self, x).clone() # 避免pytorch对tensor复用，但并不能总是保证id唯一。为何要避免复用，见最后面？
            bind_func(self, x, y) 
            return y
        
        setattr(layer_name, fn_name, myforward) # 改变forward
        
        # return myforward
    
    return make_hook
    

@hook_forward("torch.nn.Conv2d.forward")
def symbolic_conv2d(self, x, y):
    print(f"{type(self)}, input id = {id(x)}, output id = {id(y)}")
    
@hook_forward("torch.nn.ReLU.forward")
def symbolic_relu(self, x, y):
    print(f"{type(self)}, input id = {id(x)}, output id = {id(y)}")
    
model = Model().eval()
input1 = torch.ones(1, 3, 3, 3)
a = model(input1)
a


# layer_name有哪些键
# for k,v in vars(layer_name).items():
#     print(k)
# __module__
# __doc__
# __init__
# _conv_forward
# forward

# id有什么变化
# 除特殊情况，id已经没有重复的了。

<class 'torch.nn.modules.conv.Conv2d'>, input id = 139885180861632, output id = 139885180859472
<class 'torch.nn.modules.activation.ReLU'>, input id = 139885180859472, output id = 139885180859312
<class 'torch.nn.modules.conv.Conv2d'>, input id = 139885180859312, output id = 139885180860992


tensor([[[[0.3356, 0.3356, 0.3356],
          [0.3356, 0.3356, 0.3356],
          [0.3356, 0.3356, 0.3356]]]], grad_fn=<CloneBackward0>)

"""
执行顺序说明：
1.0 代码从上往下执行, 先导入各种模块, 定义Model类，定义hook_forward函数

2.0 再执行到带参数的@hook_forward装饰器。执行hook_forward前4行代码
    定义make_hook函数，返回make_hook的引用

3.0 定义symbolic_conv2d函数，当被装饰的函数symbolic_conv2d定义好了
    则将被装饰的函数作为参传入刚刚执行装饰器返回的myforward函数并执行, 
    即执行myforward(symbolic_conv2d)
    
4.0 再定义myforward函数, 利用setattr，用myforward替换掉
    <class 'torch.nn.modules.conv.Conv2d'>的forward属性，

ReLU的执行重复3.0  4.0 即
5.0 再执行到带参数的@hook_forward装饰器。执行hook_forward前4行代码
    定义make_hook函数，返回make_hook的引用

6.0 再定义myforward函数, 利用setattr，用myforward替换掉
    <class 'torch.nn.modules.activation.ReLU'>的forward属性

7.0 实例化Model，执行__init__，置为eval模式，构造数据，执行model的前向forward

8.0 在执行self.conv1(x)时，实际执行的是hook_forward中的make_hook中的myforward
    

## 4.0 再改进版
- tensor什么时候会复用？
    - 答案：在tensor没有任何引用的时候，tensor会被回收，并且会被复用
    - 举例：在forward函数中
    
    ![Alt text](./image.png)
    
        - 执行完self.conv1(x)后，x1的引用计数减1，此时引用计数不一定为0(外面可能会有别的引用)，x1不一定会被回收
        - 执行完self.relu(x)后, x2是真正的没有引用了。此时就可能被复用
        - 在执行self.conv2(x)时。等号左边的x6需要新的内存，
            - 此时发现x2没有引用了，内存还没被释放
            - 于是就会使用这个内存。

- 如何不让他复用？
    - 解决方案是，使其引用计数一直存在，不释放。
    


In [None]:
import numpy as np
import torch.nn as nn
import torch

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 3, 1, 1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(3, 1, 1, 1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        return x
        
def hook_forward(fn):
    nnmodule_list = fn.split(".")
    nnmodule_name = eval(".".join(nnmodule_list[:-1])) 
    func_name = nnmodule_list[-1]
    oldfunc = getattr(nnmodule_name, func_name)
    
    def make_hook(bind_func):
        def newforward(self, x):
            global all_tensors
            y = oldfunc(self, x)
            all_tensors.extend([x, y])
            bind_func(self, x, y)
            return y
        setattr(nnmodule_name, func_name, newforward)
    return make_hook

def get_obj_id(obj):
    global objmap # 引入global变量objmap
    obj_id = id(obj) # 用id作为键，用长度作为值。
    
    if obj_id not in objmap:
        objmap[obj_id] = len(objmap)
    return objmap[obj_id]
        
    
@hook_forward("torch.nn.Conv2d.forward")
def symbolic_conv2d(self, x, y):
    print(f"{type(self)}, input id = {get_obj_id(x)}, output id = {get_obj_id(y)}")
    
@hook_forward("torch.nn.ReLU.forward")
def symbolic_relu(self, x, y):
    print(f"{type(self)}, input id = {get_obj_id(x)}, output id = {get_obj_id(y)}")
    
all_tensors = []
objmap = {} 
    
model = Model().eval()
input1 = torch.ones(1, 3, 3, 3)
a = model(input1)
a


'''
<class 'torch.nn.modules.conv.Conv2d'>, input id = 0, output id = 1
<class 'torch.nn.modules.activation.ReLU'>, input id = 1, output id = 2
<class 'torch.nn.modules.conv.Conv2d'>, input id = 2, output id = 3
[3]:
tensor([[[[-1.1552, -1.1552, -1.1552],
          [-1.1552, -1.1552, -1.1552],
          [-1.1552, -1.1552, -1.1552]]]], grad_fn=<ConvolutionBackward0>)
'''