In [1]:
%matplotlib inline
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, Dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10

In [2]:
# a special module that converts [batch, channel, w, h] to [batch, units]
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [3]:
cuda_availability = torch.cuda.is_available()
if cuda_availability:
    device = torch.device('cuda:{}'.format(torch.cuda.current_device()))
else:
    device = 'cpu'

In [4]:
means = np.array((0.4914, 0.4822, 0.4465))
stds = np.array((0.2023, 0.1994, 0.2010))

transform_augment = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomRotation([-30, 30]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(means, stds),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(means, stds),
])

In [5]:
class DatasetFromSubset(Dataset):
    """
    https://discuss.pytorch.org/t/torch-utils-data-dataset-random-split/32209/3
    """

    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

In [6]:
train_val_data = CIFAR10("./cifar_data/", train=True)

val_part = int(0.1 * len(train_val_data))
train_part = len(train_val_data) - val_part

train_subset, val_subset = random_split(train_val_data, [train_part, val_part])

train_loader = DatasetFromSubset(train_subset, transform=transform_augment)
val_loader = DatasetFromSubset(val_subset, transform=transform_test)
test_loader = CIFAR10("./cifar_data/", train=False, transform=transform_test)

In [7]:
def compute_loss(model, X_batch, y_batch):
    X_batch = torch.as_tensor(X_batch, dtype=torch.float32, device=device)
    y_batch = torch.as_tensor(y_batch, dtype=torch.int64, device=device)
    logits = model(X_batch)
    return F.cross_entropy(logits, y_batch).mean()

In [66]:
def get_accuracy(model):
    num_workers = 0 if device == 'cpu' else 4
    pin_memory = False
    
    test_batch_gen = torch.utils.data.DataLoader(test_loader,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=num_workers,
                                             pin_memory=pin_memory)
    
    model.train(False)
    test_batch_acc = []
    for X_batch, y_batch in test_batch_gen:
        X_batch = X_batch.to(device, non_blocking=True)
        logits = model(X_batch)
        y_pred = logits.max(1)[1].data.cpu().numpy()
        test_batch_acc.append(np.mean(y_batch.cpu().numpy() == y_pred))

    test_accuracy = np.mean(test_batch_acc)

    print("Final results:")
    print("  test accuracy:\t\t{:.2f} %".format(
        test_accuracy * 100))

    if test_accuracy * 100 > 95:
        print("Double-check, than consider applying for NIPS'17. SRSly.")
    elif test_accuracy * 100 > 90:
        print("U'r freakin' amazin'!")
    elif test_accuracy * 100 > 80:
        print("Achievement unlocked: 110lvl Warlock!")
    elif test_accuracy * 100 > 70:
        print("Achievement unlocked: 80lvl Warlock!")
    elif test_accuracy * 100 > 60:
        print("Achievement unlocked: 70lvl Warlock!")
    elif test_accuracy * 100 > 50:
        print("Achievement unlocked: 60lvl Warlock!")
    else:
        print("We need more magic!")

