In [9]:
from timm.models import create_model
import torch
import torch.nn as nn

In [10]:
ckpt=torch.load('./r192_a70_pth.tar',map_location=torch.device('cpu'))

In [11]:
def fold_parameter(conv_layer, bn_layer):
    bn_gamma = bn_layer.weight.data
    bn_beta = bn_layer.bias.data
    bn_mean = bn_layer.running_mean
    bn_var = bn_layer.running_var
    eps = bn_layer.eps
    fold_factor = bn_gamma/torch.sqrt(bn_var+eps)
    fold_bias = bn_beta-bn_mean*fold_factor
    fold_factor = torch.unsqueeze(fold_factor,1)
    fold_factor = torch.unsqueeze(fold_factor,2)
    fold_factor = torch.unsqueeze(fold_factor,3)
    fold_weight = conv_layer.weight.data*fold_factor
    conv_layer.weight.data = fold_weight
    conv_layer.bias = nn.Parameter(fold_bias)

In [12]:
model = torch.load('./efn_r192.pth')

In [13]:
model.load_state_dict(ckpt['state_dict'])

<All keys matched successfully>

### Fold all batch norm layer

In [14]:
# fold conv_stem
fold_parameter(model.conv_stem,model.bn1)
model.bn1 = nn.Identity()
# fold blocks
for blocks in model.blocks:
    for b in blocks:
        if hasattr(b, 'conv_pwl'):
            fold_parameter(b.conv_pw,b.bn1)
            b.bn1 = nn.Identity()
            fold_parameter(b.conv_dw,b.bn2)
            b.bn2 = nn.Identity()
            fold_parameter(b.conv_pwl,b.bn3)
            b.bn3 = nn.Identity()
        else:
            fold_parameter(b.conv_dw,b.bn1)
            b.bn1 = nn.Identity()
            fold_parameter(b.conv_pw,b.bn2)
            b.bn2 = nn.Identity()
    
fold_parameter(model.conv_head,model.bn2)
model.bn2 = nn.Identity()

In [15]:
torch.save(model, 'efn-r192-fold.pth')

In [16]:
model

EfficientNet(
  (conv_stem): Conv2dSame(3, 32, kernel_size=(3, 3), stride=(2, 2))
  (bn1): Identity()
  (act1): ReLU6(inplace=True)
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
        (bn1): Identity()
        (act1): ReLU6(inplace=True)
        (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (act2): Identity()
      )
    )
    (1): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1))
        (bn1): Identity()
        (act1): ReLU6(inplace=True)
        (conv_dw): Conv2dSame(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96)
        (bn2): Identity()
        (act2): ReLU6(inplace=True)
        (conv_pwl): Conv2d(96, 16, kernel_size=(1, 1), stride=(1, 1))
        (bn3): Identity()
      )
      (1): InvertedResidual(
        (conv_pw): Conv2d(16, 96, 