In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline
from d2l import torch as d2l
import random
import time
import pandas as pd
from PIL import Image
from modules import *
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [2]:
# 百度来的，不然下载不动。。
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [3]:
cifar_train = torchvision.datasets.CIFAR10(root="../data", train=True, download=True)
print(cifar_train.data.shape) # (50000, 32, 32, 3)
cifardata = cifar_train.data / 255
mean_pic = torch.tensor(cifardata.mean(axis=(0))).permute(2, 0, 1)
print(mean_pic.shape)

Files already downloaded and verified
(50000, 32, 32, 3)
torch.Size([3, 32, 32])


In [4]:
train_and_valid = data.random_split(torchvision.datasets.CIFAR10(root="../data", train=True, download=True),
                                    [45000, 5000],
                                    generator=torch.Generator().manual_seed(42))

Files already downloaded and verified


In [5]:
class TrainDataset(data.Dataset):
    def __init__(self, dataset, aug=True):
        super().__init__()
        self.dataset = dataset
        if aug:
            self.trans = transforms.Compose([transforms.ToTensor(),
                                             transforms.Lambda(lambda pic: pic-mean_pic.to(pic.device)),
                                             transforms.RandomCrop(32, padding=4),
                                             transforms.RandomHorizontalFlip(p=0.5),
                                             transforms.ConvertImageDtype(torch.float)])
        else:
            self.trans = transforms.Compose([transforms.ToTensor(),
                                             transforms.Lambda(lambda pic: pic-mean_pic.to(pic.device)),
                                             transforms.ConvertImageDtype(torch.float)])
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        return (self.trans(self.dataset[index][0]),
                self.dataset[index][1])

In [6]:
class TestDataset(data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.trans = transforms.Compose([transforms.ToTensor(),
                                         transforms.Lambda(lambda pic: pic-mean_pic.to(pic.device)),
                                         transforms.ConvertImageDtype(torch.float)])
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        return (self.trans(self.dataset[index][0]),
                self.dataset[index][1])

In [7]:
train_dataset = TrainDataset(train_and_valid[0], aug=True)
valid_dataset = TestDataset(train_and_valid[1])
test_dataset = TestDataset(torchvision.datasets.CIFAR10(root="../data", train=False, download=True))

Files already downloaded and verified


In [8]:
def evaluate_loss_acc(net, data_iter, criterion, device=device):
    """使用GPU计算模型在数据集上的精度。"""
    net.eval()  # 设置为评估模式
    loss = []
    # 正确预测的数量，总预测的数量
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for input, target in data_iter:
            input = input.to(device)
            target = target.to(device)
            output = net(input)
            loss.append(float(criterion(output, target).item()))
            metric.add(d2l.accuracy(output, target), target.numel())
    return sum(loss) / len(loss), metric[0] / metric[1]

In [9]:
def get_lr(optimizer):
    return (optimizer.state_dict()['param_groups'][0]['lr'])

In [10]:
def train_DenseNet(net,
                   batch_size,
                   lr,
                   num_epochs,
                   weight_decay=1e-4):

    train_iter = data.DataLoader(train_dataset, batch_size=batch_size,
                                 shuffle=True, num_workers=8)
    valid_iter = data.DataLoader(valid_dataset, batch_size=batch_size, 
                                 shuffle=False, num_workers=8)
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='relu')
    net.apply(init_weights)
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=lr,
                                weight_decay=weight_decay,
                                momentum=0.9)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2, verbose=True)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, threshold=0.001, verbose=True)
    scheduler_name = str(scheduler.__class__).split('.')[-1][:-2]
    writer = SummaryWriter(f'runs/DenseNet_CIFAR_L={net.L}_k={net.k}_theta={net.theta}')
    criterion = nn.CrossEntropyLoss()
    timer, num_batches = d2l.Timer(), len(train_iter)
    for epoch in range(num_epochs):
        tic = time.time()
        metric = d2l.Accumulator(3)
        net.train()
        for i, (input, target) in enumerate(train_iter):
            timer.start()
            optimizer.zero_grad()
            input, target = input.to(device), target.to(device)
            output = net(input)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(loss * input.shape[0],
                           d2l.accuracy(output, target),
                           input.shape[0])
            timer.stop()
            train_loss = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
        valid_loss, valid_acc = evaluate_loss_acc(net, valid_iter, criterion, device)
        writer.add_scalar('train/loss', train_loss, global_step=epoch+1)
        writer.add_scalar('train/accuracy', train_acc, global_step=epoch+1)
        writer.add_scalar('valid/loss', valid_loss, global_step=epoch+1)
        writer.add_scalar('valid/accuracy', valid_acc, global_step=epoch+1)
        writer.add_scalar('learning rate', get_lr(optimizer), global_step=epoch+1)
        # scheduler.step()
        scheduler.step(valid_loss)
        toc = time.time()
        print(f"epoch {epoch+1:3d}, train loss: {train_loss:.4f}, train accuracy: {train_acc:.4f}, \
valid loss: {valid_loss:.4f}, valid accuracy: {valid_acc:.4f}, time: {toc-tic:.4f}")
    print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, '
          f'valid loss {valid_loss:.3f}, valid acc {valid_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')

In [11]:
net = DenseNet(L=40, k=12, theta=0.5, block=Bottleneck, num_classes=10).to(device)
net.print_num_params()

486,436 total parameters.
486,436 trainable parameters.


In [12]:
train_DenseNet(net,
               batch_size=64,
               lr=0.1,
               num_epochs=50,
               weight_decay=1e-4)

epoch   1, train loss: 1.6552, train accuracy: 0.3928, valid loss: 1.6033, valid accuracy: 0.4588, time: 52.6598
epoch   2, train loss: 1.1895, train accuracy: 0.5782, valid loss: 0.9570, valid accuracy: 0.6662, time: 51.3201
epoch   3, train loss: 0.9363, train accuracy: 0.6723, valid loss: 1.2877, valid accuracy: 0.5914, time: 52.8655
epoch   4, train loss: 0.7862, train accuracy: 0.7285, valid loss: 1.0148, valid accuracy: 0.6826, time: 52.0783
epoch   5, train loss: 0.6905, train accuracy: 0.7637, valid loss: 0.7683, valid accuracy: 0.7448, time: 50.9403
epoch   6, train loss: 0.6227, train accuracy: 0.7874, valid loss: 0.6483, valid accuracy: 0.7812, time: 50.7741
epoch   7, train loss: 0.5717, train accuracy: 0.8045, valid loss: 0.6400, valid accuracy: 0.7810, time: 51.9899
epoch   8, train loss: 0.5434, train accuracy: 0.8151, valid loss: 0.7675, valid accuracy: 0.7536, time: 51.9333
epoch   9, train loss: 0.5217, train accuracy: 0.8226, valid loss: 0.6384, valid accuracy: 0.789

KeyboardInterrupt: 

In [None]:
test_iter = data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=0)
test_loss, test_acc = evaluate_loss_acc(net, test_iter, nn.CrossEntropyLoss())
print(test_loss, test_acc)

In [None]:
torch.save(net.state_dict(), f'DenseNet_CIFAR_L={net.L}_k={net.k}_theta={net.theta}_acc={test_acc}.pth')