res_block = 3 * 3 conv + 1 * 1 conv + input

In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import time

In [2]:
batch_size = 1
in_channels = 2
out_channels = 2
kernel_size = 3
w = 9
h = 9

x = torch.ones(batch_size, in_channels, w, h)  # 输入图片

## 方法1：原生写法

In [3]:
conv_2d = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
conv_2d_pointwise = nn.Conv2d(in_channels, out_channels, 1)
result1 = conv_2d(x) + conv_2d_pointwise(x) + x
result1

tensor([[[[-0.0053, -0.1841, -0.1841, -0.1841, -0.1841, -0.1841, -0.1841,
           -0.1841, -0.4957],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.3106, -0.4763, -0.4763, -0.4763, -0.4763, -0.4763, -0.4763,
           -0.4763, -0.1885]],

         [[ 1.5132,  0.8370,  0.8370,  0.8370,  0.8370,  0.837

## 方法2：算子融合

详情见[RepVGG: Making VGG-style ConvNets Great Again](https://arxiv.org/pdf/2101.03697.pdf)

把point-wise卷和x本身都写成3 * 3的卷积 最终把三个卷积都写成一个卷积

### 1. 改造

In [4]:
# 2*2*1*1 -> 2*2*3*3 第二个parameter是由于 从里到外4个维度 前两个维度在上下左右padding一层0
pointwise_to_conv_weight = F.pad(conv_2d_pointwise.weight, [1, 1, 1, 1, 0, 0, 0, 0])
conv_2d_for_pointwise = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
conv_2d_for_pointwise.weight = nn.Parameter(pointwise_to_conv_weight)
conv_2d_for_pointwise.bias = nn.Parameter(conv_2d_pointwise.bias)

In [5]:
# 2*2*3*3
zeros = torch.unsqueeze(torch.zeros(kernel_size, kernel_size), 0)
stars = torch.unsqueeze(F.pad(torch.ones(1, 1), [1, 1, 1, 1]), 0)
# 第一个通道的卷积核
stars_zeros = torch.unsqueeze(torch.cat([stars, zeros], 0), 0)
# 第二个通道的卷积核
zeros_stars = torch.unsqueeze(torch.cat([zeros, stars], 0), 0)

identity_to_conv_weight = torch.cat([stars_zeros, zeros_stars], 0)
identity_to_conv_bias = torch.zeros([out_channels])

conv_2d_for_identity = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
conv_2d_for_identity.weight = nn.Parameter(identity_to_conv_weight)
conv_2d_for_identity.bias = nn.Parameter(identity_to_conv_bias)

t1 = time.time()
result2 = conv_2d(x) + conv_2d_for_pointwise(x) + conv_2d_for_identity(x)
t2 = time.time()
print("原生写法 耗时: {:.10f}秒".format(t2 - t1))
print(torch.all(torch.isclose(result1, result2)))
print(result2)

原生写法 耗时: 0.0000000000秒
tensor(True)
tensor([[[[-0.0053, -0.1841, -0.1841, -0.1841, -0.1841, -0.1841, -0.1841,
           -0.1841, -0.4957],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.3106, -0.4763, -0.4763, -0.4763, -0.4763, -0.4763, -0.4763,
           -0.4763, -0.1885]],

         [[ 1.5132,  0.837

### 2. 融合

In [6]:
conv_2d_for_fusion = nn.Conv2d(in_channels, out_channels, kernel_size, padding='same')
conv_2d_for_fusion.weight = nn.Parameter(conv_2d.weight.data +
                                         conv_2d_for_pointwise.weight.data +
                                         conv_2d_for_identity.weight.data)

conv_2d_for_fusion.bias = nn.Parameter(conv_2d.bias.data +
                                       conv_2d_for_pointwise.bias.data +
                                       conv_2d_for_identity.bias.data)

t3 = time.time()
result3 = conv_2d_for_fusion(x)
t4 = time.time()
print("算子融合写法 耗时: {:.18f}秒".format(t4 - t3))
print(torch.all(torch.isclose(result2, result3)))
print(result3)

算子融合写法 耗时: 0.000000000000000000秒
tensor(True)
tensor([[[[-0.0053, -0.1841, -0.1841, -0.1841, -0.1841, -0.1841, -0.1841,
           -0.1841, -0.4957],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.0773, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728, -0.1728,
           -0.1728, -0.2089],
          [-0.3106, -0.4763, -0.4763, -0.4763, -0.4763, -0.4763, -0.4763,
           -0.4763, -0.1885]],

         [[ 1.51