## googlenet

In [20]:
import torch
import torch.nn as nn

class Conv(nn.Module):
    def __init__(self, input_channels, output_channels, **kwargs) -> None:
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(input_channels,output_channels, **kwargs)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.relu(self.conv(x))


class Inception(nn.Module):
    def __init__(self, input_channels, out1x1, inter3x3, out3x3, inter5x5, out5x5, out1x1pool) -> None:
        super(Inception, self).__init__()

        self.branch_1x1 = Conv(input_channels, out1x1, kernel_size=(1,1))
        self.branch_3x3 = nn.Sequential(
            Conv(input_channels, inter3x3, kernel_size=(1,1)),
            Conv(inter3x3, out3x3, kernel_size=(3,3), padding=(1,1))
        )
        self.branch_5x5 = nn.Sequential(
            Conv(input_channels, inter5x5, kernel_size=(1,1)),
            Conv(inter5x5, out5x5, kernel_size=(5,5), padding=(2,2))
        )
        self.branch_1x1pool =  nn.Sequential(
            nn.MaxPool2d(kernel_size=(3,3), stride=(1,1), padding=(1,1)),
            Conv(input_channels, out1x1pool, kernel_size=(1,1))
        )

    def forward(self, x):
        return torch.cat([self.branch_1x1(x), self.branch_3x3(x), self.branch_5x5(x), self.branch_1x1pool(x)], dim=1)


class InceptionAux(nn.Module):
    def __init__(self, input_channels, num_classes=10) -> None:
        super(InceptionAux, self).__init__()
        self.dropout = nn.Dropout(p=0.7)
        self.pool = nn.AvgPool2d(kernel_size=(5,5), stride=(3,3))
        self.conv = Conv(input_channels, 128, kernel_size=(1,1))
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        x = self.relu(x.reshape(x.shape[0], -1))
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


class GoogleNet(nn.Module):
    def __init__(self, input_channels=3, num_classes=10) -> None:
        super(GoogleNet, self).__init__()
        self.conv1 = Conv(input_channels, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3))
        self.maxpool1 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))
        self.conv2 = Conv(64, 192, kernel_size=(3,3), stride=(1,1), padding=(1,1))
        self.maxpool2 = nn.MaxPool2d(kernel_size=(3,3),stride=(2,2), padding=(1,1))

        self.inception3a = Inception(192, out1x1=64, inter3x3=96, out3x3=128, inter5x5=16, out5x5=32,out1x1pool=32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)

        self.maxpool3 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        
        self.maxpool4 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        self.avg = nn.AvgPool2d(kernel_size=(7,7), stride=(1,1))
        self.dropout = nn.Dropout(p=0.4)
        self.fc = nn.Linear(1024,num_classes)

        self.aux1 = InceptionAux(512, num_classes)
        self.aux2 = InceptionAux(528, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)
        aux1 = self.aux1(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        aux2 = self.aux2(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)

        x = self.avg(x)
        x = self.dropout(x.reshape(x.shape[0], -1))
        x = self.fc(x)

        return x, aux2, aux1


In [21]:
googlenet = GoogleNet()

x = torch.rand(1, 3, 224, 224)
print(googlenet(x)[0].shape)

torch.Size([1, 10])
