In [1]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ConvBlock(nn.Module):


    def __init__(self, in_channels, out_channels, **kwargs):
        super(ConvBlock, self).__init__()
        self.conv_layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    
    def forward(self, x):

        return self.conv_layer(x)

In [3]:
class InceptionBlock(nn.Module):


    def __init__(self, in_channels, out_channels, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool):
        super(InceptionBlock, self).__init__()
        self.branch_1 = ConvBlock(in_channels, out_channels, kernel_size=1)
        self.branch_2 = nn.Sequential(
            ConvBlock(in_channels, red_3x3, kernel_size=1),
            ConvBlock(red_3x3, out_3x3, kernel_size=3, padding=1),
        )
        self.branch_3 = nn.Sequential(
            ConvBlock(in_channels, red_5x5, kernel_size=1),
            ConvBlock(red_5x5, out_5x5, kernel_size=5, padding=2),
        )
        self.branch_4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            ConvBlock(in_channels, out_1x1pool, kernel_size=1),
        )


    def forward(self, x):
        y = torch.cat([self.branch_1(x), self.branch_2(x), self.branch_3(x), self.branch_4(x)],dim=1)
        
        return y

In [4]:
class Auxiliary_classifier(nn.Module):


    def __init__(self, in_channels, num_classes):
        super(InceptionBlock, self).__init__()

        self.conv_layer = nn.Sequential(
            nn.AvgPool2d(kernel_size=5, stride=3),
            ConvBlock(in_channels, 128, kernel_size=1),
        )
        self.fc_layer = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(1024, num_classes)
        )
    

    def forward(self, x):
        y = self.conv_layer(x)
        y = y.view(y.shape[0], -1)
        y = y.fc_layer(y)

        return y

In [5]:
class GoogLeNet(nn.Module):


    def __init__(self, auxiliary_classifier=True, num_classes=10):
        super(GoogLeNet, self).__init__()

        assert auxiliary_classifier == True or auxiliary_classifier == False
        self.auxiliary_classifier = auxiliary_classifier
        self.conv1 = ConvBlock(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = ConvBlock(in_channels=64, out_channels=192, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.inception3a = InceptionBlock(192,64,96,128,16,32,32)
        self.inception3b = InceptionBlock(256,128,128,192,32,96,64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        self.inception4a = InceptionBlock(480, 192, 96, 208, 16, 48, 64)
        # auxiliary classifier
        self.inception4b = InceptionBlock(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = InceptionBlock(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = InceptionBlock(512, 112, 144, 288, 32, 64, 64)
        # auxiliary classifier
        self.inception4e = InceptionBlock(528,256,160,320,32,128,128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        self.inception5a = InceptionBlock(832,256,160,320,32,128,128)
        self.inception5b = InceptionBlock(832,384,192,384,48,128,128)
        self.avgpool = nn.AvgPool2d(kernel_size=7,stride=1)
        self.dropout = nn.Dropout(p=0.4)
        self.fc1 = nn.Linear(1024,num_classes)
        if self.auxiliary_classifier:
            self.aux1 = Auxiliary_classifier(512,num_classes)
            self.aux2 = Auxiliary_classifier(528,num_classes)
        else:
            self.aux1 = self.aux2 = None

    
    def forward(self, x):
        y = self.conv1(x)
        y = self.maxpool2(x)
        y = self.conv2(x)
        y = self.maxpool2(x)
        y = self.inception3a(x)
        y = self.inception3b(x)
        y = self.maxpool3(y)
        y = self.inception4a(y)

        if self.auxiliary_classifier and self.training:
            aux1 = self.aux1(y)
        y = self.inception4b(y)
        y = self.inception4c(y)
        y = self.inception4d(y)
        if self.auxiliary_classifier and self.training:
            aux2 = self.aux2(y)
        y = self.inception4e(y)
        y = self.maxpool4(y)
        y = self.inception5a(y)
        y = self.inception5b(y)
        y = self.avgpool(y)

        y = y.view(y.shape[0],-1)
        y = self.dropout(y)
        y = self.fc1(y)

        if self.auxiliary_classifier and self.training:
            return y,aux1,aux2
        else:
            return y