## RepVGG

- 原论文：[RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/pdf/2101.03697.pdf)

<img src="../../images/repvgg.png" width="50%">

如上图所示，上述残差模块包括三个部分，一个是`3 x 3`的卷积，另一个是`1 x 1`的`point-wise`卷积，还有一个原本的输入。

作者将一个`1 x 1`的卷积转变成了一个`3 x 3`的卷积，然后将原始输入也作为了一个`3 x 3`的卷积来进行计算。最后将这三部分同时写成一个`3 x 3`的卷积融合起来。

## point-wise卷积

`point-wise`（1 x 1）卷积原理上和`DNN`没有区别，而这个`1 x 1`打破了卷积的局部关联性和平移不变形。只做了`channel mix`，对每个通道上的像素做了一个加权求和。数学上分析就是一个`MLP`。也就是只考虑了通道上的融合。

## depth-wise卷积

`depth-wise`卷积，就是将`groups`设置为一个大于`1`的数。我们看一下设置了`groups`之后，卷积层的权重大小：

In [1]:
import torch.nn as nn

In [2]:
conv_layer = nn.Conv2d(2, 4, 3, padding="same", groups=2)

In [3]:
conv_layer.weight.size()

torch.Size([4, 1, 3, 3])

可以看到，`weight`的权重为`4，1，3，3`而不是`4，2，3，3`。这是因为，将输入通道为`2`，输出通道为`4`的拆分为了输入通道为`1`，输出通道为`2`的两个`groups`。

而输入通道为`1`，输出通道为`2`的对应的`weight`应该是`2，1，3，3`，两个这样的拼接在一起就是`4，1，3，3`。

`depth-wise`卷积可以降低计算量，但是并没有将所有的`channel`进行混合。

### 原生写法

In [4]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [5]:
in_channels = 2
out_channels = 2
kernel_size = 3
w = 9 
h = 9
x = torch.ones(1, in_channels, w, h)  # batch size = 1

In [6]:
conv_2d = nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")  # 第一路分支
conv_2d_pointwise = nn.Conv2d(in_channels, out_channels, 1)  # 第二路分支
result1 = conv_2d(x) + conv_2d_pointwise(x) + x

In [7]:
result1

tensor([[[[-0.6874, -0.2034, -0.2034, -0.2034, -0.2034, -0.2034, -0.2034,
           -0.2034, -0.1217],
          [-0.5610, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140,
           -0.2140, -0.2840],
          [-0.5610, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140,
           -0.2140, -0.2840],
          [-0.5610, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140,
           -0.2140, -0.2840],
          [-0.5610, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140,
           -0.2140, -0.2840],
          [-0.5610, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140,
           -0.2140, -0.2840],
          [-0.5610, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140,
           -0.2140, -0.2840],
          [-0.5610, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140, -0.2140,
           -0.2140, -0.2840],
          [-0.5213, -0.2980, -0.2980, -0.2980, -0.2980, -0.2980, -0.2980,
           -0.2980, -0.3724]],

         [[-0.1721, -0.5914, -0.5914, -0.5914, -0.5914, -0.591

In [8]:
result1.size()

torch.Size([1, 2, 9, 9])

可以看到结果中，batch size仍然为1，通道数为输出通道数2，图片大小仍然为9 x 9。

### 算子融合

把point-wise卷积核x本身都写成3 x 3的卷积。最终把三个卷积写成一个卷积。

In [9]:
conv_2d_pointwise.weight.size()

torch.Size([2, 2, 1, 1])

可以看到conv_2d_pointwise权重原始是[2, 2, 1, 1]的，填充的时候，需要在上下左右填充0，变成[2, 2, 3, 3]。

In [10]:
pointwise_to_conv_weight = F.pad(conv_2d_pointwise.weight, [1,1,1,1,0,0,0,0])
conv_2d_for_pointwise = nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
conv_2d_for_pointwise.weight = nn.Parameter(pointwise_to_conv_weight)
conv_2d_for_pointwise.bias = nn.Parameter(conv_2d_pointwise.bias.data)

之后，需要把x本身变成一个3 x 3的卷积:

需要考虑以下两点：不考虑相邻点之间的关联性，不考虑通道之间的关联性。

我们能知道weight的大小为[2, 2, 3, 3]的，对于它自身的值获取的话，我们需要一个的[3, 3]矩阵，中间为1，上下左右为0。但是除了这个第一个和第四个通道之外，其它的通道应该是一个全0的矩阵。

In [11]:
zeros = torch.unsqueeze(torch.zeros(kernel_size, kernel_size), 0)
stars = torch.unsqueeze(F.pad(torch.ones(1, 1), [1,1,1,1]), 0)
stars_zeros = torch.unsqueeze(torch.cat([stars, zeros], 0), 0)
zeros_stars = torch.unsqueeze(torch.cat([zeros, stars], 0), 0)
identity_to_conv_weight = torch.cat([stars_zeros, zeros_stars], 0)
identity_to_conv_bias = torch.zeros([out_channels])


conv_2d_for_idntity = nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
conv_2d_for_idntity.weight = nn.Parameter(identity_to_conv_weight)
conv_2d_for_idntity.bias = nn.Parameter(identity_to_conv_bias)

result2 = conv_2d(x) + conv_2d_for_pointwise(x) + conv_2d_for_idntity(x)

In [12]:
print(torch.all(torch.isclose(result1, result2)))

tensor(True)


之后需要将这三个算子融合

In [13]:
conv_2d_for_fusion = nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
conv_2d_for_fusion.weight = nn.Parameter(conv_2d.weight.data + conv_2d_for_pointwise.weight.data +
                                         conv_2d_for_idntity.weight.data)
conv_2d_for_fusion.bias = nn.Parameter(conv_2d.bias.data + conv_2d_for_pointwise.bias.data + 
                                       conv_2d_for_idntity.bias.data)

In [14]:
result3 = conv_2d_for_fusion(x)

In [15]:
print(torch.all(torch.isclose(result2, result3)))

tensor(True)
