In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# Block
class Inception(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels['conv1'], kernel_size=1)
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=out_channels['conv2'], kernel_size=3, stride=1, padding=1)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=out_channels['conv3'], kernel_size=5, stride=1, padding=2)
        )
        self.conv4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels['conv4'], kernel_size=1)
        )
    
    def forward(self, x):
        z1 = F.relu(self.conv1(x))
        z2 = F.relu(self.conv2(x))
        z3 = F.relu(self.conv3(x))
        z4 = F.relu(self.conv4(x))
        return torch.cat((z1, z2, z3, z4), dim=1)

In [None]:
x = torch.rand((32, 3, 32, 32))

out_channels = {
    'conv1': 128,
    'conv2': 192,
    'conv3': 96,
    'conv4': 64
}
inception = Inception(3, out_channels)

y = inception(x)
print(y.shape)

In [None]:
# Network
class GoogLeNet(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3), #
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1), #
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1) #
        )

        self.inceptions = nn.Sequential(
            Inception(in_channels=channels['inception1']['in'], out_channels=channels['inception1']['out']),
            Inception(in_channels=channels['inception2']['in'], out_channels=channels['inception2']['out']),
            Inception(in_channels=channels['inception3']['in'], out_channels=channels['inception3']['out'])
        )

        self.classifier = nn.Sequential(
            nn.AvgPool2d(kernel_size=4, stride=1),
            nn.Flatten(),
            nn.Linear(in_features=channels['classifier']['in'], out_features=channels['classifier']['out'])
        )
    
    def forward(self, inputs):
        feature1 = self.stem(inputs)
        feature2 = self.inceptions(feature1)
        outputs = self.classifier(feature2)
        return outputs

channels = {
    'inception1': {
        'in': 128,
        'out': {
            'conv1': 128,
            'conv2': 192,
            'conv3': 96,
            'conv4': 64
        }
    },
    'inception2': {
        'in': 480,
        'out': {
            'conv1': 128,
            'conv2': 192,
            'conv3': 96,
            'conv4': 64
        }
    },
    'inception3': {
        'in': 480,
        'out': {
            'conv1': 128,
            'conv2': 192,
            'conv3': 96,
            'conv4': 64
        }
    },
    'classifier': {
        'in': 480,
        'out': 10
    }
}
googlenet = GoogLeNet(channels=channels)

In [None]:
x = torch.rand(64, 3, 32, 32)

y = googlenet(x)

print(y.shape)