# Assignment 1
CIFAR-10

In [1]:
# import some necessary packages
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.datasets as tv_datasets
import torchvision.transforms as tv_transforms

from tqdm import tqdm

In [2]:
# some experimental setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

num_epochs = 128
batch_size = 256
num_workers = 8
# print_every = 200

optim_name = "AdamW"
optim_kwargs = dict(
    lr=3e-4,
)

# preprocessing pipeline for input images
transformation = dict()
for data_type in ("train", "test"):
    is_train = data_type=="train"
    transformation[data_type] = tv_transforms.Compose(([
        tv_transforms.RandomRotation(degrees=15),
        tv_transforms.RandomHorizontalFlip(),
        tv_transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.1),
        tv_transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    ] if is_train else []) +
    [
        tv_transforms.ToTensor(),
        tv_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

print(f"device: {device}")


device: cuda:0


In [3]:
# prepare datasets
dataset, loader = {}, {}
for data_type in ("train", "test"):
    is_train = data_type=="train"
    dataset[data_type] = tv_datasets.CIFAR10(
        root="./data", train=is_train, download=True, transform=transformation[data_type],
    )
    loader[data_type] = torch.utils.data.DataLoader(
        dataset[data_type], batch_size=batch_size, shuffle=is_train, num_workers=num_workers,
    )

print(f"train sample: {len(dataset['train'])}, test samples: {len(dataset['test'])}")
print(f"train batches: {len(loader['train'])}, test batches: {len(loader['test'])}")

Files already downloaded and verified
Files already downloaded and verified
train sample: 50000, test samples: 10000
train batches: 196, test batches: 40


### Baseline

In [4]:
class BaselineModel(nn.Module):
    def __init__(self):
        super(BaselineModel, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 128, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Dropout(0.3),
            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Dropout(0.3),
            nn.Conv2d(256, 512, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Dropout(0.3),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 512), nn.ReLU(inplace=True), nn.Dropout(0.5),
            nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Dropout(0.5),
            nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.5),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        return self.model(x)


### Residual Mechanism

In [5]:
class MyResBlock(nn.Module):
    def __init__(self, in_c, out_c, stride=1):
        super(MyResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 3, stride=stride, padding=1)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        if in_c != out_c:
            self.conv3 = nn.Conv2d(in_c, out_c, 1, stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_c)
        self.bn2 = nn.BatchNorm2d(out_c)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        y = F.relu(self.bn2(self.conv2(y)))
        if self.conv3 is not None:
            x = self.conv3(x)
        return F.relu(x + y)

class MyResnet(nn.Module):
    def __init__(self, num_block=1):
        super(MyResnet, self).__init__()
        self.before_res = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True),
        )
        self.res = nn.Sequential(
            *[MyResBlock(64, 64) for _ in range((num_block - 1) * 1)],
            MyResBlock(64, 128, 2), *[MyResBlock(128, 128) for _ in range((num_block - 1) * 2)], nn.Dropout(0.05),
            MyResBlock(128, 256, 2), *[MyResBlock(256, 256) for _ in range((num_block - 1) * 4)], nn.Dropout(0.05),
            MyResBlock(256, 512, 2), *[MyResBlock(512, 512) for _ in range((num_block - 1) * 2)], nn.Dropout(0.05),
            MyResBlock(512, 1024, 2), *[MyResBlock(1024, 1024) for _ in range((num_block - 1) * 1)], nn.Dropout(0.05),
        )
        self.after_res = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(), nn.Linear(1024, 10)
        )

    def forward(self, x):
        return self.after_res(self.res(self.before_res(x)))


### Attention

In [6]:
# our network architecture
class MyAttentionBlock(nn.Module):
    def __init__(self, d_attn, num_heads=4):
        super(MyAttentionBlock, self).__init__()
        self.num_heads = num_heads
        self.mha = nn.MultiheadAttention(embed_dim=d_attn, num_heads=num_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Conv2d(d_attn, d_attn*4, 1), nn.ReLU(inplace=True),
            nn.Conv2d(d_attn*4, d_attn, 1),
        )
        self.bn1 = nn.BatchNorm2d(d_attn)
        self.bn2 = nn.BatchNorm2d(d_attn)

    def forward(self, x):
        B, C, W, H = x.shape
        x_flat = x.flatten(2).permute(0, 2, 1)  # B, L, C
        attn_out, _ = self.mha(x_flat, x_flat, x_flat)
        attn_out = attn_out.permute(0, 2, 1).reshape(B, C, W, H)
        attn_out = F.relu(attn_out + x)
        attn_out = self.bn1(attn_out)
        out = self.ffn(attn_out) + attn_out
        out = self.bn2(out)
        return out

