### manually edit

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init



class hswish(nn.Module):
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out


class hsigmoid(nn.Module):
    def forward(self, x):
        out = F.relu6(x + 3, inplace=True) / 6
        return out


class SeModule(nn.Module):
    def __init__(self, in_size, reduction=4):
        super(SeModule, self).__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(in_size // reduction),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(in_size),
            hsigmoid()
        )

    def forward(self, x):
        return x * self.se(x)


class Block(nn.Module):
    '''expand + depthwise + pointwise'''
    def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
        super(Block, self).__init__()
        self.stride = stride
        self.se = semodule

        self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(expand_size)
        self.nolinear1 = nolinear
        self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
        self.bn2 = nn.BatchNorm2d(expand_size)
        self.nolinear2 = nolinear
        self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_size)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_size != out_size:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_size),
            )

    def forward(self, x):
        out = self.nolinear1(self.bn1(self.conv1(x)))
        out = self.nolinear2(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.se != None:
            out = self.se(out)
        out = out + self.shortcut(x) if self.stride==1 else out
        return out



add transform to distill the knowledge

In [19]:
class MobileNetV3_Small(nn.Module):
    def __init__(self, num_classes=1000):
        super(MobileNetV3_Small, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = hswish()
        self.transform0 = feature_transform(16, 64)
        self.block0 = Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2)
        self.transform1 = feature_transform(16, 128)
        self.block1 = Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2)
        self.block2 = Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1)
        self.transform2 = feature_transform(24, 256)
        self.block3 = Block(5, 24, 96, 40, hswish(), SeModule(40), 2)
        self.block4 = Block(5, 40, 240, 40, hswish(), SeModule(40), 1)
        self.block5 = Block(5, 40, 240, 40, hswish(), SeModule(40), 1)
        self.block6 = Block(5, 40, 120, 48, hswish(), SeModule(48), 1)
        self.transform3 = feature_transform(48,512)
        self.block7 = Block(5, 48, 144, 48, hswish(), SeModule(48), 1)
        self.block8 = Block(5, 48, 288, 96, hswish(), SeModule(96), 2)
        self.block9 = Block(5, 96, 576, 96, hswish(), SeModule(96), 1)
        self.block10 = Block(5, 96, 576, 96, hswish(), SeModule(96), 1)
        self.transform4 = feature_transform(96, 512)


        self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(576)
        self.hs2 = hswish()
        self.linear3 = nn.Linear(576, 1280)
        self.bn3 = nn.BatchNorm1d(1280)
        self.hs3 = hswish()
        self.linear4 = nn.Linear(1280, num_classes)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.hs1(self.bn1(self.conv1(x)))
        out = self.block10(block9(block8(block7(block6(block5(block4(block3(block2(block1(block0(out)))))))))))
        out = self.hs2(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = self.hs3(self.bn3(self.linear3(out)))
        out = self.linear4(out)
        return out

def feature_transform(inp, oup):
    conv2d = nn.Conv2d(inp, oup, kernel_size=1)  # no padding
    relu = nn.ReLU(inplace=True)
    layers = []
    layers += [conv2d, relu]
    return nn.Sequential(*layers)

In [20]:
net = MobileNetV3_Small()
print(net)

MobileNetV3_Small(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (hs1): hswish()
  (transform0): Sequential(
    (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU(inplace=True)
  )
  (block0): Block(
    (se): SeModule(
      (se): Sequential(
        (0): AdaptiveAvgPool2d(output_size=1)
        (1): Conv2d(16, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): Conv2d(4, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): hsigmoid()
      )
    )
    (conv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_s

### pre-trained

In [128]:
from mobilenetv3_model_test import mobilenetv3
import mobilenetv3_model_test
net_small = mobilenetv3(mode='small')
state_dict = torch.load('../mobilenetv3_small_67.4.pth.tar', map_location=torch.device('cpu'))
net_small.load_state_dict(state_dict)

<All keys matched successfully>

In [129]:
net_small.features[0]._modules['3'] = feature_transform(16, 64)
net_small.features[1].conv._modules['9'] = feature_transform(16, 128)
net_small.features[3].conv._modules['9'] = feature_transform(24, 256)
net_small.features[7].conv._modules['9'] = feature_transform(48, 512)
net_small.features[11].conv._modules['9'] = feature_transform(96, 512)

In [130]:
def regist_hook(net):
    net.extraction = []

    def get(model, input, output):
        # function will be automatically called each time, since the hook is injected
        net.extraction.append(output.detach())

    for name, module in net._modules['features']._modules.items():
        if name in ['1', '3', '7', '11']:
            net._modules['features'][int(name)].conv._modules['9'].register_forward_hook(get)
    for name, module in net._modules['features']._modules.items():
        if name in ['0']:
            net._modules['features'][int(name)]._modules['3'].register_forward_hook(get)

In [131]:
regist_hook(net_small)