In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers 
import numpy as np

In [2]:
class RepVGGBlock(layers.Layer):
    def __init__(self, in_channels,out_channels,stride=1,groups=1,deploy=False):
        super(RepVGGBlock, self).__init__()
        self.in_channels = in_channels
        self.groups = groups
        self.stride = stride
        self.act = layers.ReLU()
        if deploy:
            self.rbr_reparam = layers.Conv2D(out_channels,3, stride,'SAME',use_bias=True,groups=groups)
        else:
            self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
            self.rbr_3x3 = keras.Sequential([layers.Conv2D(out_channels, 3,stride,'SAME',groups=groups,use_bias=False),layers.BatchNormalization()])
            self.rbr_1x1 = keras.Sequential([layers.Conv2D(out_channels, 1,stride,'SAME',groups=groups,use_bias=False),layers.BatchNormalization()])

    def call(self, x):
        if hasattr(self, 'rbr_reparam'):
            return self.act(self.rbr_reparam(x))
        id_out = 0 if self.rbr_identity is None else self.rbr_identity(x) 
        return self.act(self.rbr_3x3(x) + self.rbr_1x1(x) + id_out)
    
    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self.fuse_bn_tensor(self.rbr_3x3)
        kernel1x1, bias1x1 = self.fuse_bn_tensor(self.rbr_1x1)
        kernelid, biasid = self.fuse_bn_tensor(self.rbr_identity)     
        return [kernel3x3 + self.pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid]

    def pad_1x1_to_3x3_tensor(self, kernel1x1):
        return tf.pad(kernel1x1,[[2-(self.stride%3), (self.stride)%3],[2-(self.stride%3), (self.stride)%3],[0,0],[0,0]])

    def fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch,keras.Sequential):
            kernel = branch.layers[0].kernel
            moving_mean = branch.layers[1].moving_mean
            moving_variance = branch.layers[1].moving_variance
            gamma = branch.layers[1].gamma
            beta = branch.layers[1].beta
            eps = branch.layers[1].epsilon
        else:
            input_dim = self.in_channels // self.groups
            kernel = np.zeros((3, 3,input_dim,self.in_channels), dtype=np.float32)
            for i in range(self.in_channels):
                kernel[1, 1,i % input_dim,i] = 1
            moving_mean = branch.moving_mean
            moving_variance = branch.moving_variance
            gamma = branch.gamma
            beta = branch.beta
            eps = branch.epsilon
        inv = (tf.math.rsqrt(moving_variance + eps)* gamma)
        return kernel*inv,beta - moving_mean * inv

In [3]:
def make_RepVGG(num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False):
    def stage(in_planes,planes, num_blocks, stride,override_groups_map,cur_layer_idx,deploy,x):
        strides = [stride] + [1]*(num_blocks-1)
        for stride in strides:
            cur_groups = override_groups_map.get(cur_layer_idx, 1)
            x = RepVGGBlock(in_planes, planes,stride,groups=cur_groups, deploy=deploy)(x)
            in_planes = planes
        return x,in_planes
    
    override_groups_map = override_groups_map or dict()
    in_planes = min(64, int(64 * width_multiplier[0]))
    
    input_x = layers.Input((None,None,3))
    x = RepVGGBlock(3, in_planes, 2, deploy=deploy)(input_x)
    for i in range(len(num_blocks)):
        x,in_planes = stage(in_planes,int((64*(2**i))*width_multiplier[i]), num_blocks[i], 2,override_groups_map,i+1,deploy,x)
    x = layers.GlobalAvgPool2D()(x)
    x = layers.Dense(num_classes,activation='softmax')(x)
    return keras.Model(input_x,x)

In [4]:
def repvgg_model_convert(model,deploy_model):
    i = 1
    while isinstance(model.get_layer(index = i),RepVGGBlock):
        deploy_model.layers[i].set_weights(model.layers[i].get_equivalent_kernel_bias())
        i+=1
    deploy_model.layers[i+1].set_weights([model.layers[i+1].kernel.numpy(),model.layers[i+1].bias.numpy()])

