In [9]:
import torch 
import torch.nn as nn
from torchsummary import summary

In [21]:
class MobileNetV2Block(nn.Module):
    def __init__(self, in_channels, out_channels, expansion_factor=6, stride=1, **kwargs):
        super(MobileNetV2Block, self).__init__()
        self.expansion_factor = expansion_factor
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.expansion_channels = in_channels * expansion_factor
        
        self.c1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=self.expansion_channels,
            kernel_size=1,
            bias=False
        )

        self.act1 = nn.ReLU6()

        self.bn1 = nn.BatchNorm2d(self.expansion_channels)

        self.d1 = nn.Conv2d(
            in_channels=self.expansion_channels,
            out_channels=self.expansion_channels,
            kernel_size=3,
            stride=self.stride,
            padding=1,
            groups=self.expansion_channels,
            bias=False
        )

        self.act2 = nn.ReLU6()

        self.bn2 = nn.BatchNorm2d(self.expansion_channels)

        self.c2= nn.Conv2d(
            in_channels=self.expansion_channels,
            out_channels=out_channels,
            kernel_size=1,
            bias=False
        )

        self.bn3 = nn.BatchNorm2d(self.out_channels)

        self.residual = (self.stride == 1 and self.in_channels == self.out_channels)

    def forward(self, x):
        out = x
        x = self.c1(x)
        x = self.act1(x)
        x = self.bn1(x)
        x = self.d1(x)
        x = self.act2(x)
        x = self.bn2(x)
        x = self.c2(x)
        x = self.bn3(x)

        if self.residual:
            x += out

        return x                

In [22]:
block = MobileNetV2Block(16,24)
summary(block, (16,224,224), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 96, 224, 224]           1,536
             ReLU6-2         [-1, 96, 224, 224]               0
       BatchNorm2d-3         [-1, 96, 224, 224]             192
            Conv2d-4         [-1, 96, 224, 224]             864
             ReLU6-5         [-1, 96, 224, 224]               0
       BatchNorm2d-6         [-1, 96, 224, 224]             192
            Conv2d-7         [-1, 24, 224, 224]           2,304
       BatchNorm2d-8         [-1, 24, 224, 224]              48
Total params: 5,136
Trainable params: 5,136
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.06
Forward/backward pass size (MB): 238.88
Params size (MB): 0.02
Estimated Total Size (MB): 241.96
----------------------------------------------------------------


In [24]:
class MobileNetV2Layer(nn.Module):
    def __init__(self, in_channels, out_channels, n, **kwargs):
        super(MobileNetV2Layer, self).__init__()
        layers = [MobileNetV2Block(in_channels, out_channels, **kwargs)] + [MobileNetV2Block(out_channels, out_channels, **kwargs) for _ in range(n-1)]

        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer(x)
        return x        

In [26]:
layer = MobileNetV2Layer(16,24,4, expansion_factor=6,stride=2)
summary(layer, (16,224,224), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 96, 224, 224]           1,536
             ReLU6-2         [-1, 96, 224, 224]               0
       BatchNorm2d-3         [-1, 96, 224, 224]             192
            Conv2d-4         [-1, 96, 112, 112]             864
             ReLU6-5         [-1, 96, 112, 112]               0
       BatchNorm2d-6         [-1, 96, 112, 112]             192
            Conv2d-7         [-1, 24, 112, 112]           2,304
       BatchNorm2d-8         [-1, 24, 112, 112]              48
  MobileNetV2Block-9         [-1, 24, 112, 112]               0
           Conv2d-10        [-1, 144, 112, 112]           3,456
            ReLU6-11        [-1, 144, 112, 112]               0
      BatchNorm2d-12        [-1, 144, 112, 112]             288
           Conv2d-13          [-1, 144, 56, 56]           1,296
            ReLU6-14          [-1, 144,

In [16]:
class MobileNetV2(nn.Module):
    def __init__(self, in_channels=3, 
                 classes=1000, 
                 expansion_factor_list = [-1,1,6,6,6,6,6,6], 
                 out_channel_list = [32,16,24,32,64,96,160,320,1280], 
                 layer_list = [1,1,2,3,4,3,3,1], 
                 stride_list = [2,1,2,2,2,1,2,1]):
        
        super(MobileNetV2, self).__init__()
        
        self.c1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channel_list[0],
            kernel_size=3,
            stride=stride_list[0],
            padding=1,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channel_list[0])
        self.act1 = nn.ReLU6()
        
        layers = []
        for i in range(1, len(expansion_factor_list)):
            expansion_factor = expansion_factor_list[i]
            in_channels = out_channel_list[i-1]
            out_channels = out_channel_list[i]
            n = layer_list[i]
            stride = stride_list[i]

            layers.append(MobileNetV2Layer(in_channels, 
                                           out_channels, 
                                           n,
                                           expansion_factor=expansion_factor,
                                           stride=stride))

        self.model = nn.Sequential(*layers)

        self.c2 = nn.Conv2d(
            in_channels=out_channel_list[-2],
            out_channels=out_channel_list[-1],
            kernel_size=1,
            stride=1,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channel_list[-1])
        self.act2 = nn.ReLU6()

        self.gap = nn.AdaptiveAvgPool2d(1)

        self.c3 = nn.Conv2d(
            in_channels=out_channel_list[-1],
            out_channels=classes,
            kernel_size=1,
            bias=False
        )

    def forward(self, x):
        x = self.c1(x)
        x = self.act1(x)
        x = self.bn1(x)

        x = self.model(x)

        x = self.c2(x)
        x = self.act2(x)
        x = self.bn2(x)

        x = self.gap(x)

        x = self.c3(x)
        

In [27]:
model = MobileNetV2()
summary(model,(3,224,224),device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]             864
             ReLU6-2         [-1, 32, 112, 112]               0
       BatchNorm2d-3         [-1, 32, 112, 112]              64
            Conv2d-4         [-1, 32, 112, 112]           1,024
             ReLU6-5         [-1, 32, 112, 112]               0
       BatchNorm2d-6         [-1, 32, 112, 112]              64
            Conv2d-7         [-1, 32, 112, 112]             288
             ReLU6-8         [-1, 32, 112, 112]               0
       BatchNorm2d-9         [-1, 32, 112, 112]              64
           Conv2d-10         [-1, 16, 112, 112]             512
      BatchNorm2d-11         [-1, 16, 112, 112]              32
 MobileNetV2Block-12         [-1, 16, 112, 112]               0
 MobileNetV2Layer-13         [-1, 16, 112, 112]               0
           Conv2d-14         [-1, 96, 1