In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline
from d2l import torch as d2l
import random
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
import time

In [2]:
# 百度来的，不然下载不动。。
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [3]:
mean, std = [0.491, 0.482, 0.446], [0.247, 0.243, 0.261]

In [4]:
class TrainDataset(data.Dataset):
    def __init__(self, S):
        super().__init__()
        self.dataset = torchvision.datasets.CIFAR10(
            root="../data", train=True, download=True)
        
        self.S = S
        
        self.trans = [transforms.RandomCrop(64),
                      transforms.RandomHorizontalFlip(p=0.5),
                      transforms.ColorJitter(brightness=0.1,
                                             contrast=0.1,
                                             saturation=0.1,
                                             hue=0),
                      transforms.ToTensor(),
                      transforms.Normalize(mean, std, inplace=True)]
        self.trans = transforms.Compose(self.trans)
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        if isinstance(self.S, int):
            resize = transforms.Resize(max(64, self.S))
        elif isinstance(self.S, list):
            assert(len(self.S) == 2)
            resize = transforms.Resize(
                random.randint(self.S[0], self.S[1]))

        return (self.trans(resize(self.dataset[index][0])),
                self.dataset[index][1])

In [5]:
class TestDataset(data.Dataset):
    def __init__(self, Q, horizontal_flip=False):
        super().__init__()
        self.dataset = torchvision.datasets.CIFAR10(
            root="../data", train=False, download=True)
        
        self.Q = Q
        
        self.trans = [transforms.ToTensor(),
                      transforms.Normalize(mean, std, inplace=True)]
        if horizontal_flip:
            self.trans.insert(0, transforms.RandomHorizontalFlip(p=1))
        self.trans = transforms.Compose(self.trans)
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        assert(isinstance(self.Q, int))
        resize = transforms.Resize(max(64, self.Q))
            
        return (self.trans(resize(self.dataset[index][0])),
                self.dataset[index][1])

In [6]:
class VGG_block(nn.Module):
    def __init__(self, in_channels, out_channels, num_3x3, conv_1x1=False, batch_norm=False):
        super().__init__()
        layers = []
        layers += [nn.Conv2d(in_channels, out_channels,
                             kernel_size=3, stride=1, padding=1),
                   nn.BatchNorm2d(out_channels),
                   nn.ReLU(inplace=True)] if batch_norm else \
                  [nn.Conv2d(in_channels, out_channels,
                             kernel_size=3, stride=1, padding=1),
                   nn.ReLU(inplace=True)]
        if num_3x3 > 1:
            for i in range(1, num_3x3):
                layers += [nn.Conv2d(out_channels, out_channels,
                                     kernel_size=3, stride=1, padding=1),
                           nn.BatchNorm2d(out_channels),
                           nn.ReLU(inplace=True)] if batch_norm else \
                          [nn.Conv2d(out_channels, out_channels,
                                     kernel_size=3, stride=1, padding=1),
                           nn.ReLU(inplace=True)]
        if conv_1x1:
            layers += [nn.Conv2d(out_channels, out_channels,
                                 kernel_size=1, stride=1, padding=0),
                       nn.BatchNorm2d(out_channels),
                       nn.ReLU(inplace=True)] if batch_norm else \
                      [nn.Conv2d(out_channels, out_channels,
                                 kernel_size=1, stride=1, padding=0),
                       nn.ReLU(inplace=True)]
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        
        self.block = nn.Sequential(*layers)
    
    def forward(self, X):
        return self.block(X)

