In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as dsets
from torchvision import transforms

from PIL import Image
import numpy as np

In [None]:
epochs = 1
batch_size = 5
img_size = 299

In [None]:
class ConvBlock(nn.Module):

    def __init__(self, in_channel, out_channel, **kwargs):

        super(ConvBlock, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, bias=False, **kwargs),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):

        return self.conv(x)

In [None]:
class InceptionA(nn.Module):

    def __init__(self, in_channel, pool_feat):

        super(InceptionA, self).__init__()

        self.branch1 = nn.Sequential(
            ConvBlock(in_channel, 64, kernel_size=1),
            ConvBlock(64, 96, kernel_size=3, padding=1),
            ConvBlock(96, 96, kernel_size=3, padding=1)
        )
        self.branch2 = nn.Sequential(
            ConvBlock(in_channel, 48, kernel_size=1),
            ConvBlock(48, 64, kernel_size=5, padding=2)
        )
        self.branch3 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            ConvBlock(in_channel, pool_feat, kernel_size=3, padding=1)
        )
        self.branch4 = ConvBlock(in_channel, 64, kernel_size=1)

    def forward(self, x):

        one = self.branch1(x)
        two = self.branch2(x)
        three = self.branch3(x)
        four = self.branch4(x)

        output = torch.cat([one, two, three, four], 1)

        return output


In [None]:
class InceptionB(nn.Module):

    def __init__(self, in_channel):

        super(InceptionB, self).__init__()

        self.branch1 = nn.Sequential(
            ConvBlock(in_channel, 64, kernel_size=1),
            ConvBlock(64, 96, kernel_size=3, padding=1),
            ConvBlock(96, 96, kernel_size=3, stride=2)
        )
        self.branch2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.branch3 = ConvBlock(in_channel, 384, kernel_size=3, stride=2)

    def forward(self, x):

        one = self.branch1(x)
        two = self.branch2(x)
        three = self.branch3(x)

        return torch.cat([one, two, three], 1)

In [None]:
class InceptionC(nn.Module):

    def __init__(self, in_channel, c):

        super(InceptionC, self).__init__()

        self.branch1 = nn.Sequential(
            ConvBlock(in_channel, c, kernel_size=1),
            ConvBlock(c, c, kernel_size=(7, 1), padding=(3, 0)),
            ConvBlock(c, c, kernel_size=(1, 7), padding=(0, 3)),
            ConvBlock(c, c, kernel_size=(7, 1), padding=(3, 0)),
            ConvBlock(c, 192, kernel_size=(1, 7), padding=(0, 3))
        )

        self.branch2 = nn.Sequential(
            ConvBlock(in_channel, c, kernel_size=1),
            ConvBlock(c, c, kernel_size=(7, 1), padding=(3, 0)),
            ConvBlock(c, 192, kernel_size=(1, 7), padding=(0, 3))
        )

        self.branch3 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            ConvBlock(in_channel, 192, kernel_size=1),
        )

        self.branch4 = ConvBlock(in_channel, 192, kernel_size=1)

    def forward(self, x):

        one = self.branch1(x)
        two = self.branch2(x)
        three = self.branch3(x)
        four = self.branch4(x)

        return torch.cat([one, two, three, four], 1)

In [None]:
class InceptionD(nn.Module):

    def __init__(self, in_channel):

        super(InceptionD, self).__init__()

        self.b1 = nn.Sequential(
            ConvBlock(in_channel, 192, kernel_size=1),
            ConvBlock(192, 192, kernel_size=(1, 7), padding=(0, 3)),
            ConvBlock(192, 192, kernel_size=(7, 1), padding=(3, 0)),
            ConvBlock(192, 192, kernel_size=3, stride=2)
        )

        self.b2 = nn.Sequential(
            ConvBlock(in_channel, 192, kernel_size=1),
            ConvBlock(192, 320, kernel_size=3, stride=2)
        )

        self.b3 = nn.AvgPool2d(kernel_size=3, stride=2)

    def forward(self, x):

        one = self.b1(x)
        two = self.b2(x)
        three = self.b3(x)

        return torch.cat([one, two, three], 1)

