In [35]:
import torch
from torch import nn

In [36]:
class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(conv_block, self).__init__()
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        print("forward")
        return self.relu(self.batchnorm(self.conv(x)))

<img src="inception images/inception block.png" alt="Alternative text" />

In [31]:
class Inception_block(nn.Module):
    def __init__(
        self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool
    ):
        super(Inception_block, self).__init__()
        self.branch1 = conv_block(in_channels, out_1x1, kernel_size=1)
        print("fwd")

        self.branch2 = nn.Sequential(
            conv_block(in_channels, red_3x3, kernel_size=1),
            conv_block(red_3x3, out_3x3, kernel_size=(3, 3), padding=1),
        )

        self.branch3 = nn.Sequential(
            conv_block(in_channels, red_5x5, kernel_size=1),
            conv_block(red_5x5, out_5x5, kernel_size=5, padding=2),
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            conv_block(in_channels, out_1x1pool, kernel_size=1),
        )

    def forward(self, x):
        return torch.cat(
            [self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], 1
        )


class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.7)
        self.pool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = conv_block(in_channels, 128, kernel_size=1)
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.pool(x)
        x = self.conv(x)
        x = x.reshape(x.shape[0], -1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

<img src="inception images/inception model.png" alt="Alternative text" />

In [32]:
class GoogLeNet(nn.Module):
    def __init__(self, aux_logits=True, num_classes=1000):
        super(GoogLeNet, self).__init__()
        assert aux_logits == True or aux_logits == False
        self.aux_logits = aux_logits

        # Write in_channels, etc, all explicit in self.conv1, rest will write to
        # make everything as compact as possible, kernel_size=3 instead of (3,3)
        self.conv1 = conv_block(
            in_channels=3,
            out_channels=64,
            kernel_size=7,
            stride=2,
            padding=3,
        )

        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = conv_block(64, 192, kernel_size=3, stride=1, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # In this order: in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool
        self.inception3a = Inception_block(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception_block(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception4a = Inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception_block(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception5a = Inception_block(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception_block(832, 384, 192, 384, 48, 128, 128)

        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
        self.dropout = nn.Dropout(p=0.4)
        self.fc1 = nn.Linear(1024, num_classes)

        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)
        else:
            self.aux1 = self.aux2 = None

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)

        # Auxiliary Softmax classifier 1
        if self.aux_logits and self.training:
            aux1 = self.aux1(x)

        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)

        # Auxiliary Softmax classifier 2
        if self.aux_logits and self.training:
            aux2 = self.aux2(x)

        x = self.inception4e(x)
        x = self.maxpool4(x)
        x = self.inception5a(x)
        x = self.inception5b(x)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.dropout(x)
        x = self.fc1(x)

        if self.aux_logits and self.training:
            return aux1, aux2, x
        else:
            return x


In [34]:
BATCH_SIZE = 5
x = torch.randn(BATCH_SIZE, 3, 224, 224)
model = GoogLeNet(aux_logits=True, num_classes=1000)
print(model(x)[2].shape)
assert model(x)[2].shape == torch.Size([BATCH_SIZE, 1000])

fwd
fwd
fwd
fwd
fwd
fwd
fwd
fwd
fwd
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
torch.Size([5, 1000])
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward
forward


In [25]:
model

GoogLeNet(
  (conv1): conv_block(
    (relu): ReLU()
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv2): conv_block(
    (relu): ReLU()
    (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (inception3a): Inception_block(
    (branch1): conv_block(
      (relu): ReLU()
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
      (batchnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch2): Sequential(
      (0): conv_block(
        (relu): ReLU()
        (conv): Conv2d(192, 96, kernel_size

In [22]:
x[2]

tensor([[[ 1.0806,  1.3238, -2.7762,  ..., -0.7524,  0.1655,  2.5799],
         [ 0.5409,  1.1752,  0.1432,  ..., -1.2042, -0.1257, -1.7857],
         [-0.6652,  1.5498, -0.8420,  ...,  1.2285, -0.9732, -0.2479],
         ...,
         [ 1.0623, -1.0364, -0.1100,  ..., -1.4402,  0.0262, -0.1533],
         [-1.2413, -0.1606, -0.1121,  ...,  0.5763,  2.3192, -2.2909],
         [-0.1457,  0.3257, -0.0425,  ...,  2.9775,  1.4676,  1.1675]],

        [[ 0.1807, -0.3026,  0.3154,  ...,  0.2478,  0.3980, -1.0058],
         [-0.5866, -1.2107, -1.3271,  ..., -1.8993, -0.3372, -0.3607],
         [ 0.6271, -0.8137, -0.5513,  ...,  0.7921,  0.3588, -0.5384],
         ...,
         [-0.1446,  0.9979,  0.1557,  ..., -2.1076, -0.8822,  1.0696],
         [ 1.0671,  1.2347,  1.2434,  ...,  0.8301, -0.1372, -1.6148],
         [-1.5599, -0.2296, -1.9155,  ..., -0.8523,  0.3211, -1.0275]],

        [[-2.0182,  0.1899, -0.1880,  ...,  0.6613,  0.0758,  1.1613],
         [-0.4421, -2.4837, -0.1279,  ...,  0

In [23]:
torch.Size([BATCH_SIZE, 1000])

torch.Size([5, 1000])

In [24]:
x[2]

tensor([[[ 1.0806,  1.3238, -2.7762,  ..., -0.7524,  0.1655,  2.5799],
         [ 0.5409,  1.1752,  0.1432,  ..., -1.2042, -0.1257, -1.7857],
         [-0.6652,  1.5498, -0.8420,  ...,  1.2285, -0.9732, -0.2479],
         ...,
         [ 1.0623, -1.0364, -0.1100,  ..., -1.4402,  0.0262, -0.1533],
         [-1.2413, -0.1606, -0.1121,  ...,  0.5763,  2.3192, -2.2909],
         [-0.1457,  0.3257, -0.0425,  ...,  2.9775,  1.4676,  1.1675]],

        [[ 0.1807, -0.3026,  0.3154,  ...,  0.2478,  0.3980, -1.0058],
         [-0.5866, -1.2107, -1.3271,  ..., -1.8993, -0.3372, -0.3607],
         [ 0.6271, -0.8137, -0.5513,  ...,  0.7921,  0.3588, -0.5384],
         ...,
         [-0.1446,  0.9979,  0.1557,  ..., -2.1076, -0.8822,  1.0696],
         [ 1.0671,  1.2347,  1.2434,  ...,  0.8301, -0.1372, -1.6148],
         [-1.5599, -0.2296, -1.9155,  ..., -0.8523,  0.3211, -1.0275]],

        [[-2.0182,  0.1899, -0.1880,  ...,  0.6613,  0.0758,  1.1613],
         [-0.4421, -2.4837, -0.1279,  ...,  0