In [7]:
class VGG_mini(nn.Module):
    def __init__(self, configuration, batch_norm=False, dropout=0.5):
        super().__init__()
        self.configurations = {
            'A': [[3,   8,  1, False],
                  [8,   16, 1, False],
                  [16,  32, 2, False],
                  [32,  64, 2, False],
                  [64,  64, 2, False]],
            
            'B': [[3,   8,  2, False],
                  [8,   16, 2, False],
                  [16,  32, 2, False],
                  [32,  64, 2, False],
                  [64,  64, 2, False]],
            
            'C': [[3,   8,  2, False],
                  [8,   16, 2, False],
                  [16,  32, 2, True],
                  [32,  64, 2, True],
                  [64,  64, 2, True]],
            
            'D': [[3,   8,  2, False],
                  [8,   16, 2, False],
                  [16,  32, 3, False],
                  [32,  64, 3, False],
                  [64,  64, 3, False]],
            
            'E': [[3,   8,  2, False],
                  [8,   16, 2, False],
                  [16,  32, 4, False],
                  [32,  64, 4, False],
                  [64,  64, 4, False]]
        }
        self.configuration = configuration
        self.batch_norm = batch_norm
        self.blocks = []
        for arg_list in self.configurations[self.configuration]:
            self.blocks.append(VGG_block(*arg_list, self.batch_norm))
        self.blocks = nn.Sequential(*self.blocks)
        
        self.FC = nn.Sequential(
            nn.Conv2d(64, 512, kernel_size=2),
            nn.Dropout2d(p=dropout, inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=1),
            nn.Dropout2d(p=dropout, inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 10, kernel_size=1),
            # 把后两个空间维度合并成一个
            nn.Flatten(start_dim=2, end_dim=-1)
        )

    def print_num_params(self):
        total_params = sum(p.numel() for p in self.parameters())
        print(f'{total_params:,} total parameters.')
        total_trainable_params = sum(
            p.numel() for p in self.parameters() if p.requires_grad)
        print(f'{total_trainable_params:,} trainable parameters.')
        
    def forward(self, X):
        output = self.blocks(X)
        output = self.FC(output)
        # 这里为了实现简单，把空间平均的操作放在softmax前面了
        # 输出维度为 (`batch_size`, 10)
        return output.mean(dim=2)

In [8]:
def single_scale_eval_SQ(S):
    if isinstance(S, int):
        Q = S
    elif isinstance(S, list):
        assert(len(S) == 2)
        Q = int(0.5*(S[0] + S[1]))
    return S, Q

def multi_scale_eval_SQ(S):
    if isinstance(S, int):
        Q = [S-8, S, S+8]
    elif isinstance(S, list):
        assert(len(S) == 2)
        Q = [S[0], int(0.5*(S[0] + S[1])), S[1]]
    return S, Q

In [9]:
def get_S_and_Q(S, single_scale_eval):
    if single_scale_eval:
        return single_scale_eval_SQ(S)
    else:
        return multi_scale_eval_SQ(S)

In [10]:
class Evaluater:
    def __init__(self, S, batch_size, mode='single'):
        self.single_scale = mode=='single'
        self.S, self.Q = get_S_and_Q(S, self.single_scale)
        if self.single_scale:
            self.datasets = [TestDataset(self.Q, False), TestDataset(self.Q, True)]
        else:
            self.datasets = []
            for q in self.Q:
                self.datasets += [TestDataset(q, False), TestDataset(q, True)]
        self.dataloaders = [data.DataLoader(dataset,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers=8) for dataset in self.datasets]
    def evaluate(self, net, criterion):
        net.eval()
        loss, accuracy = [], []
        outputs = {}
        with torch.no_grad():
            for dataloader in self.dataloaders:
                for i, (input, _) in enumerate(dataloader):
                    input = input.to(device)
                    output = net(input)
                    try:
                        outputs[i] += F.softmax(output, dim=1)
                    except KeyError:
                        outputs[i] = F.softmax(output, dim=1)
            for i, (_, target) in enumerate(self.dataloaders[0]):
                target = target.to(device)
                loss.append(criterion(outputs[i] / len(self.datasets), target))
                accuracy.append((outputs[i].argmax(dim=1)==target).sum() / target.shape[0])
        loss = torch.tensor(loss).mean().item()
        accuracy = torch.tensor(accuracy).mean().item()
        return loss, accuracy

In [11]:
def get_lr(optimizer):
    return (optimizer.state_dict()['param_groups'][0]['lr'])