In [None]:
class InceptionE(nn.Module):

    def __init__(self, in_channel):

        super(InceptionE, self).__init__()

        self.b1_1 = ConvBlock(in_channel, 448, kernel_size=1)
        self.b1_2 = ConvBlock(448, 384, kernel_size=3, padding=1)
        self.b1_2_1 = ConvBlock(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.b1_2_2 = ConvBlock(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.b2 = ConvBlock(in_channel, 384, kernel_size=1)
        self.b2_1 = ConvBlock(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.b2_2 = ConvBlock(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.b3 = nn.Sequential(
            nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
            ConvBlock(in_channel, 192, kernel_size=1)
        )
        self.b4 = ConvBlock(in_channel, 320, kernel_size=1)

    def forward(self, x):

        one_1 = self.b1_1(x)
        one_1 = self.b1_2(one_1)
        one_1_1 = self.b1_2_1(one_1)
        one_1_2 = self.b1_2_2(one_1)
        one = torch.cat([one_1_1, one_1_2], 1)

        two_1 = self.b2(x)
        two_2_1 = self.b2_1(two_1)
        two_2_2 = self.b2_2(two_1)
        two = torch.cat([two_2_1, two_2_2], 1)

        three = self.b3(x)
        four = self.b4(x)

        return torch.cat([one, two, three, four], 1)

In [None]:
class InceptionAux(nn.Module):

    def __init__(self, in_channel, classes):

        super(InceptionAux, self).__init__()

        self.l1 = ConvBlock(in_channel, 128, kernel_size=1)
        self.l2 = ConvBlock(128, 768, kernel_size=5)
        self.l3 = nn.Linear(768, classes)
    
    def forward(self, x):

        x = nn.AvgPool2d(x, kernel_size=5, stride=3)
        x = self.l1(x)
        x = self.l2(x)
        x = nn.AdaptiveAvgPool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.l3(x)

        return x

In [None]:
class InceptionV3(nn.Module):

    def __init__(self, in_channel=3, classes=10, aux=False):

        super(InceptionV3, self).__init__()

        self.conv1_3x3 = ConvBlock(in_channel, 32, kernel_size=3, stride=2, padding=1)
        self.conv2_3x3 = ConvBlock(32, 32, kernel_size=3, padding=1)
        self.conv3_3x3 = ConvBlock(64, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(3, stride=2)
        self.conv4_1x1 = ConvBlock(64, 80, kernel_size=1)
        self.conv5_3x3 = ConvBlock(80, 192, kernel_size=3)
        self.pool2 = nn.MaxPool2d(3, stride=2)

        self.a1 = InceptionA(192, 32)
        self.a2 = InceptionA(256, 64)
        self.a3 = InceptionA(288, 64)

        self.b1 = InceptionB(288)

        self.c1 = InceptionC(768, 128)
        self.c2 = InceptionC(768, 160)
        self.c3 = InceptionC(768, 160)
        self.c4 = InceptionC(768, 192)

        if aux:
            self.inception_aux = InceptionAux(768, classes)

        self.d1 = InceptionD(768)

        self.e1 = InceptionE(1280)
        self.e2 = InceptionE(2048)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout2d()
        self.linear = nn.Linear(2048, classes)

    def forward(self, x):

        auxres = None

        x = self.conv1_3x3(x)
        x = self.conv2_3x3(x)
        x = self.conv3_3x3(x)
        x = self.pool1(x)
        x = self.conv4_1x1(x)
        x = self.conv5_3x3(x)
        x = self.pool2(x)

        x = self.a1(x)
        x = self.a2(x)
        x = self.a3(x)

        x = self.b1(x)

        x = self.c1(x)
        x = self.c2(x)
        x = self.c3(x)
        x = self.c4(x)

        if self.aux:
            if self.training:
                auxres = self.inception_aux(x)

        x = self.d1(x)

        x = self.e1(x)
        x = self.e2(x)

        x = self.avgpool(x)
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = self.linear(x)

        return x, auxres


In [None]:
model = InceptionV3()