In [2]:
import torch
import torch.nn as nn

In [3]:
class ConvBlock(nn.Module):
    def __init__(self, inchannels, outchannels, kernel_size, stride, padding) -> None:
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(inchannels, outchannels, kernel_size, stride, padding)
        self.relu = nn.ReLU()
        self.batchnorm = nn.BatchNorm2d(outchannels)

    def forward(self, x):
        return self.relu(self.batchnorm(self.conv(x)))

In [9]:
class InceptionBlock(nn.Module):
    def __init__(self, config) -> None:
        super(InceptionBlock, self).__init__()
        inchannels = config[0]
        self.config = config[1:]

        self.conv_1x1 = ConvBlock(inchannels,
                             self.config[0],
                             (1,1),
                             1,
                             (0,0))
        self.conv_3x3_reduce = ConvBlock(inchannels,
                                    self.config[1],
                                    (3,3),
                                    1,
                                    (1,1))
        self.conv_3x3 = ConvBlock(self.config[1],
                             self.config[2],
                             (3,3),
                             1,
                             (1,1))
        self.conv_5x5_reduce = ConvBlock(inchannels,
                                    self.config[3],
                                    (5,5),
                                    1,
                                    (2,2))
        self.conv_5x5 = ConvBlock(self.config[3],
                             self.config[4],
                             (5,5),
                             1,
                             (2,2))
        self.max_pool = nn.MaxPool2d(kernel_size=(3,3),
                                stride=(1,1),
                                padding=(1,1))
        self.pool_proj = ConvBlock(inchannels,
                              self.config[5],
                              (1,1),
                              1,
                              (0,0))
        

    def forward(self, x):
        return torch.cat((self.conv_1x1(x),
                          self.conv_3x3(self.conv_3x3_reduce(x)),
                          self.conv_5x5(self.conv_5x5_reduce(x)),
                          self.pool_proj(self.max_pool(x))), dim=1)
            
        

In [10]:
b = InceptionBlock([256, 128, 128, 192, 32, 96, 64])
x = torch.randn((16, 256, 28, 28))
print(b(x).shape)

done
done
torch.Size([16, 480, 28, 28])


In [135]:
class GoogleNet(nn.Module):
    def __init__(self, inchannel, num_class) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(inchannel, 64, (7,7), 2, (3,3))
        self.max_pool1 = nn.MaxPool2d((3,3), 2, (1,1))
        self.conv2 = nn.Conv2d(64, 192, (3,3), 1, (1,1))
        self.max_pool2 = nn.MaxPool2d((3,3), 2, (1,1))
        # in_channel #1x1 #3x3reduced #3x3 #5x5reduced #5x5 #pool pro
        self.config = {
            'incep_3a' : [192, 64, 96, 128, 16, 32, 32],
            'incep_3b' : [256, 128, 128, 192, 32, 96, 64],

            'incep_4a' : [480, 192, 96, 208, 16, 48, 64],
            'incep_4b' : [512, 160, 112, 224, 24, 64, 64],
            'incep_4c' : [512, 128, 128, 256, 24, 64, 64],
            'incep_4d' : [512, 112, 144, 288, 32, 64, 64],
            'incep_4e' : [528, 256, 160, 320, 32,128, 128],

            'incep_5a' : [832, 256, 160, 320, 32, 128, 128],
            'incep_5b' : [832, 384, 192, 384, 48, 128, 128]
            }
        self.IC_3a = InceptionBlock(self.config["incep_3a"])
        self.IC_3b = InceptionBlock(self.config["incep_3b"])
        self.max_pool3 = nn.MaxPool2d((3,3), 2, (1,1))
    
        self.IC_4a = InceptionBlock(self.config["incep_4a"])
        self.IC_4b = InceptionBlock(self.config["incep_4b"])
        self.IC_4c = InceptionBlock(self.config["incep_4c"])
        self.IC_4d = InceptionBlock(self.config["incep_4d"])
        self.IC_4e = InceptionBlock(self.config["incep_4e"])
        self.max_pool4 = nn.MaxPool2d((3,3), 2, (1,1))

        self.IC_5a = InceptionBlock(self.config["incep_5a"])
        self.IC_5b = InceptionBlock(self.config["incep_5b"])        
        self.avg_pool = nn.AvgPool2d((7,7), 1, (0,0))
        self.fcs = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(1024, num_class),
        )
    def forward(self, x):
        x = self.max_pool1(self.conv1(x))

        x = self.max_pool2(self.conv2(x))

        x = self.IC_3b(self.IC_3a(x))
        x = self.max_pool3(x)

        x = self.IC_4e(self.IC_4d(self.IC_4c(self.IC_4b(self.IC_4a(x)))))
        x = self.max_pool4(x)

        x = self.IC_5b(self.IC_5a(x))
        x = self.avg_pool(x)
        x = x.view(x.shape[0], -1)
        return self.fcs(x)

        

In [136]:
net = GoogleNet(3, 120)
x = torch.randn((16, 3, 224, 224))
print(net(x).shape)