In [12]:
def train_VGG(net,
              batch_size,
              num_epochs,
              lr,
              evaluater,
              S=[64, 128],
              weight_decay=5e-4):

    writer = SummaryWriter(f'runs/VGG-mini-{net.configuration}'+'-batchnorm' if net.batch_norm else '')
    cifar_train = TrainDataset(S)
    train_iter = data.DataLoader(cifar_train, batch_size=batch_size,
                                 shuffle=True, num_workers=8)
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.normal_(m.weight, mean=0, std=0.1)
    net.apply(init_weights)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           factor=0.1,
                                                           patience=3,
                                                           threshold=1e-3,
                                                           verbose=True)
    criterion = nn.CrossEntropyLoss()
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        tic = time.time()
        # 训练损失之和，训练准确率之和，范例数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (input, target) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            input, target = input.to(device), target.to(device)
            output = net(input)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(loss * input.shape[0],
                           d2l.accuracy(output, target),
                           input.shape[0])
            timer.stop()
            train_loss = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            
            # if (i+1) % (num_batches//20) == 0:
            #     print(f"loss: {train_loss:.4f}, acc: {train_acc:.4f}")
        valid_loss, valid_acc = evaluater.evaluate(net, criterion)
        writer.add_scalar('train/loss', train_loss, global_step=epoch+1)
        writer.add_scalar('train/accuracy', train_acc, global_step=epoch+1)
        writer.add_scalar('valid/loss', valid_loss, global_step=epoch+1)
        writer.add_scalar('valid/accuracy', valid_acc, global_step=epoch+1)
        writer.add_scalar('learning rate', get_lr(optimizer), global_step=epoch+1)
        scheduler.step(valid_loss)
        toc = time.time()
        print(f"epoch {epoch+1:2d}, train loss: {train_loss:.4f}, train accuracy: {train_acc:.4f}, \
valid loss: {valid_loss:.4f}, valid accuracy: {valid_acc:.4f}, time: {toc-tic:.4f}")
    valid_loss, valid_acc = evaluater.evaluate(net, criterion)
    print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, '
          f'valid loss {valid_loss:.3f}, valid acc {valid_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

In [13]:
net = VGG_mini(configuration='A',
               batch_norm=True,
               dropout=0.5).to(device)
net.print_num_params()
S = [64, 128]
mode = 'multi'
evaluater = Evaluater(S, mode=mode, batch_size=256)
net

544,618 total parameters.
544,618 trainable parameters.
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


VGG_mini(
  (blocks): Sequential(
    (0): VGG_block(
      (block): Sequential(
        (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (1): VGG_block(
      (block): Sequential(
        (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (2): VGG_block(
      (block): Sequential(
        (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): 

In [14]:
evaluater.evaluate(net, nn.CrossEntropyLoss())

(2.302577495574951, 0.10205078125)

In [15]:
train_VGG(net,
          batch_size=256,
          num_epochs=70,
          lr=1e-3,
          evaluater=evaluater,
          S=[64, 128],
          weight_decay=5e-4)

Files already downloaded and verified
epoch  1, train loss: 2.6515, train accuracy: 0.1854, valid loss: 2.2295, valid accuracy: 0.2988, time: 12.5906
epoch  2, train loss: 1.9511, train accuracy: 0.2614, valid loss: 2.1940, valid accuracy: 0.3642, time: 13.1923
epoch  3, train loss: 1.8434, train accuracy: 0.3063, valid loss: 2.1711, valid accuracy: 0.4008, time: 12.8884
epoch  4, train loss: 1.7576, train accuracy: 0.3432, valid loss: 2.1673, valid accuracy: 0.4279, time: 12.5945
epoch  5, train loss: 1.6900, train accuracy: 0.3739, valid loss: 2.1620, valid accuracy: 0.4605, time: 12.5882
epoch  6, train loss: 1.6283, train accuracy: 0.3983, valid loss: 2.1356, valid accuracy: 0.4918, time: 12.5929
epoch  7, train loss: 1.5745, train accuracy: 0.4230, valid loss: 2.1312, valid accuracy: 0.4603, time: 12.5385
epoch  8, train loss: 1.5124, train accuracy: 0.4536, valid loss: 2.1052, valid accuracy: 0.4648, time: 12.6961
epoch  9, train loss: 1.4602, train accuracy: 0.4735, valid loss: 

In [16]:
evaluater.evaluate(net, nn.CrossEntropyLoss())

(1.778253197669983, 0.8248046636581421)

In [17]:
torch.save(net, f'VGG-mini-{net.configuration}' + '-batchnorm.pth' if net.batch_norm else '.pth')