# GoogLeNet V2

<img src="imgs/inceptionv2.png">

<img src="imgs/googlenetv2.png">

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
#
class InceptionV2(nn.Module):
    def __init__(self, in1, out1, mid2, out2, mid3, out3, out4, pool='max', strides=1): 
        super(InceptionV2, self).__init__()
        # 1x1 conv branch
        if strides == 1:
            self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in1, out_channels=out1, kernel_size=1),
                                      nn.BatchNorm2d(num_features=out1, eps=0.001))
        else:
            self.conv1 = None
        
        # 1x1 conv -> 3x3 conv branch
        self.conv2_1 = nn.Sequential(nn.Conv2d(in_channels=in1, out_channels=mid2, kernel_size=1),
                                    nn.BatchNorm2d(num_features=mid2, eps=0.001))
        self.conv2_2 = nn.Sequential(nn.Conv2d(in_channels=mid2, out_channels=out2, kernel_size=3, stride=strides, padding=1),
                                    nn.BatchNorm2d(num_features=out2, eps=0.001))
        
        # 1x1 conv -> 3x3 conv -> 3x3 conv
        self.conv3_1 = nn.Sequential(nn.Conv2d(in_channels=in1, out_channels=mid3, kernel_size=1),
                                     nn.BatchNorm2d(num_features=mid3, eps=0.001))
        self.conv3_2 = nn.Sequential(nn.Conv2d(in_channels=mid3, out_channels=out3, kernel_size=3, padding=1),
                                     nn.BatchNorm2d(num_features=out3, eps=0.001))
        self.conv3_3 = nn.Sequential(nn.Conv2d(in_channels=out3, out_channels=out3, kernel_size=3, stride=strides, padding=1),
                                     nn.BatchNorm2d(num_features=out3, eps=0.001))
        
        # 3x3 pool -> 1x1 conv branch
        if pool == 'max':
            self.pool4 = nn.MaxPool2d(kernel_size=3, stride=strides, padding=1)              
        elif pool == 'avg':
            self.pool4 = nn.AvgPool2d(kernel_size=3, stride=strides, padding=1)
        else:
            raise RuntimeError('undefind parmeters for pool：' + pool)

        if strides == 1:
            self._conv4 = nn.Conv2d(in_channels=in1, out_channels=out4, kernel_size=1)
            self.conv4 = nn.Sequential(self.pool4, self._conv4)
        else:
            self.conv4 = nn.Sequential(self.pool4)
        
    
    def forward(self, x):
        out2 = self.conv2_2(self.conv2_1(x))
        out3 = self.conv3_3(self.conv3_2(self.conv3_1(x)))
        out4 = self.conv4(x)
        if self.conv1 is not None:
            out1 =  self.conv1(x)
            output = torch.cat([out1,out2,out3,out4],1)
        else:
            output = torch.cat([out2,out3,out4],1)
        return output

class GoogLeNetV2(nn.Module):
    def __init__(self):
        super(GoogLeNetV2, self).__init__()
#         self.BN = nn.BatchNorm2d(num_features=3)
        self.layer1 = nn.Conv2d(in_channels=3, out_channels=64,kernel_size=(7,7), stride=2, padding=3)
        
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.conv3_1 = nn.Conv2d(in_channels=64, out_channels=64,kernel_size=(1,1))
        self.conv3_2 = nn.Conv2d(in_channels=64, out_channels=192,kernel_size=(3,3), padding=1)
        self.layer3 = nn.Sequential( self.conv3_1, self.conv3_2)
        
        self.pool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer5 = InceptionV2(192,  64,  64, 64, 64, 96, 32, pool='avg', strides=1)
        self.layer6 = InceptionV2(256, 64, 64, 96, 64, 96, 64, pool='avg', strides=1)
        
        # stride = 2
        self.layer7 = InceptionV2(320, _, 128, 160, 64,  96, _, pool='avg', strides=2)
        
        self.layer8 = InceptionV2(576, 224, 64,  96, 96, 128,  128, pool='avg', strides=1)
        self.layer9 = InceptionV2(576, 192,  96, 128, 96,  128,  128, pool='avg', strides=1)
        self.layer10 = InceptionV2(576, 160, 128, 160, 128, 160,  96, pool='avg', strides=1)
        self.layer11 = InceptionV2(576, 96, 128, 192,  160, 192, 96, pool='avg', strides=1)
        
        # stride = 2
        self.layer12 = InceptionV2(576, _, 128, 192, 192,  256,  _, pool='avg', strides=2)
        
        
        self.layer13 = InceptionV2(1024, 352, 192, 320, 160, 224, 128, pool='avg', strides=1)
        self.layer14 = InceptionV2(1024, 352, 192, 320, 192, 224, 128 ,pool='max', strides=1)
        
        self.pool15 = nn.AvgPool2d(kernel_size=7)
        
        self.linear16 = nn.Linear(1024, 1000)
    def forward(self, x):
#         network = self.BN(x)
        network = self.layer1(x)
        print('layer1 ',network.size())
        network = self.pool2(network)
        print('pool2 ',network.size())
        network = self.layer3(network)
        print('layer3 ',network.size())
        network = self.pool4(network)
        print('pool4 ',network.size())
        network = self.layer5(network)
        print('layer5 ',network.size())
        network = self.layer6(network)
        print('layer6 ',network.size())
        network = self.layer7(network)
        print('layer7 ',network.size())
        network = self.layer8(network)
        print('layer8 ',network.size())
        network = self.layer9(network)
        print('layer9 ',network.size())
        network = self.layer10(network)
        print('layer10 ',network.size())
        network = self.layer11(network)
        print('layer11 ',network.size())
        network = self.layer12(network)
        print('layer12 ',network.size())
        network = self.layer13(network)
        print('layer13 ',network.size())
        network = self.layer14(network)
        print('layer14 ',network.size())
        network = self.pool15(network)
        print('pool15 ',network.size())
        network = network.view(network.size(0), -1)
        network = self.linear16(network)
        print('linear16 ',network.size())
        return network
        

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
input_data = torch.randn(1,3,224,224)
net = GoogLeNetV2()
net.to(device)
input_data = input_data.to(device)
output = net(input_data)
print(output.size())

layer1  torch.Size([1, 64, 112, 112])
pool2  torch.Size([1, 64, 56, 56])
layer3  torch.Size([1, 192, 56, 56])
pool4  torch.Size([1, 192, 28, 28])
layer5  torch.Size([1, 256, 28, 28])
layer6  torch.Size([1, 320, 28, 28])
layer7  torch.Size([1, 576, 14, 14])
layer8  torch.Size([1, 576, 14, 14])
layer9  torch.Size([1, 576, 14, 14])
layer10  torch.Size([1, 576, 14, 14])
layer11  torch.Size([1, 576, 14, 14])
layer12  torch.Size([1, 1024, 7, 7])
layer13  torch.Size([1, 1024, 7, 7])
layer14  torch.Size([1, 1024, 7, 7])
pool15  torch.Size([1, 1024, 1, 1])
linear16  torch.Size([1, 1000])
torch.Size([1, 1000])
