In [1]:
import os 
import time
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

## 1 模型定义
- 在GoogLeNet的定义中去掉max_pool2 与 max_pool4，防止size因为池化降维而变成0，**注意在forward中也要做相应更改**。

In [2]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class Inception(nn.Module):
    def __init__(self, in_channels, n1_1, n3x3red, n3x3, n5x5red, n5x5, pool_plane):
        super(Inception, self).__init__()
        # first line
        self.branch1x1 = BasicConv2d(in_channels, n1_1, kernel_size=1)

        # second line
        self.branch3x3 = nn.Sequential(
            BasicConv2d(in_channels, n3x3red, kernel_size=1),
            BasicConv2d(n3x3red, n3x3, kernel_size=3, padding=1)
        )

        # third line
        self.branch5x5 = nn.Sequential(
            BasicConv2d(in_channels, n5x5red, kernel_size=1),
            BasicConv2d(n5x5red, n5x5, kernel_size=5, padding=2)
        )

        # fourth line
        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_plane, kernel_size=1)
        )

    def forward(self, x):
        y1 = self.branch1x1(x)
        y2 = self.branch3x3(x)
        y3 = self.branch5x5(x)
        y4 = self.branch_pool(x)
        output = torch.cat([y1, y2, y3, y4], 1)
        return output


