In [1]:
import torch
import torch.nn as nn
from torch.utils import data
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline
from d2l import torch as d2l
import random
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [2]:
# 百度来的，不然下载不动。。
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [3]:
class TrainDataset(data.Dataset):
    def __init__(self, S):
        super().__init__()
        self.dataset = torchvision.datasets.CIFAR10(
            root="../data", train=True, download=True)
        
        self.S = S
        
        self.trans = [transforms.RandomCrop(224),
                      transforms.RandomHorizontalFlip(p=0.5),
                      transforms.ColorJitter(brightness=0.1,
                                             contrast=0.1,
                                             saturation=0.1,
                                             hue=0),
                      transforms.ToTensor()]
        self.trans = transforms.Compose(self.trans)
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        if isinstance(self.S, int):
            resize = transforms.Resize(max(224, self.S))
        elif isinstance(self.S, list):
            assert(len(self.S) == 2)
            resize = transforms.Resize(
                random.randint(self.S[0], self.S[1]))
            
        return (self.trans(resize(self.dataset[index][0])),
                self.dataset[index][1])

In [4]:
class TestDataset(data.Dataset):
    def __init__(self, Q, horizontal_flip=False):
        super().__init__()
        self.dataset = torchvision.datasets.CIFAR10(
            root="../data", train=False, download=True)
        
        self.Q = Q
        
        self.trans = [transforms.ToTensor()]
        if horizontal_flip:
            self.trans.insert(0, transforms.RandomHorizontalFlip(p=1))
        self.trans = transforms.Compose(self.trans)
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        assert(isinstance(self.Q, int))
        resize = transforms.Resize(max(224, self.Q))
            
        return (self.trans(resize(self.dataset[index][0])),
                self.dataset[index][1])

In [5]:
class VGG_block(nn.Module):
    def __init__(self, in_channels, out_channels, num_3x3, conv_1x1=False):
        super().__init__()
        layers = []
        layers += [nn.Conv2d(in_channels, out_channels,
                             kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True)]
        if num_3x3 > 1:
            for i in range(1, num_3x3):
                layers += [nn.Conv2d(out_channels, out_channels,
                                     kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True)]
        if conv_1x1:
            layers += [nn.Conv2d(out_channels, out_channels,
                                 kernel_size=1, stride=1, padding=0), nn.ReLU(inplace=True)]
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        
        self.block = nn.Sequential(*layers)
    
    def forward(self, X):
        return self.block(X)

