<a href="https://colab.research.google.com/github/xuwangfmc/dlbook/blob/main/NetworkPruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NetworkPruning
   现有许多工作为了追求高精度，通过会设计参数量与复杂度较高的模型，但是这种模型是存在大量信息冗余的。网络剪枝就是针对这一现象，将大量的冗余信息给去掉。
网络剪枝主要分为Non-structured Pruning(Wight)和Structured Pruning(Neuron)两种。前者比较灵活，可任意裁各Filter内的参数，参数减少空间较大，精度降低较少，但是这种任意裁剪的方式不利于硬件实现，因为硬件通常是直接进行矩阵运算，额外的复杂操作会使得硬件难以优化；Stuctured Pruning是去掉整个neuron，即整个Filter，参数减少空间较小，精度下降较多，但是硬件支持良好。
![NetworkPruning.png](https://s2.loli.net/2022/01/22/aoyRDLGxrf8b67z.png)

该教程主要介绍如何使用Pytorch自带的剪枝库对模型进行剪枝，以及实际运用中如何对模型进行剪枝。

## 实验流程


**裁剪单个Module**




在进行剪枝之前，先构建LeNet模型。

In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__() 
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [2]:
module = model.conv1
print(list(module.named_parameters()))

[('weight', Parameter containing:
tensor([[[[-0.3138,  0.2117, -0.3179],
          [-0.1768,  0.0209, -0.3088],
          [ 0.1922,  0.1970, -0.3047]]],


        [[[ 0.3054,  0.2523, -0.0929],
          [ 0.3087,  0.3270, -0.1836],
          [ 0.1698,  0.2943,  0.3102]]],


        [[[ 0.0307,  0.2844, -0.0396],
          [ 0.2084, -0.1057,  0.0373],
          [ 0.1042, -0.2758, -0.1900]]],


        [[[ 0.1137, -0.1251,  0.2503],
          [ 0.1387, -0.0317, -0.1184],
          [-0.2504,  0.1400,  0.2995]]],


        [[[-0.2898, -0.3211, -0.3013],
          [ 0.1735, -0.1554,  0.0403],
          [ 0.2877, -0.2073,  0.0639]]],


        [[[-0.2566, -0.1833,  0.2502],
          [-0.2506,  0.1780,  0.2193],
          [-0.2474, -0.0557, -0.1214]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.2063, -0.0221,  0.0029,  0.3152,  0.2007,  0.1731], device='cuda:0',
       requires_grad=True))]


 named_parameters是torch.nn.Module类提供的获取可学习参数的方法,将会返回一个list，其中包括有Weights和bias的初始值。

如果我们想要裁剪一个Module，首先我们需要选取一个pruning的方案，目前torch.nn.utils.prune中已经支持  
- RandomUnstructured  
- L1Unstructured  
- RandomStructured  
- LnStructured  
- CustomFromMask 
 
也可以通过继承BasePruningMethod来自定义pruning的方法。

In [4]:
prune.random_unstructured(module,name="weight",amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

然后我们指定module以及需要pruning的参数的name,最后使用合适的参数，指定pruning的参数。在上述代码中，我们将随机裁剪30%的连接（conv1中weights参数30%的连接）。其中name用于指定module中的某个parameter，amount用于执行需要裁剪连接的比例（0.0到1.0）或者直接给定一个绝对值。

In [5]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.2063, -0.0221,  0.0029,  0.3152,  0.2007,  0.1731], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.3138,  0.2117, -0.3179],
          [-0.1768,  0.0209, -0.3088],
          [ 0.1922,  0.1970, -0.3047]]],


        [[[ 0.3054,  0.2523, -0.0929],
          [ 0.3087,  0.3270, -0.1836],
          [ 0.1698,  0.2943,  0.3102]]],


        [[[ 0.0307,  0.2844, -0.0396],
          [ 0.2084, -0.1057,  0.0373],
          [ 0.1042, -0.2758, -0.1900]]],


        [[[ 0.1137, -0.1251,  0.2503],
          [ 0.1387, -0.0317, -0.1184],
          [-0.2504,  0.1400,  0.2995]]],


        [[[-0.2898, -0.3211, -0.3013],
          [ 0.1735, -0.1554,  0.0403],
          [ 0.2877, -0.2073,  0.0639]]],


        [[[-0.2566, -0.1833,  0.2502],
          [-0.2506,  0.1780,  0.2193],
          [-0.2474, -0.0557, -0.1214]]]], device='cuda:0', requires_grad=True))]


这时候我们会看到weight_orig，和之前打印的数值是没有变化的，但是weights的参数不见了。

In [6]:
print(list(module.named_buffers()))