class GoogLeNet(nn.Module):
    def __init__(self, in_channels, num_classes=10):
        super(GoogLeNet, self).__init__()

        self.conv1 = BasicConv2d(in_channels, 64, kernel_size=7, stride=2, padding=3)

        self.max_pool1 = nn.MaxPool2d(3, stride=2)

        self.conv2 = BasicConv2d(64, 192, kernel_size=3, stride=1, padding=1)

        # self.max_pool2 = nn.MaxPool2d(3, stride=2)

        self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
        self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)

        self.max_pool3 = nn.MaxPool2d(3, stride=2)

        self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
        self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
        self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
        self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
        self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)

        # self.max_pool4 = nn.MaxPool2d(3, stride=2)

        self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
        self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)

        self.avg_pool = nn.AvgPool2d(7)

        self.dropout = nn.Dropout(0.4)

        self.classifier = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.max_pool1(x)
        x = self.conv2(x)
        # x = self.max_pool2(x)
        x = self.a3(x)
        x = self.b3(x)
        x = self.max_pool3(x)
        x = self.a4(x)
        x = self.b4(x)
        x = self.c4(x)
        x = self.d4(x)
        x = self.e4(x)
        # x = self.max_pool4(x)
        x = self.a5(x)
        x = self.b5(x)
        x = self.avg_pool(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

## 2 模型训练与评估类

In [3]:
def timer(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        func(*args,**kwargs)
        end = time.time()
        cost = end - start
        print("Cost time: {} mins.".format(cost/60)) 
    return wrapper

class CNNModel(object):
    def __init__(self, model, train_data, test_data, model_dir, model_name,
                 best_valid_loss=float('inf'), n_split=0.9, batch_size=64, epochs=10):
        self.batch_size = batch_size
        self.epochs = epochs
        self.best_valid_loss = best_valid_loss
        self.model_dir = model_dir
        self.model_name = model_name
        self.n_split = n_split
        
        self.train_data =  train_data
        self.test_data = test_data
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.init_data()
        self.init_iterator()
        self.init_model_path()
        
        self.model = self.init_model(model)
        self.optimizer = self.set_optimizer()
        self.criterion = self.set_criterion()
    
    def init_data(self):
        n_train = int(len(self.train_data)*self.n_split)
        n_validation = len(self.train_data) - n_train
        self.train_data, self.valid_data = torch.utils.data.random_split(self.train_data, [n_train, n_validation])
    
    def init_iterator(self):
        self.train_iterator = torch.utils.data.DataLoader(self.train_data, shuffle=True, batch_size=self.batch_size)
        self.valid_iterator = torch.utils.data.DataLoader(self.valid_data, batch_size=self.batch_size)
        self.test_iterator = torch.utils.data.DataLoader(self.test_data, batch_size=self.batch_size)
        
    def set_optimizer(self):
        optimizer = optim.Adam(self.model.parameters()) 
        return optimizer
    
    def set_criterion(self):
        criterion = nn.CrossEntropyLoss()
        return criterion
    
    def init_model(self, model):
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model = model.to(self.device)
        return model
        
    def init_model_path(self):
        if not os.path.isdir(self.model_dir):
            os.makedirs(self.model_dir)
        self.model_path = os.path.join(self.model_dir, self.model_name)
        
    # 定义评估函数
    def accu(self, fx, y):
        pred = fx.max(1,keepdim=True)[1]
        correct = pred.eq(y.view_as(pred)).sum()  # 得到该batch的准确度
        acc = correct.float()/pred.shape[0]
        return acc

    def train(self):
        epoch_loss = 0   # 积累变量
        epoch_acc = 0    # 积累变量
        self.model.train()    # 该函数表示PHASE=Train

        for (x,y) in self.train_iterator:  # 拿去每一个minibatch
            x = x.to(self.device)
            y = y.to(self.device)
            self.optimizer.zero_grad()
            fx = self.model(x)           # 进行forward
            loss = self.criterion(fx,y)  # 计算Loss,train_loss
            type(loss)
            acc = self.accu(fx,y)      # 计算精确度，train_accu
            loss.backward()     # 进行BP
            self.optimizer.step()    # 统一更新模型
            epoch_loss += loss.item()
            epoch_acc += acc.item()

        return epoch_loss/len(self.train_iterator),epoch_acc/len(self.train_iterator)

    def evaluate(self, iterator):
        epoch_loss = 0
        epoch_acc = 0
        self.model.eval()
        with torch.no_grad():
            for (x,y) in iterator:
                x = x.to(self.device)
                y = y.to(self.device)
                fx = self.model(x)
                loss = self.criterion(fx,y)
                acc = self.accu(fx,y)
                epoch_loss += loss.item()
                epoch_acc += acc.item()
        return epoch_loss/len(iterator),epoch_acc/len(iterator)
    
    @timer
    def train_fit(self):
        info = 'Epoch:{0} | Train Loss:{1} | Train Acc:{2} | Val Loss:{3} | Val Acc:{4}'
        for epoch in range(self.epochs):
            train_loss, train_acc = self.train()
            valid_loss, valid_acc = self.evaluate(self.valid_iterator)
            if valid_loss < self.best_valid_loss:  # 如果是最好的模型就保存到文件夹
                self.best_valid_loss = valid_loss
                torch.save(self.model.state_dict(), self.model_path)
            print(info.format(epoch+1, train_loss, train_acc, valid_loss, valid_acc))
    
    def get_acc(self):
        self.model.load_state_dict(torch.load(self.model_path))
        test_loss, test_acc = self.evaluate(self.test_iterator)
        print('| Test Loss: {0} | Test Acc: {1} |'.format(test_loss,test_acc))


## 3 数据集的准备

In [4]:
data_trans = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [5]:
train_data = datasets.MNIST('data', train=True, download=True, transform=data_trans)
test_data = datasets.MNIST('data', train=False, download=True, transform=data_trans)

## 4 模型训练

In [6]:
epochs = 20
n_split = 0.9
batch_size = 64
model_dir = 'models'
best_valid_loss = float('inf')
model_name = "googlenet.pt"
model = GoogLeNet(in_channels=1,num_classes=10)

obj = CNNModel(model=model, 
               train_data=train_data, 
               test_data=test_data, 
               model_dir=model_dir, 
               model_name=model_name,
               best_valid_loss=best_valid_loss, 
               n_split=n_split, 
               batch_size=batch_size, 
               epochs=epochs)

In [7]:
print(obj.device)

cuda


In [8]:
print(obj.model)

DataParallel(
  (module): GoogLeNet(
    (conv1): BasicConv2d(
      (conv): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (max_pool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): BasicConv2d(
      (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
    (a3): Inception(
      (branch1x1): BasicConv2d(
        (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (branch3x3): Sequential(
        (0): BasicConv2d(
          (conv): Conv2d(192, 96, kernel_size=(1, 1), stride=

In [9]:
obj.train_fit()

Epoch:1 | Train Loss:0.20343272557473296 | Train Acc:0.9438129443127962 | Val Loss:0.0860832029002461 | Val Acc:0.9770057626227115
Epoch:2 | Train Loss:0.07472551454485381 | Train Acc:0.9803700138466053 | Val Loss:0.055696024757591965 | Val Acc:0.9851507094312222
Epoch:3 | Train Loss:0.056842647595507625 | Train Acc:0.9853006516587678 | Val Loss:0.04943874551657033 | Val Acc:0.9861480498567541
Epoch:4 | Train Loss:0.04640269758382821 | Train Acc:0.987330914320539 | Val Loss:0.04173827684543868 | Val Acc:0.9894725179418604
Epoch:5 | Train Loss:0.04042175161554059 | Train Acc:0.9891883886255924 | Val Loss:0.040216217998989875 | Val Acc:0.9876994680851063
Epoch:6 | Train Loss:0.03700891847201398 | Train Acc:0.9899412519982641 | Val Loss:0.045699269768405465 | Val Acc:0.9877548763092528
Epoch:7 | Train Loss:0.029438126076374787 | Train Acc:0.9917802132701422 | Val Loss:0.03431864774369813 | Val Acc:0.9896387413461157
Epoch:8 | Train Loss:0.027639807626755105 | Train Acc:0.9926688388625592 

In [10]:
obj.get_acc()

| Test Loss: 0.026469486231684304 | Test Acc: 0.9940286624203821 |
