In [None]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
class InceptionBlock(nn.Module):
  def __init__(self, in_channels, out_channels_1, in_reduce_3, out_reduce_3, in_reduce_5, out_reduce_5, out_channels_pool):
    super().__init__()
    self.conv1_depth = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels_1, kernel_size=1, stride=1),
        nn.ReLU()
    )

    self.conv3_depth = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=in_reduce_3, kernel_size=1, stride=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=in_reduce_3, out_channels=out_reduce_3, kernel_size=3, stride=1, padding=1),
        nn.ReLU()
    )

    self.conv5_depth = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=in_reduce_5, kernel_size=1, stride=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=in_reduce_5, out_channels=out_reduce_5, kernel_size=5, stride=1, padding=2),
        nn.ReLU()
    )

    self.maxpool_depth = nn.Sequential(
        nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels_pool, kernel_size=1, stride=1),
        nn.ReLU()
    )

  def forward(self, x):
    return torch.cat((self.conv1_depth(x), self.conv3_depth(x), self.conv5_depth(x), self.maxpool_depth(x)), dim=0)

In [None]:
class GoogLeNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
    self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    # self.norm = nn.LocalResponseNorm()
    self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1, stride=1)
    self.conv3 = nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, stride=1, padding=1)
    self.inception3a = InceptionBlock(in_channels=192, out_channels_1=64, in_reduce_3=96, out_reduce_3=128, in_reduce_5=16, out_reduce_5=32, out_channels_pool=32)
    self.inception3b = InceptionBlock(in_channels=256, out_channels_1=128, in_reduce_3=128, out_reduce_3=192, in_reduce_5=32, out_reduce_5=96, out_channels_pool=64)
    self.inception4a = InceptionBlock(in_channels=480, out_channels_1=192, in_reduce_3=96, out_reduce_3=208, in_reduce_5=16, out_reduce_5=48, out_channels_pool=64)
    self.inception4b = InceptionBlock(in_channels=512, out_channels_1=160, in_reduce_3=112, out_reduce_3=224, in_reduce_5=24, out_reduce_5=64, out_channels_pool=64)
    self.inception4c = InceptionBlock(in_channels=512, out_channels_1=128, in_reduce_3=128, out_reduce_3=256, in_reduce_5=24, out_reduce_5=64, out_channels_pool=64)
    self.inception4d = InceptionBlock(in_channels=512, out_channels_1=112, in_reduce_3=144, out_reduce_3=288, in_reduce_5=32, out_reduce_5=64, out_channels_pool=64)
    self.inception4e = InceptionBlock(in_channels=528, out_channels_1=256, in_reduce_3=160, out_reduce_3=320, in_reduce_5=32, out_reduce_5=128, out_channels_pool=128)
    self.inception5a = InceptionBlock(in_channels=832, out_channels_1=256, in_reduce_3=160, out_reduce_3=320, in_reduce_5=32, out_reduce_5=128, out_channels_pool=128)
    self.inception5b = InceptionBlock(in_channels=832, out_channels_1=384, in_reduce_3=192, out_reduce_3=384, in_reduce_5=48, out_reduce_5=128, out_channels_pool=128)
    self.avg_pool = nn.AvgPool2d(kernel_size=7, stride=1)
    self.dropout = nn.Dropout(p=0.4)
    self.linear = nn.Linear(in_features=1024, out_features=1000)

  def forward(self, x):
    x = self.max_pool(F.relu(self.conv1(x)))
    x = F.relu(self.conv2(x))
    x = self.max_pool(F.relu(self.conv3(x)))
    x = self.inception3a(x)
    x = self.max_pool(self.inception3b(x))
    x = self.inception4a(x)
    x = self.inception4b(x)
    x = self.inception4c(x)
    x = self.inception4d(x)
    x = self.max_pool(self.inception4e(x))
    x = self.inception5a(x)
    x = torch.flatten(self.avg_pool(self.inception5b(x)))
    return F.softmax(self.linear(x))

In [None]:
model = GoogLeNet()

In [None]:
test = torch.rand(size=(3, 224, 224))

In [None]:
test.shape

torch.Size([3, 224, 224])

In [None]:
pred = model(test)



In [None]:
pred.shape

torch.Size([1000])

In [None]:
from torchsummary import summary

In [None]:
# summary(model, (3, 224, 224))