In [5]:
def get_RepVGG_func_by_name(name,numclass=1000,deploy=False):
    def create_RepVGG_A0(numclass=1000,deploy=False):
        return make_RepVGG([2, 4, 14, 1], numclass,[0.75, 0.75, 0.75, 2.5], None, deploy)

    def create_RepVGG_A1(numclass=1000,deploy=False):
        return make_RepVGG([2, 4, 14, 1], numclass,[1, 1, 1, 2.5], None, deploy)

    def create_RepVGG_A2(numclass=1000,deploy=False):
        return make_RepVGG([2, 4, 14, 1], numclass,[1.5, 1.5, 1.5, 2.75], None, deploy)

    def create_RepVGG_B0(numclass=1000,deploy=False):
        return make_RepVGG([4, 6, 16, 1], numclass,[1, 1, 1, 2.5], None, deploy)

    def create_RepVGG_B1(numclass=1000,deploy=False):
        return make_RepVGG([4, 6, 16, 1], numclass,[2, 2, 2, 4], None, deploy)

    def create_RepVGG_B1g2(numclass=1000,deploy=False):
        return make_RepVGG([4, 6, 16, 1], numclass,[2, 2, 2, 4], g2_map, deploy)

    def create_RepVGG_B1g4(numclass=1000,deploy=False):
        return make_RepVGG([4, 6, 16, 1], numclass,[2, 2, 2, 4], g4_map, deploy)

    def create_RepVGG_B2(numclass=1000,deploy=False):
        return make_RepVGG([4, 6, 16, 1], numclass,[2.5, 2.5, 2.5, 5], None,deploy)

    def create_RepVGG_B2g2(numclass=1000,deploy=False):
        return make_RepVGG([4, 6, 16, 1], numclass,[2.5, 2.5, 2.5, 5], g2_map, deploy)

    def create_RepVGG_B2g4(numclass=1000,deploy=False):
        return make_RepVGG([4, 6, 16, 1], numclass,[2.5, 2.5, 2.5, 5], g4_map, deploy)

    def create_RepVGG_B3(numclass=1000,deploy=False):
        return make_RepVGG([4, 6, 16, 1], numclass,[3, 3, 3, 5], None, deploy)

    def create_RepVGG_B3g2(numclass=1000,deploy=False):
        return make_RepVGG([4, 6, 16, 1], numclass,[3, 3, 3, 5], g2_map, deploy)

    def create_RepVGG_B3g4(numclass=1000,deploy=False):
        return make_RepVGG([4, 6, 16, 1], numclass,[3, 3, 3, 5], g4_map, deploy)

    optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
    g2_map = {l: 2 for l in optional_groupwise_layers}
    g4_map = {l: 4 for l in optional_groupwise_layers}

    func_dict = {'RepVGG-A0': create_RepVGG_A0,
                 'RepVGG-A1': create_RepVGG_A1,
                 'RepVGG-A2': create_RepVGG_A2,
                 'RepVGG-B0': create_RepVGG_B0,
                 'RepVGG-B1': create_RepVGG_B1,
                 'RepVGG-B1g2': create_RepVGG_B1g2,
                 'RepVGG-B1g4': create_RepVGG_B1g4,
                 'RepVGG-B2': create_RepVGG_B2,
                 'RepVGG-B2g2': create_RepVGG_B2g2,
                 'RepVGG-B2g4': create_RepVGG_B2g4,
                 'RepVGG-B3': create_RepVGG_B3,
                 'RepVGG-B3g2': create_RepVGG_B3g2,
                 'RepVGG-B3g4': create_RepVGG_B3g4}
    return func_dict[name](numclass,deploy)

In [6]:
model = get_RepVGG_func_by_name('RepVGG-B1g2',10,False)
deploy_model = get_RepVGG_func_by_name('RepVGG-B1g2',10,True)

In [7]:
((x_train,y_train),(x_test,y_test)) = keras.datasets.cifar10.load_data()
x_train,x_test = x_train/255,x_test/255

In [8]:
model.compile('adam',keras.losses.SparseCategoricalCrossentropy(),keras.metrics.sparse_categorical_accuracy)
deploy_model.compile('adam',keras.losses.SparseCategoricalCrossentropy(),keras.metrics.sparse_categorical_accuracy)

In [9]:
model.fit(x_train,y_train,batch_size=64,epochs=1)



<tensorflow.python.keras.callbacks.History at 0x1dae24a05c8>

In [10]:
repvgg_model_convert(model,deploy_model)

In [11]:
model.evaluate(x_test,y_test)
deploy_model.evaluate(x_test,y_test)



[1.4816360473632812, 0.4796999990940094]