In [1]:
import torch
import torchviz
import torch.nn as nn
import torch.nn.functional as F

# Hooks
官方文档：
   - hooks for saved tensors
   - autograd mechanics中的Hooks for saved tensors
   - modules中的module hooks部分

#### 基本概念
- **什么是hook**
  - hook是一个特殊函数，它在autograd.Function的forward或backward method被调用前，或者计算结束后被执行。用来在常规的forward/backward pass中执行任意的code。在不调整forward/backward method定义，从而不会影响autograd machine的执行逻辑的条件下，完成hook Function中指定的‘支线任务’。

- **hooks的类型**:
  - 按照接入autograd过程的位置，可以分为：
    - forward hook: 在forward method计算开始前或计算完成后执行
    - backward hook：在backward method计算开始前或计算完成后执行
  - 按照作用对象，可以分为3种，其中两种都用在tensor对象上：
    - 作用在tensor上的tentor.register_hook()，torch.Tensor.register_post_accumulate_grad_hook()，torch.autograd.graph.Node.register_hook()，torch.autograd.graph.Node.register_prehook()
    - 专门用在graph中saved_tensor上的torch.autograd.graph.saved_tenors_hooks()
    - 作用在module上的torch.nn.module.register_module_(forward_/bacward_)hook

- **hook的使用对象**：tensor或者nn.Module
  - **on tensor**：用于tensor的hook只有backward hook
  - **on nn.Module：** 既有forward又有backward hook

- **用法：针对各自特定的对象，做inspect或者modify**
  1. <font color=blue>**saved_tensor上的hook**</font>功能明确，一般用于将forward method向backward method传递的tensor转移到GPU之外的存储介质上，以节省memory space
  2. <font color=blue>**tensor.register_hook(hookfunc)**</font>只有backward hook，在每次register的tensor在backward pass中计算出gradient值之后被执行。作用的对象是tensor.grad。
  3. module类的hook既有forward hook，又有backward hook
     1. <font color=blue>**module forward hook**</font>定义方式是hook_name(module, input, output)。可见作用于module的input和output上。
     2. <font color=blue>**module backward hook**</font>定义方式是hook_name(module, grad_input, grad_output)。可见作用于module的grad_input和grad_output上。
- 具体场景有：
  1. 查看gradient或者forward计算的intermediate result，比如activation的信息(value, shape)用来做模型分析。比如：
     - debug，比如看gradient有没有vanishing/exploding
     - feature visualization：用intermediate activations的值
     - get feature map statistics：查看acitvation的mean, std, max/min
  2. gradient manipulation。这个可以用来实现：
     - gradient clipping
     - custom regularization：用自定义的方式限制weight的值

## I. Hooks for generic tensors
参考资料：pytorch101系列5，understanding hooks https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/

#### 1. 打印intermediate value的grad信息

In [2]:
#  定义hook function
def printgrad(grad):
    return print(grad)

In [3]:
# 如果不用hook就要设置retain_grad参数，会占用额外的memory
torch.manual_seed(2)
w = torch.randn(5, requires_grad=True)
b = torch.randn(3, requires_grad=True)
x = torch.ones((3, 5))
y = x @ w + b
loss = (y ** 2).sum()

y.retain_grad()

loss.backward()
print(y.grad)

tensor([-1.8891,  0.5235,  0.4592])


In [4]:
# 用了hook就不需要设置
torch.manual_seed(2)
w = torch.randn(5, requires_grad=True)
b = torch.randn(3, requires_grad=True)
x = torch.ones((3, 5))
y = x @ w + b
loss = (y ** 2).sum()

y.register_hook(printgrad)

loss.backward()

tensor([-1.8891,  0.5235,  0.4592])


#### 2. 修改gradient计算方式
- 使用hook修改gradient的计算方式后，autograd能自动将修改后的值向前传递

In [5]:
#  定义hook function
def doublegrad(grad):
    return grad * 2

In [6]:
# 原函数
torch.manual_seed(2)
w = torch.randn(5, requires_grad=True)
b = torch.randn(3, requires_grad=True)
x = torch.ones((3, 5))
y = x @ w + b
loss = (y ** 2).sum()

y.register_hook(printgrad)

loss.backward()
print('w:', w.grad)

tensor([-1.8891,  0.5235,  0.4592])
w: tensor([-0.9064, -0.9064, -0.9064, -0.9064, -0.9064])