class MyAttention(nn.Module):
    def __init__(self, num_block=1):
        super(MyAttention, self).__init__()
        self.before_attn = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True),
        )
        self.attn = nn.Sequential(
            MyResBlock(64, 128, 2), *[MyAttentionBlock(128) for _ in range(num_block)],
            MyResBlock(128, 256, 2), *[MyAttentionBlock(256) for _ in range(num_block)],
            MyResBlock(256, 512, 2), *[MyAttentionBlock(512) for _ in range(num_block)],
        )
        self.after_attn = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(), nn.Linear(512, 10)
        )

    def forward(self, x):
        return self.after_attn(self.attn(self.before_attn(x)))


In [7]:
# move to device
model = MyResnet(num_block=2)
model.to(device)

# print the number of parameters
print(f"number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000:.2f}M")

number of parameters: 53.24M


## Start Training

In [8]:
# the network optimizer
optimizer = getattr(optim, optim_name)(model.parameters(), **optim_kwargs)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# loss function
criterion = nn.CrossEntropyLoss()

# training loop
model.train()
for epoch in range(num_epochs):

    running_loss = 0.0
    for i, (img, target) in tqdm(enumerate(loader["train"]), total=len(loader["train"])):
        img, target = img.to(device), target.to(device)

        pred = model(img)
        loss = criterion(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        # if i % print_every == print_every - 1:
        #     print(f"[epoch={epoch + 1:3d}, iter={i + 1:5d}] loss: {running_loss / print_every:.3f}")
        #     running_loss = 0.0
    print(f"[epoch={epoch + 1:3d}] loss: {running_loss / len(loader['train']):.3f}")
    scheduler.step()

print("Finished Training")


100%|██████████| 196/196 [00:10<00:00, 18.08it/s]

[epoch=  1] loss: 1.567



100%|██████████| 196/196 [00:10<00:00, 18.91it/s]

[epoch=  2] loss: 1.122



100%|██████████| 196/196 [00:10<00:00, 19.21it/s]

[epoch=  3] loss: 0.912



100%|██████████| 196/196 [00:10<00:00, 18.98it/s]

[epoch=  4] loss: 0.773



100%|██████████| 196/196 [00:10<00:00, 19.10it/s]

[epoch=  5] loss: 0.675



100%|██████████| 196/196 [00:10<00:00, 19.07it/s]

[epoch=  6] loss: 0.616



100%|██████████| 196/196 [00:10<00:00, 19.04it/s]

[epoch=  7] loss: 0.560



100%|██████████| 196/196 [00:10<00:00, 19.12it/s]

[epoch=  8] loss: 0.515



100%|██████████| 196/196 [00:10<00:00, 19.14it/s]

[epoch=  9] loss: 0.484



100%|██████████| 196/196 [00:10<00:00, 19.01it/s]

[epoch= 10] loss: 0.442



100%|██████████| 196/196 [00:10<00:00, 19.04it/s]

[epoch= 11] loss: 0.423



100%|██████████| 196/196 [00:10<00:00, 19.09it/s]

[epoch= 12] loss: 0.402



100%|██████████| 196/196 [00:10<00:00, 19.11it/s]

[epoch= 13] loss: 0.374



100%|██████████| 196/196 [00:10<00:00, 19.05it/s]

[epoch= 14] loss: 0.354



100%|██████████| 196/196 [00:10<00:00, 19.05it/s]

[epoch= 15] loss: 0.341



100%|██████████| 196/196 [00:10<00:00, 18.97it/s]

[epoch= 16] loss: 0.322



100%|██████████| 196/196 [00:10<00:00, 19.11it/s]

[epoch= 17] loss: 0.301



100%|██████████| 196/196 [00:10<00:00, 19.07it/s]

[epoch= 18] loss: 0.287



100%|██████████| 196/196 [00:10<00:00, 18.81it/s]

[epoch= 19] loss: 0.279



100%|██████████| 196/196 [00:10<00:00, 19.02it/s]

[epoch= 20] loss: 0.268



100%|██████████| 196/196 [00:10<00:00, 19.07it/s]

[epoch= 21] loss: 0.261



100%|██████████| 196/196 [00:10<00:00, 19.00it/s]

[epoch= 22] loss: 0.237



100%|██████████| 196/196 [00:10<00:00, 18.96it/s]

[epoch= 23] loss: 0.231



100%|██████████| 196/196 [00:10<00:00, 18.48it/s]

[epoch= 24] loss: 0.223



100%|██████████| 196/196 [00:10<00:00, 18.61it/s]

[epoch= 25] loss: 0.212



100%|██████████| 196/196 [00:10<00:00, 18.88it/s]

[epoch= 26] loss: 0.206



100%|██████████| 196/196 [00:10<00:00, 18.21it/s]

[epoch= 27] loss: 0.191



100%|██████████| 196/196 [00:10<00:00, 18.93it/s]

[epoch= 28] loss: 0.187



100%|██████████| 196/196 [00:10<00:00, 18.93it/s]

[epoch= 29] loss: 0.180



100%|██████████| 196/196 [00:10<00:00, 18.77it/s]

[epoch= 30] loss: 0.167



100%|██████████| 196/196 [00:10<00:00, 19.12it/s]

[epoch= 31] loss: 0.170



100%|██████████| 196/196 [00:10<00:00, 19.00it/s]

[epoch= 32] loss: 0.153



100%|██████████| 196/196 [00:10<00:00, 18.99it/s]

[epoch= 33] loss: 0.154



100%|██████████| 196/196 [00:10<00:00, 18.99it/s]

[epoch= 34] loss: 0.149



100%|██████████| 196/196 [00:10<00:00, 18.84it/s]

[epoch= 35] loss: 0.137



100%|██████████| 196/196 [00:10<00:00, 19.10it/s]

[epoch= 36] loss: 0.135



100%|██████████| 196/196 [00:10<00:00, 19.08it/s]

[epoch= 37] loss: 0.128



100%|██████████| 196/196 [00:10<00:00, 18.97it/s]

[epoch= 38] loss: 0.123



100%|██████████| 196/196 [00:10<00:00, 19.11it/s]

[epoch= 39] loss: 0.123



100%|██████████| 196/196 [00:10<00:00, 19.10it/s]

[epoch= 40] loss: 0.111



100%|██████████| 196/196 [00:10<00:00, 19.05it/s]

[epoch= 41] loss: 0.108



100%|██████████| 196/196 [00:10<00:00, 19.01it/s]

[epoch= 42] loss: 0.105



100%|██████████| 196/196 [00:10<00:00, 19.13it/s]

[epoch= 43] loss: 0.098



100%|██████████| 196/196 [00:10<00:00, 19.11it/s]

[epoch= 44] loss: 0.093



100%|██████████| 196/196 [00:10<00:00, 19.06it/s]

[epoch= 45] loss: 0.094



100%|██████████| 196/196 [00:10<00:00, 19.11it/s]

[epoch= 46] loss: 0.091



100%|██████████| 196/196 [00:10<00:00, 19.10it/s]

[epoch= 47] loss: 0.085



100%|██████████| 196/196 [00:10<00:00, 19.17it/s]

[epoch= 48] loss: 0.081



100%|██████████| 196/196 [00:10<00:00, 19.21it/s]

[epoch= 49] loss: 0.082



100%|██████████| 196/196 [00:10<00:00, 19.07it/s]

[epoch= 50] loss: 0.072



100%|██████████| 196/196 [00:10<00:00, 19.04it/s]

[epoch= 51] loss: 0.073



100%|██████████| 196/196 [00:10<00:00, 19.06it/s]

[epoch= 52] loss: 0.071



100%|██████████| 196/196 [00:10<00:00, 19.14it/s]

[epoch= 53] loss: 0.066



100%|██████████| 196/196 [00:10<00:00, 19.06it/s]

[epoch= 54] loss: 0.067



100%|██████████| 196/196 [00:10<00:00, 19.09it/s]

[epoch= 55] loss: 0.063



100%|██████████| 196/196 [00:10<00:00, 18.91it/s]

[epoch= 56] loss: 0.061



100%|██████████| 196/196 [00:10<00:00, 18.88it/s]

[epoch= 57] loss: 0.055



100%|██████████| 196/196 [00:10<00:00, 18.90it/s]

[epoch= 58] loss: 0.056



100%|██████████| 196/196 [00:10<00:00, 18.96it/s]

[epoch= 59] loss: 0.054



100%|██████████| 196/196 [00:10<00:00, 19.01it/s]

[epoch= 60] loss: 0.051



100%|██████████| 196/196 [00:10<00:00, 18.94it/s]

[epoch= 61] loss: 0.048



100%|██████████| 196/196 [00:10<00:00, 18.92it/s]

[epoch= 62] loss: 0.045



100%|██████████| 196/196 [00:10<00:00, 18.87it/s]

[epoch= 63] loss: 0.049



100%|██████████| 196/196 [00:10<00:00, 18.94it/s]

[epoch= 64] loss: 0.044



100%|██████████| 196/196 [00:10<00:00, 18.99it/s]

[epoch= 65] loss: 0.047



100%|██████████| 196/196 [00:10<00:00, 18.94it/s]

[epoch= 66] loss: 0.040



100%|██████████| 196/196 [00:10<00:00, 19.00it/s]

[epoch= 67] loss: 0.042



100%|██████████| 196/196 [00:10<00:00, 19.07it/s]

[epoch= 68] loss: 0.036



100%|██████████| 196/196 [00:10<00:00, 18.99it/s]

[epoch= 69] loss: 0.039



100%|██████████| 196/196 [00:10<00:00, 19.08it/s]

[epoch= 70] loss: 0.032



100%|██████████| 196/196 [00:10<00:00, 19.01it/s]

[epoch= 71] loss: 0.034



100%|██████████| 196/196 [00:10<00:00, 19.10it/s]

[epoch= 72] loss: 0.031



100%|██████████| 196/196 [00:10<00:00, 19.16it/s]

[epoch= 73] loss: 0.032



100%|██████████| 196/196 [00:10<00:00, 19.06it/s]

[epoch= 74] loss: 0.030



100%|██████████| 196/196 [00:10<00:00, 19.11it/s]

[epoch= 75] loss: 0.028



100%|██████████| 196/196 [00:10<00:00, 18.85it/s]

[epoch= 76] loss: 0.023



100%|██████████| 196/196 [00:10<00:00, 18.95it/s]

[epoch= 77] loss: 0.025



100%|██████████| 196/196 [00:10<00:00, 19.12it/s]

[epoch= 78] loss: 0.024



100%|██████████| 196/196 [00:10<00:00, 19.00it/s]

[epoch= 79] loss: 0.025



100%|██████████| 196/196 [00:10<00:00, 19.07it/s]

[epoch= 80] loss: 0.023



100%|██████████| 196/196 [00:10<00:00, 19.06it/s]

[epoch= 81] loss: 0.022



100%|██████████| 196/196 [00:10<00:00, 19.13it/s]

[epoch= 82] loss: 0.020



100%|██████████| 196/196 [00:10<00:00, 19.16it/s]

[epoch= 83] loss: 0.020



100%|██████████| 196/196 [00:10<00:00, 19.06it/s]

[epoch= 84] loss: 0.018



100%|██████████| 196/196 [00:10<00:00, 19.18it/s]

[epoch= 85] loss: 0.020



100%|██████████| 196/196 [00:10<00:00, 19.12it/s]

[epoch= 86] loss: 0.018



100%|██████████| 196/196 [00:10<00:00, 18.92it/s]

[epoch= 87] loss: 0.017



100%|██████████| 196/196 [00:10<00:00, 19.02it/s]

[epoch= 88] loss: 0.016



100%|██████████| 196/196 [00:10<00:00, 18.98it/s]

[epoch= 89] loss: 0.015



100%|██████████| 196/196 [00:10<00:00, 19.05it/s]

[epoch= 90] loss: 0.013



100%|██████████| 196/196 [00:10<00:00, 19.19it/s]

[epoch= 91] loss: 0.014



100%|██████████| 196/196 [00:10<00:00, 19.21it/s]

[epoch= 92] loss: 0.011



100%|██████████| 196/196 [00:10<00:00, 19.05it/s]

[epoch= 93] loss: 0.011



100%|██████████| 196/196 [00:10<00:00, 19.04it/s]

[epoch= 94] loss: 0.012



100%|██████████| 196/196 [00:10<00:00, 19.07it/s]

[epoch= 95] loss: 0.011



100%|██████████| 196/196 [00:10<00:00, 18.92it/s]

[epoch= 96] loss: 0.012



100%|██████████| 196/196 [00:10<00:00, 18.95it/s]

[epoch= 97] loss: 0.010



100%|██████████| 196/196 [00:10<00:00, 19.13it/s]

[epoch= 98] loss: 0.010



100%|██████████| 196/196 [00:10<00:00, 19.12it/s]

[epoch= 99] loss: 0.009



100%|██████████| 196/196 [00:10<00:00, 19.15it/s]

[epoch=100] loss: 0.009



100%|██████████| 196/196 [00:10<00:00, 19.10it/s]

[epoch=101] loss: 0.007



100%|██████████| 196/196 [00:10<00:00, 19.18it/s]

[epoch=102] loss: 0.008



100%|██████████| 196/196 [00:10<00:00, 19.14it/s]

[epoch=103] loss: 0.007



100%|██████████| 196/196 [00:10<00:00, 18.83it/s]

[epoch=104] loss: 0.007



100%|██████████| 196/196 [00:10<00:00, 18.95it/s]

[epoch=105] loss: 0.007



100%|██████████| 196/196 [00:10<00:00, 18.88it/s]

[epoch=106] loss: 0.007



100%|██████████| 196/196 [00:10<00:00, 18.96it/s]

[epoch=107] loss: 0.006



100%|██████████| 196/196 [00:10<00:00, 19.17it/s]

[epoch=108] loss: 0.007



100%|██████████| 196/196 [00:10<00:00, 19.06it/s]

[epoch=109] loss: 0.005



100%|██████████| 196/196 [00:10<00:00, 18.99it/s]

[epoch=110] loss: 0.005



100%|██████████| 196/196 [00:10<00:00, 19.12it/s]

[epoch=111] loss: 0.005



100%|██████████| 196/196 [00:10<00:00, 19.06it/s]

[epoch=112] loss: 0.005



100%|██████████| 196/196 [00:10<00:00, 19.09it/s]

[epoch=113] loss: 0.005



100%|██████████| 196/196 [00:10<00:00, 19.05it/s]

[epoch=114] loss: 0.005



100%|██████████| 196/196 [00:10<00:00, 19.09it/s]

[epoch=115] loss: 0.005



100%|██████████| 196/196 [00:10<00:00, 18.66it/s]

[epoch=116] loss: 0.004



100%|██████████| 196/196 [00:10<00:00, 19.06it/s]

[epoch=117] loss: 0.004



100%|██████████| 196/196 [00:10<00:00, 18.96it/s]

[epoch=118] loss: 0.005



100%|██████████| 196/196 [00:10<00:00, 19.01it/s]

[epoch=119] loss: 0.003



100%|██████████| 196/196 [00:10<00:00, 19.16it/s]

[epoch=120] loss: 0.004



100%|██████████| 196/196 [00:10<00:00, 18.97it/s]

[epoch=121] loss: 0.004



100%|██████████| 196/196 [00:10<00:00, 19.00it/s]

[epoch=122] loss: 0.004



100%|██████████| 196/196 [00:10<00:00, 18.87it/s]

[epoch=123] loss: 0.005



100%|██████████| 196/196 [00:10<00:00, 18.95it/s]

[epoch=124] loss: 0.004



100%|██████████| 196/196 [00:10<00:00, 19.39it/s]

[epoch=125] loss: 0.004



100%|██████████| 196/196 [00:10<00:00, 19.24it/s]

[epoch=126] loss: 0.003



100%|██████████| 196/196 [00:10<00:00, 19.14it/s]

[epoch=127] loss: 0.004



100%|██████████| 196/196 [00:10<00:00, 19.04it/s]

[epoch=128] loss: 0.004
Finished Training





## Evaluating its accuracy

In [9]:
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for img, target in loader["test"]:
        img, target = img.to(device), target.to(device)

        # make prediction
        pred = model(img)

        # accumulate
        total += len(target)
        correct += (torch.argmax(pred, dim=1) == target).sum().item()

print(f"Accuracy of the network on the {total} test images: {100 * correct / total:.2f}%")

Accuracy of the network on the 10000 test images: 93.24%
