In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F

## 网络结构

In [2]:
class Inception(nn.Module):
    def __init__(self, in_c, c1, c2, c3, c4):
        super(Inception, self).__init__()
        self.p1_1 = nn.Conv2d(in_c, c1, kernel_size=1)
        self.p2_1 = nn.Conv2d(in_c, c2[0], kernel_size=1)
        self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)
        self.p3_1 = nn.Conv2d(in_c, c3[0], kernel_size=1)
        self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)
        self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.p4_2 = nn.Conv2d(in_c, c4, kernel_size=1)
 
    def forward(self, x):
        p1 = F.relu(self.p1_1(x))
        p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))
        p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))
        p4 = F.relu(self.p4_2(self.p4_1(x)))
        return torch.cat((p1, p2, p3, p4), dim=1) 

In [3]:
class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()

    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=7)
    
class FlattenLayer(torch.nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): 
        return x.view(x.shape[0], -1)

In [4]:
class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(GoogLeNet, self).__init__()
        
        self.b1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3),
                           nn.ReLU(),
                           nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=1),
                           nn.Conv2d(64, 128, kernel_size=3, padding=1),
                           nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        self.b3 = nn.Sequential(Inception(128, 64, (64, 64), (64, 64), 64),
                           Inception(256, 64, (64, 64), (64, 64), 64),
                           Inception(256, 64, (64, 64), (64, 64), 64),
                           GlobalAvgPool2d())

        self.output=nn.Sequential(FlattenLayer(),
                                  nn.Dropout(p=0.4),
                                  nn.Linear(4096, 10))

    def forward(self, x):
        x = self.b1(x)
        print(x.shape)
        x = self.b2(x)
        print(x.shape)
        x = self.b3(x)
        print(x.shape)
        x = self.output(x)
        print(x.shape)

        return x

In [5]:
gnet = GoogLeNet()
x = torch.rand(1, 3, 224, 224)
# for blk in gnet.children(): 
#     x = blk(x)
#     print('output shape: ', x.shape)
gnet(x)

torch.Size([1, 32, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 256, 4, 4])
torch.Size([1, 10])


tensor([[-0.0237, -0.0036,  0.0089, -0.0201,  0.0020, -0.0241,  0.0231,  0.0169,
          0.0124, -0.0219]], grad_fn=<AddmmBackward0>)