In [7]:
# 如果不用hook，而是在retain_grad之后直接修改y的gradient：
# 修改后的值不会向前传递，forward中前序参数的梯度值没有变化
torch.manual_seed(2)
w = torch.randn(5, requires_grad=True)
b = torch.randn(3, requires_grad=True)
x = torch.ones((3, 5))
y = x @ w + b
loss = (y ** 2).sum()

y.retain_grad()

loss.backward()

y.grad *= 2
print('y:', y.grad)
print('w:', w.grad)

y: tensor([-3.7783,  1.0471,  0.9184])
w: tensor([-0.9064, -0.9064, -0.9064, -0.9064, -0.9064])


In [8]:
# 改成hook之后，hook函数的输出会代替y的grad在autograd执行过程中使用
torch.manual_seed(2)
w = torch.randn(5, requires_grad=True)
b = torch.randn(3, requires_grad=True)
x = torch.ones((3, 5))
y = x @ w + b
loss = (y ** 2).sum()

y.register_hook(doublegrad) # hook for gradient manipulation
y.register_hook(printgrad) # hook for print

loss.backward()

print('w:', w.grad)

tensor([-3.7783,  1.0471,  0.9184])
w: tensor([-1.8128, -1.8128, -1.8128, -1.8128, -1.8128])


#### 3. 对module中指定的tensor做gradient clipping

In [9]:
# ----------------  原模型  ----------------
class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 2, stride=2)
        self.relu = nn.ReLU()
        self.flatten = lambda x: x.view(-1)
        self.fc = nn.Linear(160, 5)
    
    def forward(self, x):
        x = self.relu(self.conv(x))         
        return self.fc(self.flatten(x))

torch.manual_seed(2)
x = torch.randn(1, 3, 8, 8)
net = TestNet()
out = net(x)
loss = (1 - out).mean()
loss.backward()

print(out.shape)
# bias_grad in linear layer should be -1/out.shape
print('bias_grad in linear layer:', net.fc.bias.grad)

torch.Size([5])
bias_grad in linear layer: tensor([-0.2000, -0.2000, -0.2000, -0.2000, -0.2000])


- <font color=blue>**给tensor加入backward hook后，Backward pass的执行顺序**</font>
  1. 从root开始按照chainrule执行backward pass
  2. 遇到hook后，对制定tensor的grad执行相应的ops，如果同一位置有多个hook，按他们在module中出现的顺序执行操作，而不是想chainrule那样反向操作。如下例：\
     - fc layer backward method     -> 
     - flatten layer                ->   
     - register for clamp           ->
     - register for print shape     -> 
     - register for check gradient  -> 
     - relu layer backward method   -> 
     - conv layer backward method

In [10]:
#    1. 将linear层中bias的梯度改为0
#    2. conv layer从downstream拿到的gradient大小都不小于0

# ----------------  加入hook  ----------------
class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 2, stride=2)
        self.relu = nn.ReLU()
        self.flatten = lambda x: x.view(-1)
        self.fc = nn.Linear(160, 5)
    
    def forward(self, x):
        x = self.conv(x)
        
        x = self.relu(x)
        # 设置hook：让x.grad >= 0
        # 处理的是relu的output,因为relu的梯度只有0和1，所以处理relu之后的x即可
        x.register_hook(lambda grad: torch.clamp(grad, min=0))
        
        x.register_hook(lambda grad: print('size of relu output:', grad.shape))
        # 再加一个hook确认有没有x.grad是负值
        # 这里改变的是relu运算后的x.grad，穿透relu的0，1mask，传给conv的仍然>0
        x.register_hook(lambda grad: print('any grad < 0?', \
                                           bool((grad < 0).any())))
        x = self.flatten(x)
        y = self.fc(x)
        return y    
    
torch.manual_seed(2)
x = torch.randn(1, 3, 8, 8)
net = TestNet()

# 在model外部，给参数tenosr设置hook: 这里将linear层中bias的梯度改为0
# 这种方式可以在不改变module定义的条件下设置hook
for name, param in net.named_parameters():
    if 'fc' in name and 'bias' in name:
        param.register_hook(lambda grad: torch.zeros(grad.shape))

out = net(x)
loss = (1 - out).mean()
loss.backward()

