In [None]:
# @title <p>Essential Import
import os, shutil, json
from PIL import Image
from zipfile import ZipFile
import matplotlib.pyplot as plt
import numpy as np, pandas as pd, random as rd
import warnings
warnings.filterwarnings("ignore")

In [None]:
# @title <p>Torch Essential Import
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# @title <p> Model Architecture
class ConvBlock(nn.Module):
  def __init__(self, in_ch, out_ch, **kwargs):
    super().__init__()
    self.conv = nn.Conv2d(in_ch, out_ch, **kwargs)
    self.bn = nn.BatchNorm2d(out_ch)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    x = self.relu(x)
    return x

class InceptionBlock(nn.Module):
  def __init__(self, in_ch, out1x1, red3x3, out3x3, red5x5, out5x5, maxpool1x1):
    super().__init__()
    self.branch1 = ConvBlock(in_ch, out1x1, kernel_size = (1, 1))
    self.branch2 = nn.Sequential(
        ConvBlock(in_ch, red3x3, kernel_size=1),
        ConvBlock(red3x3, out3x3, kernel_size=(3, 3), padding=1),
    )
    self.branch3 = nn.Sequential(
        ConvBlock(in_ch, red5x5, kernel_size=1),
        ConvBlock(red5x5, out5x5, kernel_size = (5, 5), padding=2)
    )
    self.branch4 = nn.Sequential(
        nn.MaxPool2d(kernel_size = (3, 3), padding=1, stride=1),
        ConvBlock(in_ch, maxpool1x1, kernel_size= (1, 1))
    )

  def forward(self, x):
    return torch.cat(
            [self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], 1
        )

class InceptionAux(nn.Module):
    def __init__(self, in_ch, num_classes):
        super().__init__()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.7)
        self.pool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = ConvBlock(in_ch, 128, kernel_size=(1,1))
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        x = x.reshape(x.shape[0], -1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class InceptionModel(nn.Module):
  def __init__(self, aux_logits=True, num_classes=1000):
    super().__init__()
    self.aux_logits = aux_logits
    self.maxpool = nn.MaxPool2d(kernel_size=(3,3), stride=2, padding=1)

    self.conv1 = ConvBlock(3, 64, kernel_size = (7,7), stride=2, padding=3)
    self.conv2 = ConvBlock(64, 192, kernel_size = (3,3), stride=1, padding=1)

    self.inception3a = InceptionBlock(192, 64, 96, 128, 16, 32, 32)
    self.inception3b = InceptionBlock(256, 128, 128, 192, 32, 96, 64)

    self.inception4a = InceptionBlock(480, 192, 96, 208, 16, 48, 64)
    self.inception4b = InceptionBlock(512, 160, 112, 224, 24, 64, 64)
    self.inception4c = InceptionBlock(512, 128, 128, 256, 24, 64, 64)
    self.inception4d = InceptionBlock(512, 112, 144, 288, 32, 64, 64)
    self.inception4e = InceptionBlock(528, 256, 160, 320, 32, 128, 128)

    self.inception5a = InceptionBlock(832, 256, 160, 320, 32, 128, 128)
    self.inception5b = InceptionBlock(832, 384, 192, 384, 48, 128, 128)

    self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
    self.dropout = nn.Dropout(p=0.4)
    self.fc1 = nn.Linear(1024, num_classes)

    if self.aux_logits:
        self.aux1 = InceptionAux(512, num_classes)
        self.aux2 = InceptionAux(528, num_classes)
    else:
        self.aux1 = self.aux2 = None

  def forward(self, x):
    x = self.conv1(x)
    x = self.maxpool(x)
    x = self.conv2(x)
    x = self.maxpool(x)

    x = self.inception3a(x)
    x = self.inception3b(x)
    x = self.maxpool(x)

    x = self.inception4a(x)

    if self.aux_logits and self.training:
      aux1 = self.aux1(x)

    x = self.inception4b(x)
    x = self.inception4c(x)
    x = self.inception4d(x)

    if self.aux_logits and self.training:
      aux2 = self.aux2(x)

    x = self.inception4e(x)
    x = self.maxpool(x)
    x = self.inception5a(x)
    x = self.inception5b(x)
    x = self.avgpool(x)
    x = x.reshape(x.shape[0], -1)
    x = self.dropout(x)
    x = self.fc1(x)

    if self.aux_logits and self.training:
            return aux1, aux2, x
    else:
        return x

model = InceptionModel()
x = torch.randn(1, 3, 224, 224)
aux1, aux2, out = model(x)
out.shape

torch.Size([1, 1000])