In [6]:
class VGG(nn.Module):
    def __init__(self, configuration):
        super().__init__()
        self.configurations = {
            'A': [[3,   64,  1, False],
                  [64,  128, 1, False],
                  [128, 256, 2, False],
                  [256, 512, 2, False],
                  [512, 512, 2, False]],
            
            'B': [[3,   64,  2, False],
                  [64,  128, 2, False],
                  [128, 256, 2, False],
                  [256, 512, 2, False],
                  [512, 512, 2, False]],
            
            'C': [[3,   64,  2, False],
                  [64,  128, 2, False],
                  [128, 256, 2, True],
                  [256, 512, 2, True],
                  [512, 512, 2, True]],
            
            'D': [[3,   64,  2, False],
                  [64,  128, 2, False],
                  [128, 256, 3, False],
                  [256, 512, 3, False],
                  [512, 512, 3, False]],
            
            'E': [[3,   64,  2, False],
                  [64,  128, 2, False],
                  [128, 256, 4, False],
                  [256, 512, 4, False],
                  [512, 512, 4, False]]
        }
        self.blocks = []
        for arg_list in self.configurations[configuration]:
            self.blocks.append(VGG_block(*arg_list))
        self.blocks = nn.Sequential(*self.blocks)
        
        self.FC = nn.Sequential(
            nn.Conv2d(512, 4096, kernel_size=7),
            nn.Dropout2d(p=0.5, inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(4096, 4096, kernel_size=1),
            nn.Dropout2d(p=0.5, inplace=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(4096, 10, kernel_size=1),
            # 把后两个空间维度合并成一个
            nn.Flatten(start_dim=2, end_dim=-1)
        )
        
    def print_num_params(self):
        total_params = sum(p.numel() for p in self.parameters())
        print(f'{total_params:,} total parameters.')
        total_trainable_params = sum(
            p.numel() for p in self.parameters() if p.requires_grad)
        print(f'{total_trainable_params:,} trainable parameters.')
        
    def forward(self, X):
        output = self.blocks(X)
        output = self.FC(output)
        # 这里为了实现简单，把空间平均的操作放在softmax前面了
        # 输出维度为 (`batch_size`, 1000)
        return output.mean(dim=2)

In [7]:
def single_scale_eval_SQ(S):
    if isinstance(S, int):
        Q = S
    elif isinstance(S, list):
        assert(len(S) == 2)
        Q = 0.5*(S[0] + S[1])
    return S, Q

def multi_scale_eval_SQ(S):
    if isinstance(S, int):
        Q = [S-32, S, S+32]
    elif isinstance(S, list):
        assert(len(S) == 2)
        Q = [S[0], 0.5*(S[0] + S[1]), S[1]]
    return S, Q

In [8]:
def get_S_and_Q(S, single_scale_eval):
    if single_scale_eval:
        return single_scale_eval_SQ(S)
    else:
        return multi_scale_eval_SQ(S)

In [9]:
class Evaluater:
    def __init__(self, S, batch_size, mode='single'):
        self.single_scale = True if mode=='single' else False
        self.S, self.Q = get_S_and_Q(S, self.single_scale)
        if self.single_scale:
            self.datasets = [TestDataset(self.Q, False), TestDataset(self.Q, True)]
        else:
            self.datasets = []
            for q in self.Q:
                self.datasets += [TestDataset(q, False), TestDataset(q, True)]
        self.dataloaders = [data.DataLoader(dataset,
                                            batch_size=batch_size,
                                            shuffle=False) for dataset in self.datasets]
    def evaluate(self, net, criterion):
        net.eval()
        loss, accuracy = [], []
        outputs = {}
        with torch.no_grad():
            for dataloader in self.dataloaders:
                for i, (input, _) in enumerate(dataloader):
                    input = input.to(device)
                    output = net(input)
                    try:
                        outputs[i] += output
                    except KeyError:
                        outputs[i] = output
            for i, (_, target) in enumerate(self.dataloaders[0]):
                target = target.to(device)
                loss.append(criterion(outputs[i] / len(self.datasets), target))
                accuracy.append((outputs[i].argmax(dim=1)==target).sum() / target.shape[0])
        loss = torch.tensor(loss).mean().item()
        accuracy = torch.tensor(accuracy).mean().item()
        return loss, accuracy

In [10]:
def train_VGG(net,
              batch_size,
              num_epochs,
              lr,
              evaluater,
              S=256,
              weight_decay=5e-4):
    
    cifar_train = TrainDataset(S)
    train_iter = data.DataLoader(cifar_train, batch_size=batch_size,
                                 shuffle=True)
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.normal_(m.weight, mean=0, std=0.1)
    net.apply(init_weights)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           factor=0.1,
                                                           patience=1,
                                                           threshold=1e-2,
                                                           verbose=True)
    criterion = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc',
                                    'valid loss', 'valid acc'],
                           figsize=(7.5, 5.5))
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        # 训练损失之和，训练准确率之和，范例数
        metric = d2l.Accumulator(3)
        net.train()
        for i, (input, target) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            input, target = input.to(device), target.to(device)
            output = net(input)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(loss * input.shape[0],
                           d2l.accuracy(output, target),
                           input.shape[0])
            timer.stop()
            train_loss = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_loss, train_acc, None, None))
        if (epoch+1) % 5 == 0:
            valid_loss, valid_acc = evaluater.evaluate(net, criterion)
            scheduler.step(valid_loss)
            animator.add(epoch + 1, (None, None, valid_loss, valid_acc))
    print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, '
          f'valid loss {valid_loss:.3f}, valid acc {valid_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

In [11]:
net = VGG(configuration='A').to(device)
S = 256
mode = 'single'
evaluater = Evaluater(S, mode=mode, batch_size=256)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
train_VGG(net,
      batch_size=128,
      num_epochs=5,
      lr=1e-2,
      evaluater=evaluater,
      S=256,
      weight_decay=5e-4)

KeyboardInterrupt: 