In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline
from d2l import torch as d2l
import random
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
import time
from modules import *

In [2]:
# 百度来的，不然下载不动。。
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [3]:
cifar_train = torchvision.datasets.CIFAR10(root="../data", train=True, download=True)
print(cifar_train.data.shape) # (50000, 32, 32, 3)
cifardata = cifar_train.data / 255
mean_pic = torch.tensor(cifardata.mean(axis=(0))).permute(2, 0, 1)
print(mean_pic.shape)

Files already downloaded and verified
(50000, 32, 32, 3)
torch.Size([3, 32, 32])


In [4]:
train_and_valid = data.random_split(torchvision.datasets.CIFAR10(root="../data", train=True, download=True),
                                    [45000, 5000],
                                    generator=torch.Generator().manual_seed(42))

Files already downloaded and verified


In [5]:
class TrainDataset(data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.trans = transforms.Compose([transforms.ToTensor(),
                                         transforms.Lambda(lambda pic: pic-mean_pic.to(pic.device)),
                                         transforms.RandomCrop(32, padding=4),
                                         transforms.RandomHorizontalFlip(p=0.5),
                                         transforms.ConvertImageDtype(torch.float)])
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        return (self.trans(self.dataset[index][0]),
                self.dataset[index][1])

In [6]:
class TestDataset(data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.trans = transforms.Compose([transforms.ToTensor(),
                                         transforms.Lambda(lambda pic: pic-mean_pic.to(pic.device)),
                                         transforms.ConvertImageDtype(torch.float)])
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        return (self.trans(self.dataset[index][0]),
                self.dataset[index][1])

In [7]:
train_dataset = TrainDataset(train_and_valid[0])
valid_dataset = TestDataset(train_and_valid[1])
test_dataset = TestDataset(torchvision.datasets.CIFAR10(root="../data", train=False, download=True))

Files already downloaded and verified


In [8]:
def evaluate_loss_acc(net, data_iter, criterion, device=device):
    """使用GPU计算模型在数据集上的精度。"""
    net.eval()  # 设置为评估模式
    loss = []
    # 正确预测的数量，总预测的数量
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for input, target in data_iter:
            input = input.to(device)
            target = target.to(device)
            output = net(input)
            loss.append(float(criterion(output, target).item()))
            metric.add(d2l.accuracy(output, target), target.numel())
    return sum(loss) / len(loss), metric[0] / metric[1]

In [9]:
def get_lr(optimizer):
    return (optimizer.state_dict()['param_groups'][0]['lr'])

In [10]:
def train_ResNet(net,
                 batch_size,
                 lr,
                 num_epochs,
                 weight_decay=1e-4):

    train_iter = data.DataLoader(train_dataset, batch_size=batch_size,
                                 shuffle=True, num_workers=0)
    valid_iter = data.DataLoader(valid_dataset, batch_size=batch_size, 
                                 shuffle=False, num_workers=0)
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
    net.apply(init_weights)
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=lr,
                                weight_decay=weight_decay,
                                momentum=0.9)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2, verbose=True)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, threshold=0.0001, verbose=True)
    scheduler_name = str(scheduler.__class__).split('.')[-1][:-2]
    writer = SummaryWriter(f'runs/ResNet_CIFAR_n={net.n}_{net.option}_bn={net.batch_norm}_{scheduler_name}_weight_decay={weight_decay}')
    criterion = nn.CrossEntropyLoss()
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        tic = time.time()
        metric = d2l.Accumulator(3)
        net.train()
        for i, (input, target) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            input, target = input.to(device), target.to(device)
            output = net(input)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(loss * input.shape[0],
                           d2l.accuracy(output, target),
                           input.shape[0])
            timer.stop()
            train_loss = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
        valid_loss, valid_acc = evaluate_loss_acc(net, valid_iter, criterion, device)
        writer.add_scalar('train/loss', train_loss, global_step=epoch+1)
        writer.add_scalar('train/accuracy', train_acc, global_step=epoch+1)
        writer.add_scalar('valid/loss', valid_loss, global_step=epoch+1)
        writer.add_scalar('valid/accuracy', valid_acc, global_step=epoch+1)
        writer.add_scalar('learning rate', get_lr(optimizer), global_step=epoch+1)
        # scheduler.step()
        scheduler.step(valid_loss)
        toc = time.time()
        print(f"epoch {epoch+1:3d}, train loss: {train_loss:.4f}, train accuracy: {train_acc:.4f}, \
valid loss: {valid_loss:.4f}, valid accuracy: {valid_acc:.4f}, time: {toc-tic:.4f}")
    print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, '
          f'valid loss {valid_loss:.3f}, valid acc {valid_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

In [11]:
net = ResNet_CIFAR(n=18, option='A', batch_norm=True, dropout=0.5).to(device)
net.print_num_params()

1,727,962 total parameters.
1,727,962 trainable parameters.


In [None]:
train_ResNet(net,
             batch_size=256,
             lr=0.1,
             num_epochs=300,
             weight_decay=5e-4)

epoch   1, train loss: 2.8001, train accuracy: 0.1223, valid loss: 2.2500, valid accuracy: 0.1494, time: 61.6174
epoch   2, train loss: 2.0884, train accuracy: 0.1987, valid loss: 1.8948, valid accuracy: 0.2668, time: 61.0321
epoch   3, train loss: 1.8530, train accuracy: 0.2825, valid loss: 1.7821, valid accuracy: 0.3180, time: 60.9301
epoch   4, train loss: 1.7036, train accuracy: 0.3550, valid loss: 1.5742, valid accuracy: 0.4108, time: 61.1335
epoch   5, train loss: 1.5688, train accuracy: 0.4148, valid loss: 1.4644, valid accuracy: 0.4670, time: 61.1319
epoch   6, train loss: 1.4149, train accuracy: 0.4835, valid loss: 1.3594, valid accuracy: 0.4880, time: 61.2271
epoch   7, train loss: 1.2613, train accuracy: 0.5469, valid loss: 1.5640, valid accuracy: 0.4814, time: 61.2310
epoch   8, train loss: 1.1194, train accuracy: 0.6059, valid loss: 1.6668, valid accuracy: 0.5138, time: 61.1517
epoch   9, train loss: 1.0075, train accuracy: 0.6513, valid loss: 1.3836, valid accuracy: 0.572

In [None]:
test_iter = data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=0)
test_loss, test_acc = evaluate_loss_acc(net, test_iter, nn.CrossEntropyLoss())
print(test_loss, test_acc)

In [None]:
torch.save(net.state_dict(), f'ResNet_CIFAR_n={net.n}_{net.option}_acc={test_acc}.pth')