In [62]:
def train(model, opt, sch, batch_size=512, patience=10, num_epochs=10, model_name='simple_model', save=False):
    train_loss_ar = []
    val_accuracy_ar = []

    trigger_times = 0

    prev_val_accuracy = -np.Inf
    val_accuracy_max = -np.Inf

    num_workers = 0 if device == 'cpu' else 4
    pin_memory = False

    train_batch_gen = torch.utils.data.DataLoader(train_loader,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  num_workers=num_workers,
                                                  pin_memory=pin_memory)

    val_batch_gen = torch.utils.data.DataLoader(val_loader,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                num_workers=num_workers,
                                                pin_memory=pin_memory)

    for epoch in range(num_epochs):
        start_time = time.time()

        model.train(True)
        for X_batch, y_batch in train_batch_gen:
            X_batch, y_batch = X_batch.to(device, non_blocking=True), y_batch.to(device, non_blocking=True)
            loss = compute_loss(model, X_batch, y_batch)
            loss.backward()
            opt.step()
            sch.step()
            opt.zero_grad(set_to_none=True)
            train_loss_ar.append(loss.item())

        model.train(False)
        for X_batch, y_batch in val_batch_gen:
            X_batch, y_batch = X_batch.to(device, non_blocking=True), y_batch.to(device, non_blocking=True)
            logits = model(torch.as_tensor(X_batch, dtype=torch.float32, device=device))
            y_pred = logits.max(1)[1].data.cpu().numpy()
            val_accuracy_ar.append(np.mean((y_batch.cpu().numpy() == y_pred)))

        val_accuracy = np.mean(val_accuracy_ar[-len(val_loader) // batch_size:])

        if val_accuracy > val_accuracy_max and save:
            torch.save({
                'epoch': epoch,
                'valid_accuracy': val_accuracy,
                'state_dict': model.state_dict(),
                'optimizer': opt.state_dict(),
            }, 'models/' + model_name + f'_{epoch+1}.pth')

            val_accuracy_max = val_accuracy

        if val_accuracy < prev_val_accuracy:
            trigger_times += 1
            if trigger_times >= patience:
                print('Early stopping\n')
                get_accuracy(model)
                break
        else:
            trigger_times = 0

        prev_val_accuracy = val_accuracy

        print("Epoch {} of {} took {:.3f}s".format(
            epoch + 1, num_epochs, time.time() - start_time))
        print("  training loss (in-iteration): \t{:.6f}".format(
            np.mean(train_loss_ar[-len(train_loader) // batch_size:])))
        print("  validation accuracy: \t\t\t{:.2f} %".format(
            val_accuracy * 100))
        print("  learning rate: \t\t\t{:.4f}".format(
            sch.get_last_lr()[0]))

    get_accuracy(model)

### Simple model

In [46]:
simple_model = nn.Sequential(
    nn.Conv2d(3, 32, kernel_size=(5, 5), bias=False),
    nn.BatchNorm2d(32),
    nn.MaxPool2d(kernel_size=(3, 3)),
    nn.GELU(),
    nn.Conv2d(32, 64, kernel_size=(5, 5), bias=False),
    nn.BatchNorm2d(64),
    nn.GELU(),
    Flatten(),
    nn.Linear(1600, 64),
    nn.GELU(),
    nn.Linear(64, 10)
).to(device)

In [52]:
n_epochs = 10
batch_size = 512
patience = 5
steps_per_epoch = len(train_loader)//batch_size + 1 if len(train_loader)%batch_size != 0 else len(train_loader)//batch_size

optimizer = torch.optim.AdamW(simple_model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=steps_per_epoch, epochs=n_epochs)

In [53]:
train(simple_model, optimizer, sch=scheduler, batch_size=batch_size, patience=patience, num_epochs=n_epochs)

Epoch 1 of 10 took 7.988s
  training loss (in-iteration): 	0.942151
  validation accuracy: 			68.96 %
  learning rate: 			0.0028 %
Epoch 2 of 10 took 7.712s
  training loss (in-iteration): 	1.017422
  validation accuracy: 			66.21 %
  learning rate: 			0.0076 %
Epoch 3 of 10 took 7.843s
  training loss (in-iteration): 	1.094179
  validation accuracy: 			64.32 %
  learning rate: 			0.0100 %
Epoch 4 of 10 took 7.566s
  training loss (in-iteration): 	1.072090
  validation accuracy: 			66.85 %
  learning rate: 			0.0095 %
Epoch 5 of 10 took 8.052s
  training loss (in-iteration): 	1.038102
  validation accuracy: 			67.48 %
  learning rate: 			0.0081 %
Epoch 6 of 10 took 7.636s
  training loss (in-iteration): 	0.990658
  validation accuracy: 			70.42 %
  learning rate: 			0.0061 %
Epoch 7 of 10 took 7.746s
  training loss (in-iteration): 	0.942064
  validation accuracy: 			72.13 %
  learning rate: 			0.0039 %
Epoch 8 of 10 took 7.677s
  training loss (in-iteration): 	0.902883
  validation ac

### CNN v.1

In [34]:
class ConvBlock(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size, padding, activation, dropout, batch_norm, pool=None, dilation=1):
        super().__init__()
        self.pool = pool
        self.out_channels = out_dim
        self.batch_norm = batch_norm

        self.batch_norm_2d = nn.BatchNorm2d(out_dim)
        self.activation = activation
        self.dropout = nn.Dropout(dropout)

        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=(kernel_size, kernel_size), padding=padding, dilation=dilation)

        if pool == 'max':
            self.pool = nn.MaxPool2d(2, stride=2)
        elif pool == 'mean':
            self.pool = nn.AvgPool2d(2)
        elif pool is None:
            self.pool = pool
        else:
            raise NotImplementedError

    def forward(self, x):
        x = self.conv(x)
        if self.pool:
            x = self.pool(x)
        if self.batch_norm:
            x = self.batch_norm_2d(x)
        if self.activation:
            x = self.activation(x)

        x = self.dropout(x)
        return x


class CNNv1(nn.Module):
    def __init__(self):
        super().__init__()
        self.activation = nn.modules.activation.GELU()
        self.dropout = 0.05
        self.batch_norm = True

        self.conv1 = ConvBlock(3, 64, 7, 3, self.activation, self.dropout, self.batch_norm)
        self.conv2 = ConvBlock(64, 64, 3, 1, self.activation, self.dropout, self.batch_norm, pool='max')
        self.conv3 = ConvBlock(64, 128, 3, 1, self.activation, self.dropout, self.batch_norm)
        self.conv4 = ConvBlock(128, 128, 3, 1, self.activation, self.dropout, self.batch_norm, pool='max')
        self.conv5 = ConvBlock(128, 256, 3, 1, self.activation, self.dropout, self.batch_norm)
        self.conv6 = ConvBlock(256, 256, 3, 1, self.activation, self.dropout, self.batch_norm, pool='max')
        self.conv7 = ConvBlock(256, 512, 3, 1, self.activation, self.dropout, self.batch_norm)
        self.conv8 = ConvBlock(512, 512, 3, 1, self.activation, self.dropout, self.batch_norm, pool='max')

        self.fc = nn.Linear(2048 , 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)

        x = x.view(x.size(0), -1)

        x = self.fc(x)
        return x

In [56]:
# model_v1 = CNNv1().to(device)

# n_epochs = 100
# batch_size = 512
# patience = 5
# steps_per_epoch = len(train_loader)//batch_size + 1 if len(train_loader)%batch_size != 0 else len(train_loader)//batch_size

# optimizer = torch.optim.AdamW(model_v1.parameters(), lr=1e-3, weight_decay=1e-2)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2, steps_per_epoch=steps_per_epoch, epochs=n_epochs)

# train(model_v1, optimizer, sch=scheduler, batch_size=batch_size, patience=patience, num_epochs=n_epochs, model_name='CNNv1', save=True)

Epoch 1 of 100 took 16.631s
  training loss (in-iteration): 	1.677367
  validation accuracy: 			46.05 %
  learning rate: 			0.0004
Epoch 2 of 100 took 16.643s
  training loss (in-iteration): 	1.279906
  validation accuracy: 			60.00 %
  learning rate: 			0.0005
Epoch 3 of 100 took 16.932s
  training loss (in-iteration): 	1.090327
  validation accuracy: 			67.76 %
  learning rate: 			0.0006
Epoch 4 of 100 took 16.983s
  training loss (in-iteration): 	0.975070
  validation accuracy: 			67.92 %
  learning rate: 			0.0008
Epoch 5 of 100 took 17.140s
  training loss (in-iteration): 	0.883390
  validation accuracy: 			73.33 %
  learning rate: 			0.0010
Epoch 6 of 100 took 17.336s
  training loss (in-iteration): 	0.814127
  validation accuracy: 			76.65 %
  learning rate: 			0.0013
Epoch 7 of 100 took 17.548s
  training loss (in-iteration): 	0.769796
  validation accuracy: 			78.36 %
  learning rate: 			0.0016
Epoch 8 of 100 took 17.399s
  training loss (in-iteration): 	0.726065
  validation 

In [68]:
model_v1_best = CNNv1().to(device)

checkpoint = torch.load('models/CNNv1_95.pth')
model_v1_best.load_state_dict(checkpoint['state_dict'])

model_v1_best.eval()

CNNv1(
  (activation): GELU()
  (conv1): ConvBlock(
    (batch_norm_2d): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): GELU()
    (dropout): Dropout(p=0.05, inplace=False)
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  )
  (conv2): ConvBlock(
    (batch_norm_2d): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): GELU()
    (dropout): Dropout(p=0.05, inplace=False)
    (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv3): ConvBlock(
    (batch_norm_2d): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): GELU()
    (dropout): Dropout(p=0.05, inplace=False)
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conv4): ConvBlock(
    (batch_norm_2d): BatchNorm2d

In [73]:
get_accuracy(model_v1_best)

Final results:
  test accuracy:		90.87 %
U'r freakin' amazin'!


### Iteration path

1. Выбрана модель, на которой была возможность быстро тестировать независимые от архитектуры фичи. С ее помощью были добавлены:
    * Автоматическое сохранение параметров моделей с лучшим значением accuracy
    * 1cycle policy
    * Early stopping
2. Собрана сеть из 8 сверточных слоев и одного полносвязного (CNNv1)