# 确认bias的grad成功改成了全0
print('bias_grad in linear layer:', net.fc.bias.grad)

size of relu output: torch.Size([1, 10, 4, 4])
any grad < 0? False
bias_grad in linear layer: tensor([0., 0., 0., 0., 0.])


#### Whether a particular hook will be fired
1. **torch.Tensor.register_hook()** 在每次tensor的梯度计算后执行。
   - <font color=orange>注，不要求tensor的grad_fn执行。比如，当Tensor被作为inputs参数传给torch.autograd.grad()时, 它的grad_fn不会被执行, 但它上面登记的register_hook会被执行。</font>说明：
     - 通常情况下，grad_fn is used to propagate gradients backwards through the computation graph.但是当Tensor被作为inputs参数传给torch.autograd.grad()时, 它在计算图上会被当做leaf node。这就意味着它的grad_fn参数指向的backward method不会被执行来进一步向前传梯度。
2. **torch.Tensor.register_post_accumulate_grad_hook()** 在完成梯度累积操作后执行。
   - 用torch.Tensor.register_hook()登记的hook在每次gradients计算的时候都要执行, 但用torch.Tensor.register_post_accumulate_grad_hook()登记的hook只在backward pass的最后，完成梯度累积后执行一次。因此，post-accumulate-grad hooks的作用对象只能是leaf Tensors. 登记在非leaf tensor上会报错。

#### The order in which the different hooks are fired
1. hooks registered to Tensor are executed
2. pre-hooks registered to Node are executed (if Node is executed).
3. the .grad field is updated for Tensors that retain_grad
4. Node is executed (subject to rules above)
5. for leaf Tensors that have .grad accumulated, post-accumulate-grad hooks are executed
6. post-hooks registered to Node are executed (if Node is executed)

If multiple hooks of the same type are registered on the same Tensor or Node they are executed in the order in which they are registered. 

## II. Hooks for nn.Module objects

- <font color=norange>hook函数的返回值将代替原input/output用在剩余的forward/backward computation中. </font>
- 所以，这些hooks可以：
  1. execute arbitrary code along the regular module forward/backward.
  2. 修改指定的inputs/outputs而不用改变module的forward() function.
- <font color=red>用于nn.Module的hook要注意module中可能涉及多个nn.Function，因此对应多个forward/backward calls，hook的作用会涉及多个call，如果不清楚有的pytorch自定义的module是如何工作的，hook的设置会容易弄错，使用的时候要仔细。</font>

### II.1 register hooks
#### II.1.1 register forward hooks
- forward hook在forward pass中被调用，具体又有两种执行位置，**pre_hook**在forward method之前执行；**hook**在forward method执行完之后执行。
- 下面1和2只对当前hook所reigster上的module有效。3和4是global hook，也就是installed for all modules。
1. <font color=green>**register_forward_pre_hook(hook_func_name)**</font>
   - 对应的hook function定义形式：<font color=norange>forward_pre_hook(m, inputs)</font>
2. <font color=green>**register_forward_hook(hook_func_name)**</font>
   - 对应的hook function定义形式：<font color=norange>forward_hook(m, inputs, output)</font>
3. <font color=green>**register_module_forward_pre_hook(hook_func_name)**</font>
4. <font color=green>**register_module_forward_hook(hook_func_name)**</font>：

#### II.1.2 register backward hooks
- backward hook都是在backward method执行完后执行。register_full_backward_pre_hook的作用对象只能是grad_outputs。register_full_backward_hook可以同时处理grad_inputs和grad_outputs。<font color=red>grad_intputs是指backward method的inputs，同理，grad_outputs是backward method的outputs。</font>
- 下面1和2只对当前hook所reigster上的module有效。3和4是global hook，也就是installed for all modules。

1. <font color=green>**register_full_backward_pre_hook(hook_func_name)**</font>
   - 对应的hook function定义形式：<font color=norange>backward_hook(m, grad_outputs)</font>
2. <font color=green>**register_full_backward_hook(hook_func_name)**</font>：这个是原register_backward_hook()
   - 对应的hook function定义形式：<font color=norange>backward_hook(m, grad_inputs, grad_outputs)</font>
3. <font color=green>**register_module_full_backward_hook(hook_func_name)**</font>
4. <font color=green>**register_module_full_backward_pre_hook(hook_func_name)**</font>

### II.2 用法举例

