In [1]:
import torch
from torch import nn
from repnet_arch import *

In [2]:
rep_blocks = [
    'RepConvBNBlock',
    'ACBlock',
    'RepVGGBlock',
    'DiverseBranchBlock',
    'ResRepBlock'
]

In [3]:
N, C, H, W = 1, 3, 512, 512
out_channels = [1, 3, 8, 16]
kernel_paddings = [(1, 0), (1, 1), (3,1), (3,0), (5,1), (5,2), (5,3), (5,4), (5,6)]
bias_list = [True, False]

comparison_dict = {t: {'convert': [], 'deploy': []} for t in rep_blocks}

for kernel, padding in kernel_paddings:
    for out_channel in out_channels:
        for bias in bias_list:
            x = torch.randn(N, C, H, W)
            for block_type in rep_blocks:
                if block_type == 'RepVGGBlock' and (kernel != 3 or \
                    padding != 1 or out_channel != C):  continue
                if block_type == 'DiverseBranchBlock' and \
                    (padding != kernel // 2 or \
                     (out_channel == 1 and out_channel != C)): continue
                if block_type == 'ResRepBlock' and \
                    (C != out_channel or kernel != 3 or padding != 1): continue

                nd_model = eval(block_type)(
                    in_channels=C,
                    out_channels=out_channel,
                    kernel_size=kernel,
                    padding=padding,
                    stride=1,
                    bias=bias,
                    activate='relu',
                    deploy=False).eval()
                
                nd_result = nd_model(x)
                nd_model.switch_to_deploy()
                cd_result = nd_model(x)

                d_model = eval(block_type)(
                    in_channels=C,
                    out_channels=out_channel,
                    kernel_size=kernel,
                    padding=padding,
                    stride=1,
                    bias=bias,
                    activate='relu',
                    deploy=True).eval()
                d_model.load_state_dict(nd_model.state_dict(), strict=True)
                d_result = d_model(x)

                comparison_dict[block_type]['convert'].append(
                    ((nd_result - cd_result) ** 2).sum().view(-1))
                comparison_dict[block_type]['deploy'].append(
                    ((d_result - cd_result) ** 2).sum().view(-1))

for block_type in rep_blocks:
    print(f"{block_type}:")
    convert_data = torch.cat(comparison_dict[block_type]['convert'])
    print('\tconvert:', torch.min(convert_data), '~', torch.max(convert_data),
          torch.mean(convert_data), '±', torch.std(convert_data))
    deploy_data = torch.cat(comparison_dict[block_type]['deploy'])
    print('\tdeploy  :', torch.min(deploy_data), '~', torch.max(deploy_data),
          torch.mean(deploy_data), '±', torch.std(deploy_data))

RepConvBNBlock:
	convert: tensor(7.0258e-11, grad_fn=<MinBackward1>) ~ tensor(3.5256e-08, grad_fn=<MaxBackward1>) tensor(9.9469e-09, grad_fn=<MeanBackward0>) ± tensor(1.1204e-08, grad_fn=<StdBackward0>)
	deploy  : tensor(0., grad_fn=<MinBackward1>) ~ tensor(0., grad_fn=<MaxBackward1>) tensor(0., grad_fn=<MeanBackward0>) ± tensor(0., grad_fn=<StdBackward0>)
ACBlock:
	convert: tensor(4.6152e-10, grad_fn=<MinBackward1>) ~ tensor(8.4251e-08, grad_fn=<MaxBackward1>) tensor(2.5009e-08, grad_fn=<MeanBackward0>) ± tensor(2.5993e-08, grad_fn=<StdBackward0>)
	deploy  : tensor(0., grad_fn=<MinBackward1>) ~ tensor(0., grad_fn=<MaxBackward1>) tensor(0., grad_fn=<MeanBackward0>) ± tensor(0., grad_fn=<StdBackward0>)
RepVGGBlock:
	convert: tensor(8.1415e-09, grad_fn=<MinBackward1>) ~ tensor(1.5122e-08, grad_fn=<MaxBackward1>) tensor(1.1632e-08, grad_fn=<MeanBackward0>) ± tensor(4.9359e-09, grad_fn=<StdBackward0>)
	deploy  : tensor(0., grad_fn=<MinBackward1>) ~ tensor(0., grad_fn=<MaxBackward1>) tensor