## GoogleNet 
Batch Normalization is performed before each ReLU operation

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [38]:
class Inception(nn.Module):
    def __init__(self, c1, c2, c3, c4, **kwargs):
        super(Inception, self).__init__(**kwargs)
        
        self.b1 = nn.LazyConv2d(c1, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(c1)        
        
        self.b2_1 = nn.LazyConv2d(c2[0], kernel_size=1)
        self.bn2_1 = nn.BatchNorm2d(c2[0])
        self.b2_2 = nn.LazyConv2d(c2[1], kernel_size=3, padding=1)
        self.bn2_2 = nn.BatchNorm2d(c2[1])    
        
        self.b3_1 = nn.LazyConv2d(c3[0], kernel_size=1)
        self.bn3_1 = nn.BatchNorm2d(c3[0])
        self.b3_2 = nn.LazyConv2d(c3[1], kernel_size=5, padding=2)
        self.bn3_2 = nn.BatchNorm2d(c3[1])
        
        self.b4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.b4_2 = nn.LazyConv2d(c4, kernel_size=1)
        self.bn4 = nn.BatchNorm2d(c4) 
        
    def forward(self, x):
        b1 = F.relu(self.bn1(self.b1(x)))
        b2 = F.relu(self.bn2_2(self.b2_2(F.relu(self.bn2_1(self.b2_1(x))))))
        b3 = F.relu(self.bn3_2(self.b3_2(F.relu(self.bn3_1(self.b3_1(x))))))
        b4 = F.relu(self.bn4(self.b4_2(self.b4_1(x))))
        
        return torch.cat((b1, b2, b3, b4), dim=1) # concat on the channel dim

In [40]:
class GoogleNet(nn.Module):
    def __init__(self, num_classes):
        super(GoogleNet, self).__init__()
        self.backbone_1 = nn.Sequential(
            nn.LazyConv2d(64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)                        
        )
        self.backbone_2 = nn.Sequential(
            nn.LazyConv2d(64, kernel_size=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.neck_1 = nn.Sequential(
            Inception(64, (96, 128), (16, 32), 32),
            Inception(128, (128, 192),  (32, 96), 64),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.neck_2 = nn.Sequential(
            Inception(192, (96, 208), (16, 48), 64),
            Inception(160, (112, 224), (24, 64), 64),
            Inception(128, (128, 256), (24, 64), 64),
            Inception(112, (144, 288), (32, 64), 64),
            Inception(256, (160, 320), (32, 128), 128),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.neck_3 = nn.Sequential(
            Inception(256, (160, 320), (32, 128), 128),
            Inception(384, (192, 384), (48, 128), 128),
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten()
        )
        self.head = nn.Sequential(
            nn.LazyLinear(num_classes)
        )
        
    def forward(self, x):
        x = self.backbone_1(x)
        x = self.backbone_2(x)
        x = self.neck_1(x)
        x = self.neck_2(x)
        x = self.neck_3(x)
        x = self.head(x)
        return x

In [44]:
X = torch.randn(10, 3, 28, 28)

In [45]:
model = GoogleNet(num_classes=10)
out = model(X)
print(out)

tensor([[-1.9737e-01, -6.5973e-04, -4.5003e-02,  2.4237e-01, -4.4973e-01,
         -7.8477e-02,  1.8349e-01, -3.9047e-01,  5.1757e-01,  4.5902e-01],
        [-7.4174e-02, -2.1426e-01, -2.3938e-01, -3.4963e-01, -4.7464e-01,
         -1.3004e-01,  4.1213e-01,  7.2614e-02,  4.6444e-01, -2.8055e-01],
        [-4.3328e-02,  8.2026e-01, -2.7893e-01,  4.3066e-01,  5.2307e-01,
         -1.2237e-01,  2.0621e-01, -1.4017e-01,  4.5774e-01,  2.5361e-01],
        [ 1.7975e-01,  8.5689e-01, -7.2231e-01, -8.9433e-02,  1.7784e-01,
          1.3084e-01,  3.0835e-01,  4.9903e-02,  4.2639e-01,  3.1438e-01],
        [ 5.6343e-02, -8.2611e-02, -3.5372e-01,  9.2317e-01,  6.0609e-01,
          3.2935e-01, -2.1700e-01, -3.5961e-01,  3.5112e-01,  4.6720e-01],
        [-7.2575e-02, -2.5553e-01, -8.5229e-01,  3.6594e-01, -2.9135e-01,
         -1.6591e-01,  8.0243e-01, -5.1526e-01,  3.0935e-01, -4.6002e-01],
        [-2.2295e-01, -2.9023e-01,  7.1409e-01, -4.5263e-01,  2.7272e-01,
          1.9450e-01,  5.9090e-0