## CIFAR-10 Training

This is an example of how to train the ViT on the CIFAR-10 dataset.

Since no kind of regularization, augmentation, etc. is applied, the model is highly overfitting (i.e., 100% accuracy on the training subset after 24 epochs). 


### Imports

In [1]:
import random

import torch
import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm

from ViT import ViT

Set seeds

In [None]:
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

### Load CIFAR dataset

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=64, shuffle=True, num_workers=2)

test_loader = torch.utils.data.DataLoader(
    testset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


### Create instances that will be used during training.

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
epochs = 24
learning_rate = 1e-4
image_dims = (3, 32, 32)
num_patches = 8
embedding_dim = 256 # 512
mlp_hidden_dim = 256 # 128
num_heads = 16
output_dim = 10
num_blocks = 12

model = ViT(image_dims=image_dims, 
            num_heads=num_heads, 
            num_patches=num_patches, 
            embedding_dim=embedding_dim, 
            num_blocks=num_blocks, 
            output_dim=output_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=8)
loss_fn = torch.nn.CrossEntropyLoss()

### Training loop

In [5]:
for epoch in range(epochs):
    train_loss = 0
    correct, total = 0, 0
    for batch in tqdm(train_loader, desc=f'Epoch: {epoch + 1} in training'):
        imgs, labels = batch
        imgs = imgs.to(device)
        labels = labels.to(device)
        preds = model(imgs)
        loss = loss_fn(preds, labels)

        train_loss += loss.detach().cpu().item() / len(train_loader)

        correct += torch.sum(torch.argmax(preds, dim=1)
                                 == labels).detach().cpu().item()
        total += len(imgs)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
    lr_scheduler.step()

    print(f"Epoch {epoch + 1} / {epochs} loss: {train_loss:.2f} accuracy: {correct / total * 100:.2f}%")

    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0
        for batch in tqdm(test_loader, desc='Testing'):
            imgs, labels = batch
            imgs = imgs.to(device)
            labels = labels.to(device)

            preds = model(imgs)
            loss = loss_fn(preds, labels)
            test_loss += loss.detach().cpu().item() / len(test_loader)
            correct += torch.sum(torch.argmax(preds, dim=1)
                                 == labels).detach().cpu().item()
            total += len(imgs)
        print(f"Test loss: {test_loss:.2f} accuracy: {correct / total * 100:.2f}%")

Epoch: 1 in training: 100%|██████████| 782/782 [00:33<00:00, 23.32it/s]


Epoch 1 / 24 loss: 1.61 accuracy: 41.20%


Testing: 100%|██████████| 157/157 [00:06<00:00, 26.08it/s]


Test loss: 1.44 accuracy: 49.33%


Epoch: 2 in training: 100%|██████████| 782/782 [00:33<00:00, 23.15it/s]


Epoch 2 / 24 loss: 1.17 accuracy: 57.85%


Testing: 100%|██████████| 157/157 [00:06<00:00, 25.88it/s]


Test loss: 1.14 accuracy: 58.59%


Epoch: 3 in training: 100%|██████████| 782/782 [00:33<00:00, 23.02it/s]


Epoch 3 / 24 loss: 0.98 accuracy: 64.96%


Testing: 100%|██████████| 157/157 [00:05<00:00, 26.42it/s]


Test loss: 1.05 accuracy: 62.73%


Epoch: 4 in training: 100%|██████████| 782/782 [00:32<00:00, 23.71it/s]


Epoch 4 / 24 loss: 0.82 accuracy: 70.82%


Testing: 100%|██████████| 157/157 [00:05<00:00, 27.38it/s]


Test loss: 0.99 accuracy: 64.80%


Epoch: 5 in training: 100%|██████████| 782/782 [00:32<00:00, 23.91it/s]


Epoch 5 / 24 loss: 0.66 accuracy: 76.41%


Testing: 100%|██████████| 157/157 [00:06<00:00, 25.37it/s]


Test loss: 0.97 accuracy: 66.17%


Epoch: 6 in training: 100%|██████████| 782/782 [00:32<00:00, 23.89it/s]


Epoch 6 / 24 loss: 0.50 accuracy: 82.22%


Testing: 100%|██████████| 157/157 [00:06<00:00, 25.75it/s]


Test loss: 0.97 accuracy: 67.78%


Epoch: 7 in training: 100%|██████████| 782/782 [00:34<00:00, 22.99it/s]


Epoch 7 / 24 loss: 0.37 accuracy: 87.27%


Testing: 100%|██████████| 157/157 [00:06<00:00, 26.15it/s]


Test loss: 1.03 accuracy: 68.37%


Epoch: 8 in training: 100%|██████████| 782/782 [00:34<00:00, 22.99it/s]


Epoch 8 / 24 loss: 0.28 accuracy: 91.00%


Testing: 100%|██████████| 157/157 [00:05<00:00, 26.55it/s]


Test loss: 1.08 accuracy: 68.44%


Epoch: 9 in training: 100%|██████████| 782/782 [00:35<00:00, 22.33it/s]


Epoch 9 / 24 loss: 0.25 accuracy: 92.24%


Testing: 100%|██████████| 157/157 [00:06<00:00, 25.66it/s]


Test loss: 1.08 accuracy: 68.44%


Epoch: 10 in training: 100%|██████████| 782/782 [00:27<00:00, 28.85it/s]


Epoch 10 / 24 loss: 0.25 accuracy: 91.93%


Testing: 100%|██████████| 157/157 [00:06<00:00, 25.96it/s]


Test loss: 1.12 accuracy: 68.48%


Epoch: 11 in training: 100%|██████████| 782/782 [00:34<00:00, 22.78it/s]


Epoch 11 / 24 loss: 0.26 accuracy: 91.16%


Testing: 100%|██████████| 157/157 [00:05<00:00, 26.63it/s]


Test loss: 1.19 accuracy: 67.36%


Epoch: 12 in training: 100%|██████████| 782/782 [00:33<00:00, 23.05it/s]


Epoch 12 / 24 loss: 0.27 accuracy: 90.21%


Testing: 100%|██████████| 157/157 [00:06<00:00, 26.03it/s]


Test loss: 1.29 accuracy: 66.47%


Epoch: 13 in training: 100%|██████████| 782/782 [00:34<00:00, 22.96it/s]


Epoch 13 / 24 loss: 0.30 accuracy: 88.92%


Testing: 100%|██████████| 157/157 [00:05<00:00, 26.36it/s]


Test loss: 1.29 accuracy: 66.27%


Epoch: 14 in training: 100%|██████████| 782/782 [00:34<00:00, 22.85it/s]


Epoch 14 / 24 loss: 0.33 accuracy: 88.05%


Testing: 100%|██████████| 157/157 [00:06<00:00, 25.23it/s]


Test loss: 1.23 accuracy: 66.53%


Epoch: 15 in training: 100%|██████████| 782/782 [00:34<00:00, 22.94it/s]


Epoch 15 / 24 loss: 0.34 accuracy: 87.70%


Testing: 100%|██████████| 157/157 [00:06<00:00, 26.03it/s]


Test loss: 1.30 accuracy: 65.03%


Epoch: 16 in training: 100%|██████████| 782/782 [00:33<00:00, 23.02it/s]


Epoch 16 / 24 loss: 0.32 accuracy: 88.44%


Testing: 100%|██████████| 157/157 [00:06<00:00, 26.09it/s]


Test loss: 1.34 accuracy: 64.55%


Epoch: 17 in training: 100%|██████████| 782/782 [00:33<00:00, 23.03it/s]


Epoch 17 / 24 loss: 0.28 accuracy: 90.00%


Testing: 100%|██████████| 157/157 [00:06<00:00, 25.95it/s]


Test loss: 1.42 accuracy: 65.48%


Epoch: 18 in training: 100%|██████████| 782/782 [00:33<00:00, 23.01it/s]


Epoch 18 / 24 loss: 0.22 accuracy: 92.17%


Testing: 100%|██████████| 157/157 [00:06<00:00, 26.13it/s]


Test loss: 1.43 accuracy: 64.86%


Epoch: 19 in training: 100%|██████████| 782/782 [00:33<00:00, 23.07it/s]


Epoch 19 / 24 loss: 0.16 accuracy: 94.29%


Testing: 100%|██████████| 157/157 [00:06<00:00, 25.13it/s]


Test loss: 1.64 accuracy: 65.85%


Epoch: 20 in training: 100%|██████████| 782/782 [00:34<00:00, 23.00it/s]


Epoch 20 / 24 loss: 0.10 accuracy: 96.31%


Testing: 100%|██████████| 157/157 [00:05<00:00, 26.17it/s]


Test loss: 1.77 accuracy: 66.91%


Epoch: 21 in training: 100%|██████████| 782/782 [00:33<00:00, 23.07it/s]


Epoch 21 / 24 loss: 0.05 accuracy: 98.17%


Testing: 100%|██████████| 157/157 [00:06<00:00, 25.80it/s]


Test loss: 2.16 accuracy: 67.08%


Epoch: 22 in training: 100%|██████████| 782/782 [00:33<00:00, 23.09it/s]


Epoch 22 / 24 loss: 0.02 accuracy: 99.30%


Testing: 100%|██████████| 157/157 [00:05<00:00, 26.61it/s]


Test loss: 2.59 accuracy: 68.14%


Epoch: 23 in training: 100%|██████████| 782/782 [00:33<00:00, 23.08it/s]


Epoch 23 / 24 loss: 0.00 accuracy: 99.93%


Testing: 100%|██████████| 157/157 [00:06<00:00, 25.88it/s]


Test loss: 2.73 accuracy: 69.10%


Epoch: 24 in training: 100%|██████████| 782/782 [00:33<00:00, 23.17it/s]


Epoch 24 / 24 loss: 0.00 accuracy: 100.00%


Testing: 100%|██████████| 157/157 [00:06<00:00, 25.88it/s]

Test loss: 2.76 accuracy: 69.25%



