In [1]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class ConvWithBn(nn.Module):
    def __init__(self, ):
        super(ConvWithBn, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False) 
        self.bn1 = nn.BatchNorm2d(8)
        self._initialize_weights()
 
    def forward(self, x):
        x = self.bn1(self.conv1(x))
        return x
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.normal_(0, 1)
                m.bias.data.normal_(0, 1)
                m.running_mean.data.normal_(0, 1)
                m.running_var.data.uniform_(1, 2)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Conv(nn.Module):
    def __init__(self, ):
        super(Conv, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False) 
    def forward(self, x):
        x = self.conv1(x)
        return x

In [3]:
class ConvWithBias(nn.Module):
    def __init__(self, ):
        super(ConvWithBias, self).__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=True) 
    def forward(self, x):
        x = self.conv1(x)
        return x

In [4]:
Model1 = ConvWithBn()
model1_cpkt = Model1.state_dict()
Model1.eval()
Model2 = Conv()
model2_cpkt = {k:v for k,v in model1_cpkt.items() if k in Model2.state_dict()}
Model2.load_state_dict(model2_cpkt)
Model2.eval()
 
input = torch.randn(1,3,64,64)
out1 = Model1(input)
out2 = Model2(input)

In [7]:
out1[0][0]

tensor([[-0.1019, -0.3648, -0.2868,  ..., -0.7173, -0.4325,  0.0201],
        [-0.2780, -0.2512,  0.1077,  ..., -0.6277, -0.0932, -0.3593],
        [-0.2696, -0.4547,  0.0357,  ...,  0.3470,  0.1927, -0.0093],
        ...,
        [-0.2300, -0.5385, -0.3354,  ..., -0.2428, -0.1889, -0.0223],
        [-0.1480,  0.1193, -0.4314,  ..., -0.0821, -0.3396, -0.2906],
        [-0.1546, -0.2099, -0.0995,  ..., -0.3481, -0.2938, -0.1365]],
       grad_fn=<SelectBackward>)

In [8]:
out2[0][0]

tensor([[ 0.3194, -0.8087, -0.4742,  ..., -2.3212, -1.0992,  0.8427],
        [-0.4362, -0.3214,  1.2184,  ..., -1.9368,  0.3565, -0.7853],
        [-0.4002, -1.1945,  0.9098,  ...,  2.2455,  1.5834,  0.7164],
        ...,
        [-0.2304, -1.5539, -0.6827,  ..., -0.2854, -0.0539,  0.6607],
        [ 0.1212,  1.2685, -1.0948,  ...,  0.4041, -0.7006, -0.4904],
        [ 0.0931, -0.1442,  0.3294,  ..., -0.7373, -0.5042,  0.1707]],
       grad_fn=<SelectBackward>)

In [9]:
model1_cpkt.keys()

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked'])

In [11]:
model1_cpkt['bn1.weight'].shape

torch.Size([8])

In [12]:
bnw = model1_cpkt['bn1.weight']
bnb = model1_cpkt['bn1.bias']
mean = model1_cpkt['bn1.running_mean']
var = model1_cpkt['bn1.running_var']

In [15]:
bnwexp = bnw.unsqueeze(0).unsqueeze(2).unsqueeze(3)
bnbexp = bnb.unsqueeze(0).unsqueeze(2).unsqueeze(3)
meanexp = mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)
varexp = var.unsqueeze(0).unsqueeze(2).unsqueeze(3)

In [17]:
bnout = bnwexp*((out2 - meanexp)/torch.sqrt(varexp+1e-5)) +bnbexp
torch.sum(bnout - out1)

tensor(0.0002, grad_fn=<SumBackward0>)

In [16]:
bnwexp.shape

torch.Size([1, 8, 1, 1])

In [18]:
Model3 = ConvWithBias()

In [19]:
conv1w = model1_cpkt['conv1.weight']
bnw = model1_cpkt['bn1.weight']
bnb = model1_cpkt['bn1.bias']
bnmean = model1_cpkt['bn1.running_mean']
bnvar = model1_cpkt['bn1.running_var']

In [21]:
bnwexp = bnw.unsqueeze(1).unsqueeze(2).unsqueeze(3)
bnvarexp = bnvar.unsqueeze(1).unsqueeze(2).unsqueeze(3)
new_conv1w = (bnwexp*conv1w)/(torch.sqrt(bnvarexp+1e-5))
new_conv2b = (bnb - bnw*bnmean/(torch.sqrt(bnvar+1e-5)))

In [24]:
bnwexp.shape

torch.Size([8, 1, 1, 1])

In [22]:
merge_state_dict = {}
merge_state_dict['conv1.weight'] = new_conv1w
merge_state_dict['conv1.bias'] = new_conv2b

