<div>
   <img src="RepNetPrevie.jpg"/>
</div>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 原始卷积操作

In [2]:
class Block(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.conv3x3 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.conv3x3_bn = nn.BatchNorm2d(num_channels)
        self.conv1x1 = nn.Conv2d(num_channels, num_channels, kernel_size=1, bias=False)
        self.conv1x1_bn = nn.BatchNorm2d(num_channels)
        self.identity = nn.Identity()
        self.identity_bn = nn.BatchNorm2d(num_channels)
        self.apply(self.init)
        
    def init(self, m):
        if isinstance(m, nn.BatchNorm2d):
            m.weight.data.normal_()
            m.bias.data.normal_()
            m.running_mean.normal_()
            m.running_var.fill_(0.5)
    
    def forward(self, x):
        a = self.conv3x3_bn(self.conv3x3(x))
        b = self.conv1x1_bn(self.conv1x1(x))
        c = self.identity_bn(self.identity(x))
        return a + b + c

In [3]:
num_channels = 2
x = torch.ones(1, num_channels, 3, 3)
block = Block(num_channels)
block.eval()
block(x)

tensor([[[[ 3.8140,  3.8139,  3.7645],
          [ 3.7785,  3.8110,  3.7654],
          [ 3.8143,  3.8635,  3.8470]],

         [[-0.7286, -0.6416, -0.3531],
          [-0.7127, -0.5068,  0.0389],
          [-1.0060, -0.2569,  0.1836]]]], grad_fn=<AddBackward0>)

# 重定义参数为3个3x3卷积，实现等效的效果，BN暂时不修改

In [4]:
class Block_Reparameter(nn.Module):
    def __init__(self, num_channels, source):
        super().__init__()
        self.num_channels = num_channels
        self.conv3x3 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.conv3x3_bn = nn.BatchNorm2d(num_channels)
        self.conv1x1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.conv1x1_bn = nn.BatchNorm2d(num_channels)
        self.identity = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
        self.identity_bn = nn.BatchNorm2d(num_channels)
        self.reparam(source)
    
    def forward(self, x):
        a = self.conv3x3_bn(self.conv3x3(x))
        b = self.conv1x1_bn(self.conv1x1(x))
        c = self.identity_bn(self.identity(x))
        return a + b + c
    
    def reparam(self, source):
        
        # reparameter 3x3
        self.conv3x3.weight = source.conv3x3.weight
        
        # reparameter 1x1
        self.conv1x1.weight.data.fill_(0)
        for i in range(self.num_channels):
            for j in range(self.num_channels):
                self.conv1x1.weight.data[i, j, 1, 1] = source.conv1x1.weight.data[i, j, 0, 0]
        
        # reparameter identity
        self.identity.weight.data.fill_(0)
        for i in range(self.num_channels):
            self.identity.weight.data[i, i, 1, 1] = 1
            
        self.fill_bn(self.conv3x3_bn, source.conv3x3_bn)
        self.fill_bn(self.conv1x1_bn, source.conv1x1_bn)
        self.fill_bn(self.identity_bn, source.identity_bn)

    def fill_bn(self, bn, source):
        bn.weight = source.weight
        bn.bias = source.bias
        bn.running_mean = source.running_mean
        bn.running_var = source.running_var
        

In [12]:
reparam_block = Block_Reparameter(num_channels, block)
reparam_block.eval()
reparam_block(x)

tensor([[[[ 3.8140,  3.8139,  3.7645],
          [ 3.7785,  3.8110,  3.7654],
          [ 3.8143,  3.8635,  3.8470]],

         [[-0.7286, -0.6416, -0.3531],
          [-0.7127, -0.5068,  0.0389],
          [-1.0060, -0.2569,  0.1836]]]], grad_fn=<AddBackward0>)

# 合并3x3卷积和BN为1个

In [31]:
class BlockReparameterBN(nn.Module):
    def __init__(self, num_channels, source):
        super().__init__()
        self.num_channels = num_channels
        self.conv3x3 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=True)
        self.reparam(source)
    
    def forward(self, x):
        return self.conv3x3(x)
    
    def reparam(self, source):
        
        self.conv3x3.weight.data[...] = source.conv3x3.weight.data[...]
        self.conv3x3.bias.data.fill_(0)
        
        for i in range(self.num_channels):
            mul, add = self.get_mul_add(source.conv3x3_bn, i)
            self.conv3x3.weight.data[i] *= mul
            self.conv3x3.bias.data[i] += add
        
        # reparameter 1x1
        for i in range(self.num_channels):
            mul, add = self.get_mul_add(source.conv1x1_bn, i)
            self.conv3x3.bias.data[i] += add
            
            for j in range(self.num_channels):
                self.conv3x3.weight.data[i, j, 1, 1] += source.conv1x1.weight.data[i, j, 0, 0] * mul
        
        # reparameter identity
        for i in range(self.num_channels):
            mul, add = self.get_mul_add(source.identity_bn, i)
            self.conv3x3.bias.data[i] += add
            self.conv3x3.weight.data[i, i, 1, 1] += 1 * mul

    def get_mul_add(self, bn, i):
        mean = bn.running_mean[i]
        std = torch.sqrt(bn.running_var[i])
        gamma = bn.weight[i]
        beta = bn.bias[i]
        mul = 1 / std * gamma
        add = (-mean) / std * gamma + beta
        return mul, add
        

In [32]:
reparam_bn_block = BlockReparameterBN(num_channels, block)
reparam_bn_block.eval()
reparam_bn_block(x)

tensor([[[[ 3.8141,  3.8139,  3.7645],
          [ 3.7785,  3.8111,  3.7654],
          [ 3.8143,  3.8636,  3.8471]],

         [[-0.7286, -0.6416, -0.3531],
          [-0.7127, -0.5068,  0.0389],
          [-1.0060, -0.2569,  0.1836]]]], grad_fn=<MkldnnConvolutionBackward>)