[('weight_mask', tensor([[[[0., 0., 1.],
          [0., 0., 0.],
          [1., 0., 1.]]],


        [[[1., 1., 0.],
          [0., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [0., 1., 0.],
          [1., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 0., 0.],
          [0., 0., 0.]]]], device='cuda:0'))]


此时会产生一个weight_mask的掩码，本身不会直接作用于模型，会产生一个weight的属性，这时候原module是不存在weight的parameter,仅仅是一个attribute.

In [7]:
print(module.weight)

tensor([[[[-0.0000,  0.0000, -0.3179],
          [-0.0000,  0.0000, -0.0000],
          [ 0.1922,  0.0000, -0.3047]]],


        [[[ 0.3054,  0.2523, -0.0000],
          [ 0.0000,  0.3270, -0.1836],
          [ 0.1698,  0.0000,  0.3102]]],


        [[[ 0.0307,  0.0000, -0.0000],
          [ 0.2084, -0.1057,  0.0373],
          [ 0.1042, -0.2758, -0.1900]]],


        [[[ 0.1137, -0.0000,  0.0000],
          [ 0.0000, -0.0317, -0.0000],
          [-0.2504,  0.0000,  0.0000]]],


        [[[-0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.1554,  0.0403],
          [ 0.2877, -0.2073,  0.0639]]],


        [[[-0.2566, -0.0000,  0.2502],
          [-0.2506,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)


最后，使用pytorch的forward_pre_hooks会在每次forward之前应用这个pruning操作，需要指出的是当module被裁剪之后，它的每一个paramter都需要一个forward_pre_hooks来标识将被裁剪。当前我们只进行了conv1模块的weight裁剪，所以以下命令将只能看到一个hook。

In [8]:
print(module._forward_pre_hooks)

OrderedDict([(1, <torch.nn.utils.prune.PruningContainer object at 0x7fa8e3721350>)])


同样，我们还可以对conv1的bias进行L1unstructured的裁剪，和上述类似。

**迭代裁剪**

单个module中的parameters是可以多次裁剪的，无非就是顺序的组合不同的mask和调用不同的pruning方法，结果是一致的，我们可以通过调用PruningContainer的compute_mask方法来实现在旧mask之上添加新的mask的逻辑。此时50%的kernel参数会被设置成0.

In [9]:
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
print(module.weight)

tensor([[[[-0.0000,  0.0000, -0.3179],
          [-0.0000,  0.0000, -0.0000],
          [ 0.1922,  0.0000, -0.3047]]],


        [[[ 0.3054,  0.2523, -0.0000],
          [ 0.0000,  0.3270, -0.1836],
          [ 0.1698,  0.0000,  0.3102]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]],


        [[[-0.2566, -0.0000,  0.2502],
          [-0.2506,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)


这时候hook就会变成torch.nn.utils.prune.PruningContainer的类型，将会存储应用在weights参数上的所有prune操作。

In [10]:
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

[<torch.nn.utils.prune.RandomUnstructured object at 0x7fa8e3765b90>, <torch.nn.utils.prune.RandomUnstructured object at 0x7fa8e3721c10>, <torch.nn.utils.prune.LnStructured object at 0x7fa8e3724490>]


**序列化裁剪后的模型**
所有的裁剪后的tensor都是存储在state_dict当中，这就非常便于我们做模型的序列化以及save操作。

In [11]:
print(model.state_dict().keys())

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])


接下来我们会想，如何将pruning操作永久的作用于模型，而不保存类似weight_orig以及weight_mask 这样的Tensor，同时移除forward_pre_hook.
prune中提供了remove操作, 需要注意的是，remove并不能undo裁剪的操作，使得什么都没发生过一样，仅仅是永久化，重新将weight赋值给module的源tensor.

In [12]:
prune.remove(module, 'weight')
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([-0.2063, -0.0221,  0.0029,  0.3152,  0.2007,  0.1731], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.0000,  0.0000, -0.3179],
          [-0.0000,  0.0000, -0.0000],
          [ 0.1922,  0.0000, -0.3047]]],


        [[[ 0.3054,  0.2523, -0.0000],
          [ 0.0000,  0.3270, -0.1836],
          [ 0.1698,  0.0000,  0.3102]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.0000, -0.0000, -0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]],


        [[[-0.2566, -0.0000,  0.2502],
          [-0.2506,  0.0000,  0.0000],
          [-0.0000, -0.0000, -0.0000]]]], device='cuda:0', requires_grad=True))]


这时候我们会发现直接weight就是裁剪后的值，而weight_orig不见了。如果希望裁剪模型中的多个参数，可以遍历module然后重复上述操作即可。

**全局剪枝**

相比之前的操作仅仅作用到指定的module，指定的参数，global pruning更加强大，可以通过如下配置来实现。

In [13]:
model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

这就非常方便，因为日常使用中我们往往追求一个全局的最终的一个效果，而不大关注特定的module的稀疏程度。

In [14]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Sparsity in conv1.weight: 3.70%
Sparsity in conv2.weight: 8.91%
Sparsity in fc1.weight: 21.98%
Sparsity in fc2.weight: 12.42%
Sparsity in fc3.weight: 10.24%
Global sparsity: 20.00%


## 实战案例

步骤1：加载数据集

In [15]:
# Download dataset
!gdown --id '1O6pFYd9aw1cZbry-NXk3k3tTXLVgssIg' --output food-11.zip
# Unzip the files
!unzip -q food-11.zip

Downloading...
From: https://drive.google.com/uc?id=1O6pFYd9aw1cZbry-NXk3k3tTXLVgssIg
To: /content/food-11.zip
100% 277M/277M [00:05<00:00, 50.5MB/s]


步骤2：加载StudentNet

In [16]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
# 运行之前的Architecture_Design文件
!gdown --id '1-sSaAOk3vnmfZv8F4Vo_YhdrHkTNn7bh' --output "Architecture_Design.ipynb"
%run "Architecture_Design.ipynb"

Downloading...
From: https://drive.google.com/uc?id=1-sSaAOk3vnmfZv8F4Vo_YhdrHkTNn7bh
To: /content/Architecture_Design.ipynb
  0% 0.00/7.62k [00:00<?, ?B/s]100% 7.62k/7.62k [00:00<00:00, 11.3MB/s]


步骤3：执行代码

In [17]:
import torchvision.models as models
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
# 加载训练好的StudentNet参数
!gdown --id '12wtIa0WVRcpboQzhgRUJOpcXe23tgWUL' --output student_custom_small.bin
class FoodDataset(torch.utils.data.Dataset):

    def __init__(self, dir_path, transform, cuda=False):
        self.cuda = cuda
        self.transform = transform
        self.x = []
        self.y = []
        img_names = sorted(os.listdir(dir_path))
        for img_name in img_names:  # glob返回匹配到的所有文件的路径
            img_path = os.path.join(dir_path, img_name)
            label = int(img_name.split("_")[0])

            image = Image.open(img_path)
            # Get File Descriptor
            image_fp = image.fp
            image.load()
            # Close File Descriptor (or it'll reach OPEN_MAX)
            image_fp.close()

            self.x.append(image)
            self.y.append(label)

    def __getitem__(self, idx):
        image = self.transform(self.x[idx])
        label = torch.torch.tensor(self.y[idx], dtype=torch.int64)
        if self.cuda:
            image = image.cuda()
            label = label.cuda()
        return image, label
    
    def __len__(self):
        return len(self.x)


trainTransform = transforms.Compose([
    transforms.RandomCrop(256, pad_if_needed=True, padding_mode='symmetric'),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])
testTransform = transforms.Compose([
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])


def get_dataloader(dir_path='/data',mode='training', batch_size=32, cuda=False):

    assert mode in ['training', 'testing', 'validation']

    dataset = FoodDataset(
        f'{dir_path}',
        transform=trainTransform if mode == 'training' else testTransform, cuda=cuda)

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(mode == 'training'))

    return dataloader