In [23]:

Model3.load_state_dict(merge_state_dict)

Model3.eval()
out3 = Model3(input)
print("Bias of merged ConvBn : ",torch.sum(out3 - out1))

Bias of merged ConvBn :  tensor(-0.0002, grad_fn=<SumBackward0>)


In [None]:
import torch
import os
from collections import OrderedDict
import cv2
import numpy as np
import torchvision.transforms as transforms


"""  Parameters and variables  """
IMAGENET = '/home/zym/ImageNet/ILSVRC2012_img_val_256xN_list.txt'
LABEL = '/home/zym/ImageNet/synset.txt'
TEST_ITER = 10
SAVE = False
TEST_AFTER_MERGE = True


"""  Functions  """
def merge(params, name, layer):
    # global variables
    global weights, bias
    global bn_param

    if layer == 'Convolution':
        # save weights and bias when meet conv layer
        if 'weight' in name:
            weights = params.data
            bias = torch.zeros(weights.size()[0])
        elif 'bias' in name:
            bias = params.data
        bn_param = {}

    elif layer == 'BatchNorm':
        # save bn params
        bn_param[name.split('.')[-1]] = params.data

        # running_var is the last bn param in pytorch
        if 'running_var' in name:
            # let us merge bn ~
            tmp = bn_param['weight'] / torch.sqrt(bn_param['running_var'] + 1e-5)
            weights = tmp.view(tmp.size()[0], 1, 1, 1) * weights
            bias = tmp*(bias - bn_param['running_mean']) + bn_param['bias']

            return weights, bias

    return None, None


"""  Main functions  """
# import pytorch model
import models.shufflenetv2.shufflenetv2_merge as shufflenetv2
pytorch_net = shufflenetv2.ShuffleNetV2().eval()
model_path = shufflenetv2.weight_file

# load weights
print('Finding trained model weights...')
try:
    for file in os.listdir(model_path):
        if 'pth' in file:
            print('Loading weights from %s ...' % file)
            trained_weights = torch.load(os.path.join(model_path, file))
            # pytorch_net.load_state_dict(trained_weights)
            print('Weights load success')
            break
except:
    raise ValueError('No trained model found or loading error occurs')

# go through pytorch net
print('Going through pytorch net weights...')
new_weights = OrderedDict()
inner_product_flag = False
for name, params in trained_weights.items():
    if len(params.size()) == 4:
        _, _ = merge(params, name, 'Convolution')
        prev_layer = name
    elif len(params.size()) == 1 and not inner_product_flag:
        w, b = merge(params, name, 'BatchNorm')
        if w is not None:
            new_weights[prev_layer] = w
            new_weights[prev_layer.replace('weight', 'bias')] = b
    else:
        # inner product layer
        # if meet inner product layer,
        # the next bias weight can be misclassified as 'BatchNorm' layer as len(params.size()) == 1
        new_weights[name] = params
        inner_product_flag = True

# align names in new_weights with pytorch model
# after move BatchNorm layer in pytorch model,
# the layer names between old model and new model will mis-align
print('Aligning weight names...')
pytorch_net_key_list = list(pytorch_net.state_dict().keys())
new_weights_key_list = list(new_weights.keys())
assert len(pytorch_net_key_list) == len(new_weights_key_list)
for index in range(len(pytorch_net_key_list)):
    new_weights[pytorch_net_key_list[index]] = new_weights.pop(new_weights_key_list[index])

# save new weights
if SAVE:
    torch.save(new_weights, model_path + '/' + file.replace('.pth', '_merged.pth'))

# test merged pytorch model
if TEST_AFTER_MERGE:
    try:
        pytorch_net.load_state_dict(new_weights)
        print('Pytorch net load weights success~')
    except:
        raise ValueError('Load new weights error')

    print('-' * 50)
    with open(LABEL) as f:
        labels = f.read().splitlines()
    with open(IMAGENET) as f:
        images = f.read().splitlines()
        for _ in range(TEST_ITER):
            # cv2 default chann el is BGR
            image_path, label = images[np.random.randint(0, len(images))].split(' ')
            # image_path, label = images[0].split(' ')
            input_image = cv2.imread(image_path)
            input_image = cv2.resize(input_image, (224, 224))
            input_image = transforms.Compose([transforms.ToTensor(),
                                              transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                   std=[0.229, 0.224, 0.225])
                                              ])(input_image)
            input_image = input_image.view(1, 3, 224, 224)
            output_logits = pytorch_net(input_image)
            _, index = output_logits.max(dim=1)
            print('true label: \t%s' % labels[int(label)])
            print('predict label:\t%s' % labels[int(index)])
            print('-' * 50)
