In [1]:
import torch
from torch import nn, Tensor
from torchinfo import summary

* MobileNetV3(2019)
* 논문링크 : https://arxiv.org/pdf/1905.02244
* 참조링크 : https://github.com/d-li14/mobilenetv3.pytorch/blob/master/mobilenetv3.py

In [2]:
def _make_divisible(v, divisor, min_value = None):
    if min_value is None:
        min_value = divisor
    
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

_make_divisible(32 // 4, 8)

8

<img src = 'https://velog.velcdn.com/images/pre_f_86/post/34a8cfb1-80cf-4b44-a989-bb4438932999/image.PNG'>

In [5]:
class h_relu6(nn.Module):
    '''
    h_relu = relu6(x + 3) / 6
    '''
    def __init__(self, ):
        super().__init__()
        self.relu = nn.ReLU6()
    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    '''
    * mobilenetv3에서 제안한 h_swish
    * h_swish = x * h_relu6
    '''
    def __init__(self, ):
        super().__init__()
        self.sigmoid = h_relu6()
    def forward(self, x):
        return x * self.sigmoid(x)

hswish = h_swish()
hswish(torch.Tensor([1]))

tensor([0.6667])

<img src = 'https://img1.daumcdn.net/thumb/R1280x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2FczppZ9%2Fbtq0A3bDlIZ%2FXRBEd9rp8rK9Tksnx6vUQ0%2Fimg.png'>

In [18]:
torch.rand((1,64,1,1)).view(1,-1).size()

torch.Size([1, 64])

In [6]:
class SELayer(nn.Module):
    def __init__(self, in_channel, reduction = 4):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool2d((1,1))
        self.excitation = nn.Sequential(
            nn.Linear(in_channel, _make_divisible(in_channel // reduction, 8)),
            nn.ReLU(),
            nn.Linear(_make_divisible(in_channel // reduction, 8) ,in_channel),
            h_relu6()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        out = self.squeeze(x).view(b,-1)
        out = self.excitation(out).view(b,c,1,1)
        return x * out

torch.ones((1,2,3,3)) * torch.ones((1,2)).view(1,2,1,1) * 2

tensor([[[[2., 2., 2.],
          [2., 2., 2.],
          [2., 2., 2.]],

         [[2., 2., 2.],
          [2., 2., 2.],
          [2., 2., 2.]]]])

In [7]:
def conv_3x3_bn(in_channels, out_channels, stride):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1, bias = False),
        nn.BatchNorm2d(out_channels),
        h_swish()
    )

def conv_1x1_bn(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, padding = 1, bias = False),
        nn.BatchNorm2d(out_channels),
        h_swish()
    )

<img src = 'https://velog.velcdn.com/images/pre_f_86/post/604af1a4-21c9-42ce-9d70-8bf3c74dc1d0/image.PNG'>

In [8]:
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, inner_channels, out_channels, kernel_size, stride, use_se, use_hs):
        super().__init__()
        assert stride in [1,2]
        
        self.identity = stride == 1 and in_channels == out_channels
        
        if in_channels == inner_channels:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, inner_channels, kernel_size = kernel_size, stride = stride, padding = (kernel_size - 1) // 2, groups = inner_channels, bias = False),
                nn.BatchNorm2d(inner_channels),
                h_swish() if use_hs else nn.ReLU(),
                SELayer(inner_channels) if use_se else nn.Identity(),
                nn.Conv2d(inner_channels, out_channels, kernel_size = 1, stride = 1, padding = 0, bias = False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, inner_channels, kernel_size = 1, stride = 1, padding = 0, bias = False),
                nn.BatchNorm2d(inner_channels),
                h_swish() if use_hs else nn.ReLU(),
                nn.Conv2d(inner_channels, inner_channels, kernel_size = kernel_size, stride = stride, padding = (kernel_size - 1) // 2, groups = inner_channels, bias = False),
                nn.BatchNorm2d(inner_channels),
                SELayer(inner_channels) if use_se else nn.Identity(),
                h_swish() if use_hs else nn.ReLU(),
                nn.Conv2d(inner_channels, out_channels, kernel_size = 1, stride = 1, padding = 0, bias = False),
                nn.BatchNorm2d(out_channels)
            )
            
    def forward(self, x):
        if self.identity:
            return x + self.conv(x)
        else:
            return self.conv(x)

In [9]:
identity = nn.Identity()
identity(torch.Tensor((1,64,8,8)))

tensor([ 1., 64.,  8.,  8.])

<img src = 'https://velog.velcdn.com/images/pre_f_86/post/47890b15-1f1e-43f9-84d3-40892b08aee2/image.PNG'>

In [10]:
class MobileNetV3(nn.Module):
    '''
    width_mult = MobileNetV1의 alpha와 같은 역할
    '''
    def __init__(self, cfgs, mode, num_classes = 1000, width_mult = 1., block = InvertedResidual):
        super().__init__()
        self.cfgs = cfgs
        assert mode in ['large', 'small']
        
        inner_channels = _make_divisible(16 * width_mult, 8)
        layers = [conv_3x3_bn(3, inner_channels, stride = 2)]
        for k, t, c, use_se, use_hs, s in self.cfgs:
            '''
            k = kernel_size
            t = 확장계수
            c = out_channels
            use_se, use_hs = SELayer, H_Swish사용유무
            s = stride
            '''
            out_channels = _make_divisible(c * width_mult, 8)
            exp_size = _make_divisible(inner_channels * t, 8)
            layers.append(
                block(in_channels = inner_channels, 
                      inner_channels = exp_size,
                      out_channels = out_channels, 
                      kernel_size = k, stride = s, 
                      use_se = use_se, use_hs = use_hs)
            )
            inner_channels = out_channels
        self.block = nn.Sequential(*layers)
        self.last_conv = conv_1x1_bn(inner_channels, exp_size)
        self.avg = nn.AdaptiveAvgPool2d((1,1))
        out_channels = {'large' : 1280, 'small' : 1024}
        out_channels = _make_divisible(out_channels[mode] * width_mult, 8) if width_mult > 1. else out_channels[mode]
        self.fc = nn.Sequential(
            nn.Linear(exp_size, out_channels),
            h_swish(),
            nn.Dropout(0.2),
            nn.Linear(out_channels, num_classes)
        )
        self._init_layers()
    
    def _init_layers(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity = 'relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    def forward(self,x):
        out = self.block(x)
        out = self.last_conv(out)
        out = self.avg(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

In [11]:
def mobilenetv3_large(**kwargs):
    cfgs = [
        # kernel_size, t, out_channels, SE, HS, stride
        [3,   1,  16, False, False, 1],
        [3,   4,  24, False, False, 2],
        [3,   3,  24, False, False, 1],
        [5,   3,  40, True, False, 2],
        [5,   3,  40, True, False, 1],
        [5,   3,  40, True, False, 1],
        [3,   6,  80, False, True, 2],
        [3, 2.5,  80, False, True, 1],
        [3, 2.3,  80, False, True, 1],
        [3, 2.3,  80, False, True, 1],
        [3,   6, 112, True, True, 1],
        [3,   6, 112, True, True, 1],
        [5,   6, 160, True, True, 2],
        [5,   6, 160, True, True, 1],
        [5,   6, 160, True, True, 1]
    ]
    return MobileNetV3(cfgs, mode = 'large', **kwargs)

In [12]:
def mobilenetv3_small(**kwargs):
    cfgs = [
        # kernel_size, t, out_channels, SE, HS, stride
        [3,    1,  16, True, False, 2],
        [3,  4.5,  24, False, False, 2],
        [3, 3.67,  24, False, False, 1],
        [5,    4,  40, True, True, 2],
        [5,    6,  40, True, True, 1],
        [5,    6,  40, True, True, 1],
        [5,    3,  48, True, True, 1],
        [5,    3,  48, True, True, 1],
        [5,    6,  96, True, True, 2],
        [5,    6,  96, True, True, 1],
        [5,    6,  96, True, True, 1]
    ]
    return MobileNetV3(cfgs, mode = 'small', **kwargs)

In [13]:
summary(mobilenetv3_large(), (1,3,224,224))

Layer (type:depth-idx)                             Output Shape              Param #
MobileNetV3                                        [1, 1000]                 --
├─Sequential: 1-1                                  [1, 160, 7, 7]            --
│    └─Sequential: 2-1                             [1, 16, 112, 112]         --
│    │    └─Conv2d: 3-1                            [1, 16, 112, 112]         432
│    │    └─BatchNorm2d: 3-2                       [1, 16, 112, 112]         32
│    │    └─h_swish: 3-3                           [1, 16, 112, 112]         --
│    └─InvertedResidual: 2-2                       [1, 16, 112, 112]         --
│    │    └─Sequential: 3-4                        [1, 16, 112, 112]         464
│    └─InvertedResidual: 2-3                       [1, 24, 56, 56]           --
│    │    └─Sequential: 3-5                        [1, 24, 56, 56]           3,440
│    └─InvertedResidual: 2-4                       [1, 24, 56, 56]           --
│    │    └─Sequential: 3-6   

In [73]:
from torchvision.models import mobilenet_v3_large
summary(mobilenet_v3_large(), (1,3,224,224))

Layer (type:depth-idx)                             Output Shape              Param #
MobileNetV3                                        [1, 1000]                 --
├─Sequential: 1-1                                  [1, 960, 7, 7]            --
│    └─Conv2dNormActivation: 2-1                   [1, 16, 112, 112]         --
│    │    └─Conv2d: 3-1                            [1, 16, 112, 112]         432
│    │    └─BatchNorm2d: 3-2                       [1, 16, 112, 112]         32
│    │    └─Hardswish: 3-3                         [1, 16, 112, 112]         --
│    └─InvertedResidual: 2-2                       [1, 16, 112, 112]         --
│    │    └─Sequential: 3-4                        [1, 16, 112, 112]         464
│    └─InvertedResidual: 2-3                       [1, 24, 56, 56]           --
│    │    └─Sequential: 3-5                        [1, 24, 56, 56]           3,440
│    └─InvertedResidual: 2-4                       [1, 24, 56, 56]           --
│    │    └─Sequential: 3-6   

In [75]:
summary(mobilenetv3_small(), (1,3,224,224))

Layer (type:depth-idx)                             Output Shape              Param #
MobileNetV3                                        [1, 1000]                 --
├─Sequential: 1-1                                  [1, 96, 7, 7]             --
│    └─Sequential: 2-1                             [1, 16, 112, 112]         --
│    │    └─Conv2d: 3-1                            [1, 16, 112, 112]         432
│    │    └─BatchNorm2d: 3-2                       [1, 16, 112, 112]         32
│    │    └─h_swish: 3-3                           [1, 16, 112, 112]         --
│    └─InvertedResidual: 2-2                       [1, 16, 56, 56]           --
│    │    └─Sequential: 3-4                        [1, 16, 56, 56]           744
│    └─InvertedResidual: 2-3                       [1, 24, 28, 28]           --
│    │    └─Sequential: 3-5                        [1, 24, 28, 28]           3,864
│    └─InvertedResidual: 2-4                       [1, 24, 28, 28]           --
│    │    └─Sequential: 3-6   

In [74]:
from torchvision.models import mobilenet_v3_small
summary(mobilenet_v3_small(), (1,3,224,224))

Layer (type:depth-idx)                             Output Shape              Param #
MobileNetV3                                        [1, 1000]                 --
├─Sequential: 1-1                                  [1, 576, 7, 7]            --
│    └─Conv2dNormActivation: 2-1                   [1, 16, 112, 112]         --
│    │    └─Conv2d: 3-1                            [1, 16, 112, 112]         432
│    │    └─BatchNorm2d: 3-2                       [1, 16, 112, 112]         32
│    │    └─Hardswish: 3-3                         [1, 16, 112, 112]         --
│    └─InvertedResidual: 2-2                       [1, 16, 56, 56]           --
│    │    └─Sequential: 3-4                        [1, 16, 56, 56]           744
│    └─InvertedResidual: 2-3                       [1, 24, 28, 28]           --
│    │    └─Sequential: 3-5                        [1, 24, 28, 28]           3,864
│    └─InvertedResidual: 2-4                       [1, 24, 28, 28]           --
│    │    └─Sequential: 3-6   