def network_slimming(old_model, new_model):
    old_params = old_model.state_dict()
    new_params = new_model.state_dict()

    # 只保留每一层中的部分卷积核
    selected_idx = []
    for i in range(8):  # 只对模型中CNN部分(8个Sequential)进行剪枝
        gamma = old_params[f'cnn.{i}.1.weight']
        new_dim = len(new_params[f'cnn.{i}.1.weight'])
        ranking = torch.argsort(gamma, descending=True)
        selected_idx.append(ranking[:new_dim])

    now_processing = 1  # 当前在处理哪一个Sequential，索引为0的Sequential不需处理
    for param_name, weights in old_params.items():
        # 如果是CNN层，则根据gamma仅复制部分参数；如果是FC层或者该参数只有一个数字(例如batchnorm的tracenum等等)就直接全部复制
        if param_name.startswith('cnn') and weights.size() != torch.Size([]) and now_processing != len(selected_idx):
            # 当处理到Pointwise Convolution时，则代表正在处理的Sequential已处理完毕
            if param_name.startswith(f'cnn.{now_processing}.3'):
                now_processing += 1

            # Pointwise Convolution的参数会受前一个Sequential和后一个Sequential剪枝情况的影响，因此需要特别处理
            if param_name.endswith('3.weight'):
                # 不需要删除最后一个Sequential中的Pointwise卷积核
                if len(selected_idx) == now_processing:
                    # selected_idx[now_processing-1]指当前Sequential中保留的通道的索引
                    new_params[param_name] = weights[:,selected_idx[now_processing-1]]
                # 除了最后一个Sequential，每个Sequential中卷积核的数量(输出通道数)都要和后一个Sequential匹配。
                else:
                    # Pointwise Convolution中Conv2d(x,y,1)的weight的形状是(y,x,1,1)
                    # selected_idx[now_processing]指后一个Sequential中保留的通道的索引
                    # selected_idx[now_processing-1]指当前Sequential中保留的通道的索引
                    new_params[param_name] = weights[selected_idx[now_processing]][:,selected_idx[now_processing-1]]
            else:
                new_params[param_name] = weights[selected_idx[now_processing]]
        else:
            new_params[param_name] = weights
    
    # 返回新模型
    new_model.load_state_dict(new_params)
    return new_model