In [11]:
### 例1：定义不同的hooks，看他们的工作方式

##  新建module和input
m = nn.Linear(3, 3)
torch.manual_seed(1)
x = torch.randn(2, 3, requires_grad=True)

In [12]:
## ---------- 定义foreward hook function ----------
#  1.用于在forward pass前，检查或者调整inputs
def forward_pre_hook(m, inputs):  # 注，inputs都是wrapped成tuple类型的
    input = inputs[0]
    return input + 1.

#  2.用于在forward pass后，检查inputs/outputs或者调整outputs
def forward_hook(m, inputs, output):# inputs都wrapped成tuple，output按原类型传
    # 按ResNet的方式计算residual
    return output + inputs[0]

In [13]:
## register前的输出
print(format(m(x)))

tensor([[-0.3552,  0.0191,  0.7350],
        [-0.3277,  0.2751,  0.4299]], grad_fn=<AddmmBackward0>)


In [14]:
# register forward_pre_hook后会产生不同的output：
forward_pre_hook_handle = m.register_forward_pre_hook(forward_pre_hook)
print(format(m(x)))

tensor([[-0.6156, -0.1592,  1.2549],
        [-0.5881,  0.0968,  0.9498]], grad_fn=<AddmmBackward0>)


In [15]:
# register forward_hook后会产生另一种不同的output：
forward_hook_handle = m.register_forward_hook(forward_hook)
print(format(m(x)))

tensor([[1.0457, 1.1077, 2.3166],
        [1.0332, 0.6449, 1.7837]], grad_fn=<AddBackward0>)


In [16]:
# 去掉hooks之后，output与register hooks之前的值一致
forward_pre_hook_handle.remove()
forward_hook_handle.remove()

print(format(m(x)))

tensor([[-0.3552,  0.0191,  0.7350],
        [-0.3277,  0.2751,  0.4299]], grad_fn=<AddmmBackward0>)


In [17]:
## ---------- 定义backward hook function ----------
#  功能：
#     1. 检查grad_inputs/grad_outputs
#     2. 调整剩余bp流程中用的grad_inputs
#  注，grad_inputs/grad_outputs都wrapped成tuple

def backward_hook(m, grad_inputs, grad_outputs): 
    new_grad_inputs = [torch.ones_like(gi) * 42. for gi in grad_inputs]
    return new_grad_inputs

## 没有register backward hooks时的输出：
m(x).sum().backward()
print(format(x.grad))

tensor([[ 0.1376,  0.0587, -0.1150],
        [ 0.1376,  0.0587, -0.1150]])


In [18]:
# 在重做Backward Propagation前要先Clear gradients
m.zero_grad()
x.grad.zero_()

# register了backward hooks之后的输出：
backward_hook_handle = m.register_full_backward_hook(backward_hook)
m(x).sum().backward()
print(format(x.grad))

tensor([[42., 42., 42.],
        [42., 42., 42.]])


In [19]:
# 删除backward hooks之后的输出：
backward_hook_handle.remove()

m.zero_grad()
x.grad.zero_()
m(x).sum().backward()
print(format(x.grad))

tensor([[ 0.1376,  0.0587, -0.1150],
        [ 0.1376,  0.0587, -0.1150]])


In [20]:
## 例2：用forward_hook查看module中activation value的例子

class Net(nn.Module):
    def __init__(self):
        super().__init__() 
        self.conv = nn.Conv2d(3,8,2)
        self.pool = nn.AdaptiveAvgPool2d((4,4))
        self.fc = nn.Linear(8*4*4 , 1)
    def forward(self, x):
          x = F.relu(self.conv(x))
          x = self.pool(x)
          x = x.view(x.shape[0] , -1)
          x = self.fc(x)
          return x

# 定义hook，用来提取activation的结果，也就是feature
features = {} 
def hook_func(model, input ,output):
    features['feature'] = output.detach()

In [21]:
net = Net()
# 在pooling layer上register hook
net.pool.register_forward_hook(hook_func)

x= torch.randn(1,3,10,10)
output = net(x)
print(features['feature'].shape)

torch.Size([1, 8, 4, 4])


