In [1]:
import os
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

data_dir = os.environ.get('PYTORCH_DATA_DIR') or './data/cifar10/'

num_workers = 4
batch_size = 64

transform = torchvision.transforms.ToTensor()

train_set = datasets.CIFAR10(root = data_dir, train = True,
                             download = True, transform = transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size = batch_size,
                                    shuffle = True, num_workers = num_workers)

test_set = datasets.CIFAR10(root = data_dir, train = False,
                            download = True, transform = transform)

test_loader = torch.utils.data.DataLoader(test_set, batch_size = batch_size,
                                    shuffle = False, num_workers = num_workers)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar10/cifar-10-python.tar.gz to ./data/cifar10/




Files already downloaded and verified


In [2]:
from torch import nn
import torch.nn.functional as F

class ResBlock(nn.Module):
    def __init__(self, nb_channels, kernel_size):
        super().__init__()

        self.conv1 = nn.Conv2d(nb_channels, nb_channels, kernel_size,
                               padding = (kernel_size-1)//2)
        self.bn1 = nn.BatchNorm2d(nb_channels)

        self.conv2 = nn.Conv2d(nb_channels, nb_channels, kernel_size,
                               padding = (kernel_size-1)//2)
        self.bn2 = nn.BatchNorm2d(nb_channels)

    def forward(self, x):
        y = self.bn1(self.conv1(x))
        y = F.relu(y)
        y = self.bn1(self.conv1(x))
        y += x
        y = F.relu(y)
        return y

In [3]:
class Monster(nn.Module):
    def __init__(self, nb_blocks, nb_channels):
        super().__init__()
        
        alexnet = torchvision.models.alexnet(weights = 'IMAGENET1K_V1')

        self.features = nn.Sequential(alexnet.features[0], nn.ReLU(inplace = True))

        dummy = self.features(torch.zeros(1, 3, 32, 32)).size()
        alexnet_nb_channels = dummy[1]
        alexnet_map_size = tuple(dummy[2:4])

        self.conv = nn.Conv2d(alexnet_nb_channels, nb_channels, kernel_size = 1)

        self.resblocks = nn.Sequential(
            *(ResBlock(nb_channels, kernel_size = 3) for _ in range(nb_blocks))
        )

        self.avg = nn.AvgPool2d(kernel_size = alexnet_map_size)
        self.fc = nn.Linear(nb_channels, 10)

    def forward(self, x):
        x = self.features(x)
        x = F.relu(self.conv(x))
        x = self.resblocks(x)
        x = F.relu(self.avg(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [22]:
nb_epochs = 50
nb_blocks, nb_channels = 8, 64
device = 'cuda'

model, criterion = Monster(nb_blocks, nb_channels), nn.CrossEntropyLoss()

model.to(device)
criterion.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr = 1e-2, momentum = 0.9)

for e in range(nb_epochs):
    for p in model.features.parameters():
        p.requires_grad = e >= nb_epochs // 2

    acc_loss = 0.0

    print(f'Starting epoch {e}...')
    mini_batch = 0
    for input, targets in iter(train_loader):
        input, targets = input.to(device), targets.to(device)

        output = model(input)
        preds = torch.argmax(output.data, 1)
        diff_count = torch.count_nonzero(preds - targets)
        batch_size = targets.size(0)

        training_error = float(diff_count) / batch_size

        loss = criterion(output, targets)
        acc_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        mini_batch += 1
        print(f'Epoch {e}: finishing mini batch {mini_batch}, training error = {training_error}, loss = {loss.item()}')

    print(f'Epoch {e} completed, acc_loss = {acc_loss}')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 43: finishing mini batch 488, training error = 0.09375, loss = 0.24551428854465485
Epoch 43: finishing mini batch 489, training error = 0.078125, loss = 0.2678520679473877
Epoch 43: finishing mini batch 490, training error = 0.015625, loss = 0.052144475281238556
Epoch 43: finishing mini batch 491, training error = 0.0625, loss = 0.16373704373836517
Epoch 43: finishing mini batch 492, training error = 0.109375, loss = 0.2425498515367508
Epoch 43: finishing mini batch 493, training error = 0.078125, loss = 0.2730174660682678
Epoch 43: finishing mini batch 494, training error = 0.0625, loss = 0.1500380039215088
Epoch 43: finishing mini batch 495, training error = 0.09375, loss = 0.17088575661182404
Epoch 43: finishing mini batch 496, training error = 0.03125, loss = 0.0829305648803711
Epoch 43: finishing mini batch 497, training error = 0.03125, loss = 0.07369601726531982
Epoch 43: finishing mini batch 498, training er

In [24]:
nb_test_errors, nb_test_samples = 0, 0

model.eval()

mini_batch = 0
for input, targets in iter(test_loader):
    input, targets = input.to(device), targets.to(device)

    output = model(input)
    preds = torch.argmax(output.data, 1)
    diff_count = torch.count_nonzero(preds - targets)
    batch_size = targets.size(0)
    test_error = float(diff_count) / batch_size

    mini_batch += 1
    print(f'Mini batch {mini_batch}: test error = {test_error}')



Mini batch 1: test error = 0.3125
Mini batch 2: test error = 0.296875
Mini batch 3: test error = 0.28125
Mini batch 4: test error = 0.25
Mini batch 5: test error = 0.3125
Mini batch 6: test error = 0.28125
Mini batch 7: test error = 0.328125
Mini batch 8: test error = 0.140625
Mini batch 9: test error = 0.203125
Mini batch 10: test error = 0.125
Mini batch 11: test error = 0.25
Mini batch 12: test error = 0.296875
Mini batch 13: test error = 0.234375
Mini batch 14: test error = 0.203125
Mini batch 15: test error = 0.25
Mini batch 16: test error = 0.203125
Mini batch 17: test error = 0.265625
Mini batch 18: test error = 0.359375
Mini batch 19: test error = 0.140625
Mini batch 20: test error = 0.3125
Mini batch 21: test error = 0.234375
Mini batch 22: test error = 0.28125
Mini batch 23: test error = 0.1875
Mini batch 24: test error = 0.28125
Mini batch 25: test error = 0.28125
Mini batch 26: test error = 0.171875
Mini batch 27: test error = 0.265625
Mini batch 28: test error = 0.328125
M