In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
class Model(nn.Module):
    def __init__(self, ichannel, ochannel):
        super().__init__()
        self.ichannel = ichannel
        self.ochannel = ochannel
        self.conv1 = nn.Conv2d(ichannel, ochannel, 3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(ichannel, ochannel, 1, bias=True)
        self.bn1 = nn.BatchNorm2d(ochannel)
        self.bn2 = nn.BatchNorm2d(ochannel)
        self.reset_parameters()
        
    def reset_parameters(self):
        self.conv1.weight.data.fill_(0.5)
        self.conv1.bias.data.fill_(0.1)
        self.bn1.weight.data.fill_(0.3)
        self.bn1.bias.data.fill_(0.2)
        self.bn1.running_mean.data.fill_(0.81)
        self.bn1.running_var.data.fill_(0.23)
        self.conv2.weight.data.fill_(0.7)
        self.conv2.bias.data.fill_(0.9)
        self.bn2.weight.data.fill_(0.21)
        self.bn2.bias.data.fill_(0.35)
        self.bn2.running_mean.data.fill_(0.31)
        self.bn2.running_var.data.fill_(0.53)
        
    def forward(self, x):
        x1 = self.bn1(self.conv1(x))
        x2 = self.bn2(self.conv2(x))
        x3 = x
        return x1 + x2 + x3
    
class Fuse1(nn.Module):
    def __init__(self, model):
        super().__init__()
        
        self.ichannel = model.ichannel
        self.ochannel = model.ochannel
        self.conv1 = nn.Conv2d(model.ichannel, model.ochannel, 3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(model.ichannel, model.ochannel, 1, bias=True)
        self.reparametelize(model)
        
    def fuse_conv_bn(self, conv_out, conv, bn):
        
        # conv -> 
        # y = x * w + b
        
        # bn ->
        # t = (x - mean) / var
        # t = x * 1/var + (-mean / var)
        # y = t * gamma + beta
        # y = (x * w + b) * 1/var * gamma + (-mean/var) * gamma + beta
        # y = x * w * 1/var * gamma + (-mean/var) * gamma + beta + conv.b * 1/var * gamma
        
        # conv -> bn
        # y1 = x * conv.w + conv.b
        # y2 = (y1 - mean) / var
        # y3 = y2 * gamma + beta
        # output = ((x * conv.w + conv.b) - mean) / var * gamma + beta
        # output = (x * conv.w + conv.b - mean) / var * gamma + beta
        #        = x * conv.w / var * gamma + conv.b / var * gamma - mean / var * gamma + beta
        # weight = x * conv.w / var * gamma
        # bias   = conv.b / var * gamma - mean / var * gamma + beta
        std = torch.sqrt(bn.running_var.data) + bn.eps
        conv_out.weight.data[:] = conv.weight.data / std.view(1, -1, 1, 1) * bn.weight.data.view(1, -1, 1, 1)
        conv_out.bias.data[:]   = conv.bias.data / std * bn.weight.data + bn.bias.data + (-bn.running_mean.data / std) * bn.weight.data
        
    def reparametelize(self, model):
        self.fuse_conv_bn(self.conv1, model.conv1, model.bn1)
        self.fuse_conv_bn(self.conv2, model.conv2, model.bn2)
            
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = x
        return x1 + x2 + x3
    
class Fuse2(nn.Module):
    def __init__(self, model):
        super().__init__()
        
        self.conv1 = nn.Conv2d(model.ichannel, model.ochannel, 3, padding=1, bias=True)
        self.reparametelize(model)
        
    def fuse_conv3_conv1_identity(self, conv_out, conv3, conv1):
        
        # conv3 + conv1 + identity
        #y1 = conv3(x)
        # y1 = conv3.w * x + conv3.b
        #y2 = conv1(x)
        # y2 = conv1.w * x + conv1.b
        #y3 = x
        #output = y1 + y2 + y3
        # output = conv3.w * x + conv3.b + conv1.w * x + conv1.b + x
        #        = (conv3.w + conv1.w + 1) * x + conv3.b + conv1.b
        conv_out.weight.data[:] = conv3.weight.data[:]
        conv_out.weight.data[:, :, 1, 1] += conv1.weight.data[:, :, 0, 0] + 1 / conv1.weight.data.shape[0]
        conv_out.bias.data[:] = conv3.bias.data[:] + conv1.bias.data[:]
        
    def reparametelize(self, model):
        self.fuse_conv3_conv1_identity(self.conv1, model.conv1, model.conv2)
            
    def forward(self, x):
        return self.conv1(x)

In [3]:
model = Model(3, 3).eval()
fuse1 = Fuse1(model).eval()
fuse2 = Fuse2(fuse1).eval()

In [4]:
x = torch.full((1, 3, 3, 3), 0.8)
model(x), fuse1(x), fuse2(x)

(tensor([[[[4.5632, 6.0645, 4.5632],
           [6.0645, 8.3164, 6.0645],
           [4.5632, 6.0645, 4.5632]],
 
          [[4.5632, 6.0645, 4.5632],
           [6.0645, 8.3164, 6.0645],
           [4.5632, 6.0645, 4.5632]],
 
          [[4.5632, 6.0645, 4.5632],
           [6.0645, 8.3164, 6.0645],
           [4.5632, 6.0645, 4.5632]]]], grad_fn=<AddBackward0>),
 tensor([[[[4.5632, 6.0645, 4.5632],
           [6.0645, 8.3164, 6.0645],
           [4.5632, 6.0645, 4.5632]],
 
          [[4.5632, 6.0645, 4.5632],
           [6.0645, 8.3164, 6.0645],
           [4.5632, 6.0645, 4.5632]],
 
          [[4.5632, 6.0645, 4.5632],
           [6.0645, 8.3164, 6.0645],
           [4.5632, 6.0645, 4.5632]]]], grad_fn=<AddBackward0>),
 tensor([[[[4.5632, 6.0645, 4.5632],
           [6.0645, 8.3164, 6.0645],
           [4.5632, 6.0645, 4.5632]],
 
          [[4.5632, 6.0645, 4.5632],
           [6.0645, 8.3164, 6.0645],
           [4.5632, 6.0645, 4.5632]],
 
          [[4.5632, 6.0645, 4.5632],
 