### II.3 典型应用场景
参考资料：Blog, how to Use PyTorch Hooks https://medium.com/the-dl/how-to-use-pytorch-hooks-5041d777f904
#### 场景1. 在模型训练过程中打印所需信息
**1. 方法：** 给module加wrapper，在wrapper module上加hooks\
**2. 优点：** \
（1）方便debug，避免手动增加和删除print的麻烦 \
（2）不仅可以在自定义module，还可以在pytorch自带的module和第三方module上使用

In [22]:
## 场景1：在ResNet18上用hooks来打印model信息

#  给model加一个wrapper class
class VerboseNet(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
        
        # register a hook for each layer
        for name, layer in self.model.named_children():
            # 虽然所有nn.Mudule的实例都继承了__name__属性，但有的module可能没赋值
            # 这一步可以明确让layer都有对应的name
            layer.__name__ = name
            layer.register_forward_hook(
                lambda layer, _, output: print(f"new --->: {layer.__name__}: {output.shape}")
            )
    def forward(self, x: torch.Tensor):
        return self.model(x)

In [42]:
from torchvision.models import resnet18, ResNet18_Weights

verbose_resnet = VerboseNet(resnet18(weights=ResNet18_Weights.DEFAULT))
dummy_input = torch.ones(10, 3, 28, 28)

# 这里用了assignment是为了不让jupyter打印函数的output
_ = verbose_resnet(dummy_input)

new --->: conv1: torch.Size([10, 64, 14, 14])
new --->: bn1: torch.Size([10, 64, 14, 14])
new --->: relu: torch.Size([10, 64, 14, 14])
new --->: maxpool: torch.Size([10, 64, 7, 7])
new --->: layer1: torch.Size([10, 64, 7, 7])
new --->: layer2: torch.Size([10, 128, 4, 4])
new --->: layer3: torch.Size([10, 256, 2, 2])
new --->: layer4: torch.Size([10, 512, 1, 1])
new --->: avgpool: torch.Size([10, 512, 1, 1])
new --->: fc: torch.Size([10, 1000])


#### 场景2. Feature extraction
**1. 方法：** 给module加wrapper，在wrapper module上加hooks\
**2. 优点：** 在一个预训练好的模型上做transfer learning时，可能想查看该模型摸些layers上得到的feature。hook可以在不改变模型本身的条件下实现这一需求

In [24]:
## 场景2：对model中指定layer做Feature extraction

class FeatureExt(nn.Module):
    def __init__(self, model, layers):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}
        
        for layer_i in layers:
            # 把named_modules的sub-module list构造成dict，用layer_i索引
            layer = dict([*self.model.named_modules()])[layer_i]
            
            layer.register_forward_hook(self.save_outputs_hook(layer_i))
         
    def save_outputs_hook(self, layer_i):
        def fn(_, __, output):  # 下划线长度不同，因为arguments name不能相同
            # 将该layer的output提取出来，存入feature dict
            self._features[layer_i] = output
        return fn
        
    def forward(self, x):
        _ = self.model(x)
        return self._features  # 返回dict存放了指定layers的activation output

In [25]:
res18 = resnet18(weights=ResNet18_Weights.DEFAULT)
## 对resnet.named_modules的说明：
#  1. resnet18.named_modules是一个generator
#     每次yield返回的是tuple，形如('layer_name', module)
#  2. 加'*'来把generator解包成单独的tuples：*res18.named_modules()
#  3. 将dict()用到list of tuples上，可以转换成key-value pair，便于用key索引
#     这里要先将解包后的tuples打包到一个list里面，所以是用 dict([])

#  ----- 打印sample layers 来查看网络的结构  -----
# print([*res18.named_modules()][0]) # nest结构的顶层，描绘整个结构
# print('-' * 40)
# print([*res18.named_modules()][5]) # nest结构的第2层，'layer1'
# print('-' * 40)
# print([*res18.named_modules()][6]) # nest结构的第3层，'layer1'的sub-module
# print('-' * 40)
# print([*res18.named_modules()][7]) # nest结构的第4层，...

In [26]:
res18_feats = FeatureExt(res18, layers=['layer4', 'avgpool'])

fests = res18_feats(dummy_input)
for name, output in fests.items():
    print(name, '->', output.shape)

layer4 -> torch.Size([10, 512, 1, 1])
avgpool -> torch.Size([10, 512, 1, 1])


#### 场景3. 用tenosr hook实现对module中parameters的gradient clipping

In [27]:
def clip_grad(model, floor, cap):
    for parameter in model.parameters():
        parameter.register_hook(lambda grad: grad.clamp_(floor, cap))
    return model