def run_epoch(dataloader, new_net, optimizer, criterion, update=True):
    total_num, total_hit, total_loss = 0, 0, 0
    for now_step, batch_data in enumerate(dataloader):
        # 清空 optimizer
        optimizer.zero_grad()
        # 获取数据
        inputs, labels = batch_data
            
        logits = new_net(inputs)
        loss = criterion(logits, labels)
        if update:
          loss.backward()
          optimizer.step()

        total_hit += torch.sum(torch.argmax(logits, dim=1) == labels).item()
        total_num += len(inputs)
        total_loss += loss.item() * len(inputs)

    return total_loss / total_num, total_hit / total_num


# config
batch_size = 1
cuda = True if torch.cuda.is_available() else False
prune_count = 3
prune_rate = 0.95
finetune_epochs = 3


if __name__ == '__main__':
    # 加载数据
    train_dataloader = get_dataloader('training', 'training', batch_size, cuda)
    valid_dataloader = get_dataloader('validation', 'validation', batch_size, cuda)
    print('Data Loaded')

    # 加载网络
    old_net = StudentNet()
    if cuda:
        old_net = old_net.cuda()
    old_net.load_state_dict(torch.load('./student_custom_small.bin'))

    # 开始剪枝并finetune：独立剪枝prune_count次，每次剪枝的剪枝率按prune_rate逐渐增大，剪枝后微调finetune_epochs个epoch
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(old_net.parameters(), lr=1e-3)

    now_width_mult = 1
    for i in range(prune_count):
        now_width_mult *= prune_rate # 增大剪枝率
        new_net = StudentNet(width_mult=now_width_mult)
        if cuda:
            new_net = new_net.cuda()
        new_net = network_slimming(old_net, new_net)
        now_best_acc = 0
        for epoch in range(finetune_epochs):
            new_net.train()
            train_loss, train_acc = run_epoch(train_dataloader, new_net, optimizer, criterion, update=True)
            new_net.eval()
            valid_loss, valid_acc = run_epoch(valid_dataloader, new_net, optimizer, criterion, update=False)
            # 每次剪枝时存下最好的model
            if valid_acc > now_best_acc:
                now_best_acc = valid_acc
                torch.save(new_net.state_dict(), f'./pruned_{now_width_mult}_student_model.bin')
            print('rate {:6.4f} epoch {:>3d}: train loss: {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}'.format(now_width_mult, 
                epoch, train_loss, train_acc, valid_loss, valid_acc))


Downloading...
From: https://drive.google.com/uc?id=12wtIa0WVRcpboQzhgRUJOpcXe23tgWUL
To: /content/student_custom_small.bin
  0% 0.00/1.05M [00:00<?, ?B/s]100% 1.05M/1.05M [00:00<00:00, 16.5MB/s]
Data Loaded
rate 0.9500 epoch   0: train loss: 4.3217, acc 0.0990 valid loss: 1.9473, acc 0.6807
rate 0.9500 epoch   1: train loss: 4.3158, acc 0.0994 valid loss: 2.0738, acc 0.6715
rate 0.9500 epoch   2: train loss: 4.3032, acc 0.0986 valid loss: 2.3332, acc 0.6615
rate 0.9025 epoch   0: train loss: 4.4357, acc 0.0990 valid loss: 2.5495, acc 0.5940
rate 0.9025 epoch   1: train loss: 4.4380, acc 0.1011 valid loss: 2.0737, acc 0.6487
rate 0.9025 epoch   2: train loss: 4.4153, acc 0.1036 valid loss: 2.1265, acc 0.6122
rate 0.8574 epoch   0: train loss: 4.4611, acc 0.0944 valid loss: 2.2042, acc 0.5967
rate 0.8574 epoch   1: train loss: 4.4358, acc 0.0899 valid loss: 1.8182, acc 0.6715
rate 0.8574 epoch   2: train loss: 4.4884, acc 0.0928 valid loss: 1.8172, acc 0.6998
