In [1]:
import torch
import torch.nn as nn

In [26]:
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, expansion_rate, stride=1):
        super(InvertedResidual, self).__init__()
        self.in_channels= in_channels
        self.out_channels = out_channels
        self.expansion_rate = expansion_rate
        self.stride = stride
        
        self.expansion_layer = nn.Sequential(
            nn.Conv2d(in_channels, in_channels*expansion_rate, kernel_size=1),
            nn.BatchNorm2d(in_channels*expansion_rate),
            nn.ReLU())
        
        self.depthwise_layer = nn.Sequential(
            nn.Conv2d(in_channels*expansion_rate, in_channels*expansion_rate, kernel_size=3, groups=in_channels*expansion_rate, stride=stride, padding=1),
            nn.BatchNorm2d(in_channels*expansion_rate),
            nn.ReLU())
        
        self.pointwise_layer = nn.Sequential(
            nn.Conv2d(in_channels*expansion_rate, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels))
        
    def forward(self, x):
        el = self.expansion_layer(x)
        dl = self.depthwise_layer(el)
        pl = self.pointwise_layer(dl)
        if self.stride == 1 and self.in_channels==self.out_channels:
            return x + pl
        else:
            return pl

In [40]:
class MobileNetV2(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU6())
        
        self.bk1 = InvertedResidual(32, 16, 1, 1)
        self.bk2 = nn.Sequential(
            InvertedResidual(16, 24, 6, 1),
            InvertedResidual(24, 24, 6, 2))
        
        self.bk3 = nn.Sequential(
            InvertedResidual(24, 32, 6, 1),
            InvertedResidual(32, 32, 6, 1),
            InvertedResidual(32, 32, 6, 2))
        
        self.bk4 = nn.Sequential(
            InvertedResidual(32, 64, 6, 1),
            InvertedResidual(64, 64, 6, 1),
            InvertedResidual(64, 64, 6, 1),
            InvertedResidual(64, 64, 6, 2))
        
        self.bk5 = nn.Sequential(
            InvertedResidual(64, 96, 6, 1),
            InvertedResidual(96, 96, 6, 1),
            InvertedResidual(96, 96, 6, 1))
        
        self.bk6 = nn.Sequential(
            InvertedResidual(96, 160, 6, 1),
            InvertedResidual(160, 160, 6, 1),
            InvertedResidual(160, 160, 6, 2))
        
        self.bk7 = InvertedResidual(160, 320, 6, 1)

        self.conv1x1 = nn.Conv2d(320, 1280, 1)

        self.pool = nn.AvgPool2d(7)

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(1280, num_classes)
        )
        
    def forward(self, x):
        conv1 = self.conv1(x)
        bk1 = self.bk1(conv1)
        bk2 = self.bk2(bk1)
        bk3 = self.bk3(bk2)
        bk4 = self.bk4(bk3)
        bk5 = self.bk5(bk4)
        bk6 = self.bk6(bk5)
        bk7 = self.bk7(bk6)
        conv1x1 = self.conv1x1(bk7)
        pool = self.pool(conv1x1)
        flatten = torch.flatten(pool, 1)
        classifier = self.classifier(flatten)

        return classifier

In [41]:
x = torch.randn(64, 3, 224, 224)
model = MobileNetV2(3, 10)
print(model(x).shape)

torch.Size([64, 10])