class AffineReluAffine(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = AffineReluAffine(100, 200, 5)
    
# Apply gradient clipping with floor = -1.0 and cap = 1.0
clipped_dummy_model = clip_grad(model, floor=-1, cap=1)

torch.manual_seed(20)
input_tensor = torch.randn(1, 100)
output = clipped_dummy_model(input_tensor)
    
loss = output.sum()
loss.backward()

In [28]:
print(clipped_dummy_model.fc2.bias.grad[:5])

tensor([1., 1., 1., 1., 1.])


In [29]:
# 将hook用到现成的model上
clipped_resnet = clip_grad(resnet18(), -0.1, 0.1)
pred = clipped_resnet(dummy_input)
loss = pred.log().mean()
loss.backward()

In [30]:
print(clipped_resnet.fc.bias.grad[:5])

tensor([ 0.0259,  0.0355,  0.1000, -0.1000,  0.1000])


## III. hooks for saved tensors

**典型应用：** 改变saved tensor pack/unpack方式，<font color=blue>将forward pass中要保存的tensor存到cpu或者disk上，节省GPU memory</font>

**方法：** 
1. 定义两个hook function，一个实现pack，另一个unpack 
   1. 定义pack_func(tensor): 只接受1个tensor作为argument，返回任意python类型。 
   2. 定义unpack_hook_func(output_of_pack_func)：只接受pack_func的返回值作为input argument，返回一个tensor用于后续的backward pass。该返回值的value要跟pack_func中的input的value相同，以达到原本想要从forward向backward传递信息的目的。
2. register 上述pack/unpack hooks 

**基本原则：** 
1. uppack_func(pack_func(tensor)) = tensor
2. pack_func的input不能做in-place modify
3. pack_func(tensor)的output可以是tenosr或者任意python type object 
4. pack_func和unpack_func单独作用于每个saved tensor

**执行过程：** 
1. 每次forward pass执行过程中，pack func在对应operation存储信息的时候被调用，其output会代替原本module中定义的pack func输出的output tensor而被存储。
2. 在backward pass中按照chainrule执行到对应operation的backward method的之前，unpack func会被调用，它用pack func的output作为唯一的input来计算一个new tensor。这个new tensor会作为backward method的input之一而被使用。

**涉及的3个tensor object的关系：**
- pack(x)将x打包成另一个object，unpack输出的x同样也是新的tensor。如果pack(x)=x，那么他们虽然是三个不同的object，比如tensor，但share底层data memory。如果pack(x)不等于x，则x和unpack输出的x会share底层data memory。

In [43]:
## 例1：saved_tensor_hooks工作过程
#  定义hook func
def pack(x):
    print('packing:', x)
    return x
def unpack(x):
    print('unpacking:', x)
    return x

a = torch.ones(5, requires_grad=True)
b = torch.ones(5, requires_grad=True) * 2

# register pack/unpack hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
    y = a * b

# 从print结果可以看到：
# 1.第1个packing的tensor是a
# 2.第2个packing的tensor是b，但b不是leaf，它是torch.ones(...)*2的output
# torchviz.make_dot(y, params={'a':a, 'b':b})

packing: tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
packing: tensor([1., 1., 1., 1., 1.], requires_grad=True)


In [32]:
y.sum().backward()

unpacking: tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>)
unpacking: tensor([1., 1., 1., 1., 1.], requires_grad=True)


In [33]:
## 例2.1：自定义pack和unpack规则：改变tensor大小后恢复
#  pack/unpack func满足“unpack(pack(x)) = x”的规则即可

# ---> 例2.1和例2.2没有实际意义，只是展示pack和unpack的自定义能力

def pack(x):
    print('packing:', x * 4)
    return x * 4
def unpack(x):
    print('unpacking:', x / 4)
    return x / 4
# pack/unpack func满足“unpack(pack(x)) = x”的规则

torch.manual_seed(3)
x = torch.randn(3, requires_grad=True)
print('x =', x)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
    y = x ** 2
    
y.sum().backward()
assert(x.grad.equal(x * 2))

x = tensor([0.8033, 0.1748, 0.0890], requires_grad=True)
packing: tensor([3.2131, 0.6993, 0.3559])
unpacking: tensor([0.8033, 0.1748, 0.0890])


In [34]:
## 例2.2：自定义pack和unpack规则：保存index of a list

storage = []

def pack(x):
    storage.append(x)
    return len(storage) - 1

def unpack(ind):
    return storage[ind]

torch.manual_seed(3)
x = torch.randn(3, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
    y = x ** 2
    
y.sum().backward()
assert(x.grad.equal(x * 2))

#### 用pack/unpack将tensor存到GPU之外的地方
1. 这是GPU memory和time的trade off。官方样例中，用A100GPU测试，把ResNet152(batch size=256)存到cpu上，可以将gpu内存使用量从48G降低到5G,但是耗时增加6倍
2. 一种折中是只把部分layer的tensor传到cpu或者其他位置。方法是，define a special nn.Module，用来wraps module and save its tensors to cpu

In [35]:
## 例3：saving tensor to cpu

#  1.人工手写
def pack(x):
    return (x.device, x.cpu())

def unpack(package):
    device, x = package
    return x.to(device)

torch.manual_seed(3)
x = torch.randn(3, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
    y = x ** 2
    
y.sum().backward()
# assert(x.grad.equal(x * 2))
torch.allclose(x.grad, (2 * x))  # x.grad与2x值的差异在默认区间内

True

In [36]:
#  2.pytorch已经实现了上述功能

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.Parameter(torch.randn(5))
    
    def forward(self, x):
        with torch.autograd.graph.save_on_cpu(pin_memory=True):
            return self.w * x

x = torch.randn(5)
model = Model()
loss = model(x).sum()
loss.backward()

In [37]:
#  3.module wrapper

class SaveToCpu(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
    
    def forward(self, *args, **kwargs):
        with torch.autograd.graph.save_on_cpu(pin_memory=True):
            return self.module(*args, **kwargs)
        
model = nn.Sequential(
    nn.Linear(10, 100), 
    nn.ReLU(), 
    SaveToCpu(nn.Linear(100, 100)), 
    nn.ReLU(), 
    nn.Linear(100, 10),
)

x = torch.randn(10)
loss = model(x).sum()
loss.backward()

In [38]:
## 例4：saving tensor to disk

## 错误的方式

import uuid
import os
import tempfile
tmp_dir_obj = tempfile.TemporaryDirectory()
tmp_dir = tmp_dir_obj.name

def pack_hook(tensor):
    name = os.path.join(tmp_dir, str(uuid.uuid4()))
    torch.save(tensor, name)
    return name

def unpack_hook(name):
    tensor = torch.load(name)
    os.remove(name)  # 如果这里remove，那么unpack就不能被call第二次
    return tensor

x = torch.ones(5, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
    y = x.pow(2)
print(y.grad_fn._saved_self)      # 每执行一次就会unpack一次
try:
    print(y.grad_fn._saved_self)  # 第二次会失败
    print("Double access succeeded!")
except:
    print("Double access failed!")

tensor([1., 1., 1., 1., 1.], requires_grad=True)
Double access failed!


  tensor = torch.load(name)


In [39]:
## 正确的方式：利用pytorch自动释放saved data的机制
#  pytorch自动释放不再需要的object，即这里的SelfDeletingTempFile object

class SelfDeletingTempFile():
    def __init__(self):
        self.name = os.path.join(tmp_dir, str(uuid.uuid4()))

    def __del__(self):
        os.remove(self.name)

def pack_hook(tensor):
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(temp_file):
    return torch.load(temp_file.name)

In [40]:
# 只转存size >= 1000的tensor
SAVE_ON_DISK_THRESHOLD = 1000

def pack_hook(x):
    if x.numel() < SAVE_ON_DISK_THRESHOLD:
        return x
    temp_file = SelfDeletingTempFile()
    torch.save(tensor, temp_file.name)
    return temp_file

def unpack_hook(tensor_or_sctf):
    if isinstance(tensor_or_sctf, torch.Tensor):
        return tensor_or_sctf
    return torch.load(tensor_or_sctf.name)

class SaveToDisk(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, *args, **kwargs):
        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
            return self.module(*args, **kwargs)

net = nn.DataParallel(SaveToDisk(Model()))

上面例子方式定义的Hook是thread-local的，当与DataParallel一起用的时候要用上例中的方式，不能定义成下面的方式：

In [41]:
# net = nn.DataParallel(model)
# with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
#     output = net(input)