In [1]:
! pip install torch numpy timm==0.5.4 tqdm

Collecting timm==0.5.4
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m431.5/431.5 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2

In [2]:
!git clone https://github.com/tsungchiehchen/Vision-Transformer.git

Cloning into 'Vision-Transformer'...
remote: Enumerating objects: 27, done.[K
remote: Counting objects: 100% (27/27), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 27 (delta 10), reused 24 (delta 7), pack-reused 0[K
Receiving objects: 100% (27/27), 24.42 KiB | 555.00 KiB/s, done.
Resolving deltas: 100% (10/10), done.


In [3]:
%cd ./Vision-Transformer

/content/Vision-Transformer


In [4]:
import argparse
import datetime
import os
import sys
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn

from timm.models import create_model

from engine import train_one_epoch, train_one_epoch_distillation, evaluate
from utils import get_training_dataloader, get_test_dataloader
import models

In [None]:
MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
CHECKPOINT_PATH = './checkpoint'
MODEL_NAME = 'vit_base_patch16_224'
num_classes = 10
EPOCHS = 5
LR = 0.0001
WD = 0.0
shots = 1000

print(f"Creating model: {MODEL_NAME}")
model = create_model(
        MODEL_NAME,
        pretrained=False,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
model = model.to(device)

cifar10_training_loader = get_training_dataloader(
    MEAN,
    STD,
    num_workers=2,
    batch_size=16,
    shuffle=True,
    shots=shots
)

assert (shots*num_classes == len(cifar10_training_loader.dataset))

cifar10_test_loader = get_test_dataloader(
    MEAN,
    STD,
    num_workers=4,
    batch_size=256,
    shuffle=False
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

Creating model: vit_base_patch16_224
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 79445281.30it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
number of params: 86567656




In [None]:
print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        model, criterion, cifar10_training_loader,
        optimizer, device, epoch)
    if epoch % 10 == 9:
        test_stats = evaluate(cifar10_test_loader, model, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

test_stats = evaluate(cifar10_test_loader, model, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

Start training for 5 epochs
Epoch: [1]  [  0/625]  eta: 0:28:48  loss: 7.2056 (7.2056)  time: 2.7654  data: 0.4869  max mem: 2096
Epoch: [1]  [100/625]  eta: 0:01:44  loss: 2.2015 (2.4863)  time: 0.1713  data: 0.0057  max mem: 3065
Epoch: [1]  [200/625]  eta: 0:01:19  loss: 2.1340 (2.3034)  time: 0.1701  data: 0.0055  max mem: 3065
Epoch: [1]  [300/625]  eta: 0:00:59  loss: 2.0563 (2.2224)  time: 0.1778  data: 0.0084  max mem: 3065
Epoch: [1]  [400/625]  eta: 0:00:40  loss: 2.1087 (2.1713)  time: 0.1774  data: 0.0065  max mem: 3065
Epoch: [1]  [500/625]  eta: 0:00:22  loss: 1.8859 (2.1306)  time: 0.1793  data: 0.0054  max mem: 3065
Epoch: [1]  [600/625]  eta: 0:00:04  loss: 1.9450 (2.1032)  time: 0.1898  data: 0.0100  max mem: 3065
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 2.0016 (2.0996)  time: 0.1834  data: 0.0051  max mem: 3065
Epoch: [1] Total time: 0:01:53 (0.1819 s / it)
Averaged stats: loss: 2.0016 (2.0996)
Epoch: [2]  [  0/625]  eta: 0:03:33  loss: 2.1157 (2.1157)  time: 0.341

In [None]:
# Calculate througput
start_time = time.time()
test_stats = evaluate(cifar10_test_loader, model, criterion, device)
end_time = time.time()
num_samples = len(cifar10_test_loader.dataset)
throughput = num_samples / (end_time - start_time)
print("Throughput: {}".format(throughput))

Test:  [ 0/40]  eta: 0:01:58  loss: 1.6862 (1.6862)  acc1: 40.6250 (40.6250)  acc5: 85.1562 (85.1562)  time: 2.9730  data: 1.8310  max mem: 3158
Test:  [20/40]  eta: 0:00:23  loss: 1.7490 (1.7590)  acc1: 36.3281 (35.9189)  acc5: 88.6719 (88.4301)  time: 1.0917  data: 0.0786  max mem: 3158
Test:  [39/40]  eta: 0:00:01  loss: 1.7465 (1.7480)  acc1: 36.7188 (36.1300)  acc5: 88.6719 (88.0300)  time: 1.0233  data: 0.0669  max mem: 3158
Test: Total time: 0:00:44 (1.1050 s / it)
* Acc@1 36.130 Acc@5 88.030 loss 1.748
Throughput: 226.21906627248507


# Q2 Fine-tuning Pretrained ViT

In [None]:
MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
CHECKPOINT_PATH = './checkpoint'
MODEL_NAME = 'vit_base_patch16_224'
num_classes = 10
EPOCHS = 5
LR = 0.0001
WD = 0.0
shots = 1000

print(f"Creating model: {MODEL_NAME}")
model = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
model = model.to(device)

cifar10_training_loader = get_training_dataloader(
    MEAN,
    STD,
    num_workers=2,
    batch_size=16,
    shuffle=True,
    shots=shots
)

assert (shots*num_classes == len(cifar10_training_loader.dataset))

cifar10_test_loader = get_test_dataloader(
    MEAN,
    STD,
    num_workers=4,
    batch_size=256,
    shuffle=False
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

Creating model: vit_base_patch16_224


Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /root/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
100%|██████████| 330M/330M [00:05<00:00, 59.9MB/s]


Files already downloaded and verified
Files already downloaded and verified
number of params: 85806346


In [None]:
print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        model, criterion, cifar10_training_loader,
        optimizer, device, epoch)
    if epoch % 10 == 9:
        test_stats = evaluate(cifar10_test_loader, model, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

test_stats = evaluate(cifar10_test_loader, model, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

Start training for 5 epochs
Epoch: [1]  [  0/625]  eta: 0:04:01  loss: 2.3790 (2.3790)  time: 0.3858  data: 0.1525  max mem: 3158
Epoch: [1]  [100/625]  eta: 0:01:38  loss: 2.0979 (2.2390)  time: 0.1886  data: 0.0055  max mem: 3158
Epoch: [1]  [200/625]  eta: 0:01:20  loss: 2.0896 (2.1396)  time: 0.1938  data: 0.0096  max mem: 3158
Epoch: [1]  [300/625]  eta: 0:01:01  loss: 1.8388 (2.0790)  time: 0.1859  data: 0.0053  max mem: 3158
Epoch: [1]  [400/625]  eta: 0:00:42  loss: 1.7095 (2.0115)  time: 0.1875  data: 0.0071  max mem: 3158
Epoch: [1]  [500/625]  eta: 0:00:23  loss: 1.5865 (1.9341)  time: 0.1861  data: 0.0057  max mem: 3158
Epoch: [1]  [600/625]  eta: 0:00:04  loss: 1.4219 (1.8547)  time: 0.1886  data: 0.0058  max mem: 3158
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 1.4646 (1.8390)  time: 0.1900  data: 0.0065  max mem: 3158
Epoch: [1] Total time: 0:01:58 (0.1890 s / it)
Averaged stats: loss: 1.4646 (1.8390)
Epoch: [2]  [  0/625]  eta: 0:03:39  loss: 1.2923 (1.2923)  time: 0.351

In [None]:
# Calculate througput
start_time = time.time()
test_stats = evaluate(cifar10_test_loader, model, criterion, device)
end_time = time.time()
num_samples = len(cifar10_test_loader.dataset)
throughput = num_samples / (end_time - start_time)
print("Throughput: {}".format(throughput))

Test:  [ 0/40]  eta: 0:02:29  loss: 0.6622 (0.6622)  acc1: 77.7344 (77.7344)  acc5: 98.4375 (98.4375)  time: 3.7399  data: 2.5347  max mem: 3158
Test:  [20/40]  eta: 0:00:25  loss: 0.5921 (0.5938)  acc1: 80.0781 (80.3571)  acc5: 99.2188 (99.0699)  time: 1.1569  data: 0.1107  max mem: 3158
Test:  [39/40]  eta: 0:00:01  loss: 0.5921 (0.5866)  acc1: 80.4688 (80.3500)  acc5: 99.2188 (99.1700)  time: 1.0263  data: 0.0535  max mem: 3364
Test: Total time: 0:00:46 (1.1643 s / it)
* Acc@1 80.350 Acc@5 99.170 loss 0.587
Throughput: 214.69383078072903


# Q3 ViT model on a small device

In [5]:
MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
CHECKPOINT_PATH = './checkpoint'
MODEL_NAME = 'vit_tiny_patch16_224'
num_classes = 10
EPOCHS = 5
LR = 0.0001
WD = 0.0
shots = 1000

print(f"Creating model: {MODEL_NAME}")
model = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
model = model.to(device)

cifar10_training_loader = get_training_dataloader(
    MEAN,
    STD,
    num_workers=2,
    batch_size=16,
    shuffle=True,
    shots=shots
)

assert (shots*num_classes == len(cifar10_training_loader.dataset))

cifar10_test_loader = get_test_dataloader(
    MEAN,
    STD,
    num_workers=4,
    batch_size=256,
    shuffle=False
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

Creating model: vit_tiny_patch16_224


Downloading: "https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth" to /root/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth
100%|██████████| 21.9M/21.9M [00:00<00:00, 26.5MB/s]


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:10<00:00, 16135777.33it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
number of params: 5526346




In [6]:
print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        model, criterion, cifar10_training_loader,
        optimizer, device, epoch)
    if epoch % 10 == 9:
        test_stats = evaluate(cifar10_test_loader, model, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

test_stats = evaluate(cifar10_test_loader, model, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

Start training for 5 epochs


  return F.conv2d(input, weight, bias, self.stride,


Epoch: [1]  [  0/625]  eta: 0:10:37  loss: 3.0119 (3.0119)  time: 1.0207  data: 0.1758  max mem: 463
Epoch: [1]  [100/625]  eta: 0:00:39  loss: 2.0155 (2.1878)  time: 0.0863  data: 0.0097  max mem: 513
Epoch: [1]  [200/625]  eta: 0:00:27  loss: 2.0032 (2.1210)  time: 0.0524  data: 0.0043  max mem: 513
Epoch: [1]  [300/625]  eta: 0:00:19  loss: 2.0289 (2.0550)  time: 0.0705  data: 0.0077  max mem: 513
Epoch: [1]  [400/625]  eta: 0:00:13  loss: 1.7647 (1.9978)  time: 0.0559  data: 0.0049  max mem: 513
Epoch: [1]  [500/625]  eta: 0:00:07  loss: 1.5791 (1.9344)  time: 0.0530  data: 0.0050  max mem: 513
Epoch: [1]  [600/625]  eta: 0:00:01  loss: 1.6066 (1.8826)  time: 0.0540  data: 0.0046  max mem: 513
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 1.6054 (1.8729)  time: 0.0509  data: 0.0043  max mem: 513
Epoch: [1] Total time: 0:00:37 (0.0603 s / it)
Averaged stats: loss: 1.6054 (1.8729)
Epoch: [2]  [  0/625]  eta: 0:06:24  loss: 1.6375 (1.6375)  time: 0.6146  data: 0.3943  max mem: 513
Epoch:

In [7]:
# Calculate througput
start_time = time.time()
test_stats = evaluate(cifar10_test_loader, model, criterion, device)
end_time = time.time()
num_samples = len(cifar10_test_loader.dataset)
throughput = num_samples / (end_time - start_time)
print("Throughput: {}".format(throughput))

Test:  [ 0/40]  eta: 0:01:51  loss: 0.7713 (0.7713)  acc1: 71.8750 (71.8750)  acc5: 98.0469 (98.0469)  time: 2.7812  data: 2.4156  max mem: 760
Test:  [20/40]  eta: 0:00:12  loss: 0.8422 (0.8150)  acc1: 71.0938 (71.9866)  acc5: 98.0469 (98.0655)  time: 0.5387  data: 0.2472  max mem: 760
Test:  [39/40]  eta: 0:00:00  loss: 0.8209 (0.7990)  acc1: 72.2656 (72.4500)  acc5: 98.0469 (98.1400)  time: 0.4578  data: 0.1973  max mem: 760
Test: Total time: 0:00:22 (0.5559 s / it)
* Acc@1 72.450 Acc@5 98.140 loss 0.799
Throughput: 449.6078784800346


# Q4 Knowledge Distillation

In [8]:
# Step 1: Train the teacher
MODEL_NAME = 'vit_base_patch16_224'
num_classes = 10
EPOCHS = 5
LR = 0.0001
WD = 0.0

print(f"Creating model: {MODEL_NAME}")
teacher = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
teacher = teacher.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in teacher.parameters() if p.requires_grad)
print('number of params:', n_parameters)

Creating model: vit_base_patch16_224


Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /root/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
100%|██████████| 330M/330M [00:03<00:00, 87.9MB/s]


number of params: 85806346


In [9]:
print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        teacher, criterion, cifar10_training_loader,
        optimizer, device, epoch)
    if epoch % 10 == 9:
        test_stats = evaluate(cifar10_test_loader, teacher, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

test_stats = evaluate(cifar10_test_loader, teacher, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

Start training for 5 epochs
Epoch: [1]  [  0/625]  eta: 0:04:16  loss: 2.4119 (2.4119)  time: 0.4107  data: 0.1498  max mem: 2139
Epoch: [1]  [100/625]  eta: 0:01:34  loss: 1.7949 (2.1203)  time: 0.1803  data: 0.0077  max mem: 3092
Epoch: [1]  [200/625]  eta: 0:01:17  loss: 1.3583 (1.8583)  time: 0.1773  data: 0.0046  max mem: 3092
Epoch: [1]  [300/625]  eta: 0:00:58  loss: 1.1574 (1.6417)  time: 0.1801  data: 0.0054  max mem: 3092
Epoch: [1]  [400/625]  eta: 0:00:40  loss: 1.0959 (1.4966)  time: 0.1803  data: 0.0051  max mem: 3092
Epoch: [1]  [500/625]  eta: 0:00:22  loss: 0.8490 (1.3832)  time: 0.1796  data: 0.0047  max mem: 3092
Epoch: [1]  [600/625]  eta: 0:00:04  loss: 0.7958 (1.2901)  time: 0.1836  data: 0.0076  max mem: 3092
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 0.7124 (1.2705)  time: 0.1796  data: 0.0043  max mem: 3092
Epoch: [1] Total time: 0:01:53 (0.1815 s / it)
Averaged stats: loss: 0.7124 (1.2705)
Epoch: [2]  [  0/625]  eta: 0:03:29  loss: 0.6773 (0.6773)  time: 0.335

In [10]:
# save finetuned teacher model
torch.save(teacher.state_dict(), './teacher.pth')

In [11]:
teacher = create_model(
        'vit_base_patch16_224',
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
teacher = teacher.to(device)
teacher.load_state_dict(torch.load('./teacher.pth'))

test_stats = evaluate(cifar10_test_loader, teacher, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

# Train the student
for p in teacher.parameters():
    p.requires_grad = False

MODEL_NAME = 'vit_tiny_patch16_224'

model = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)


print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch_distillation(
        teacher, model, criterion, cifar10_training_loader,
        optimizer, device, epoch, alpha=1.0, temp=1.0)
    if epoch % 2 == 1:
        test_stats = evaluate(cifar10_test_loader, model, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

test_stats = evaluate(cifar10_test_loader, model, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

Test:  [ 0/40]  eta: 0:01:53  loss: 0.5829 (0.5829)  acc1: 82.4219 (82.4219)  acc5: 98.4375 (98.4375)  time: 2.8268  data: 1.7344  max mem: 3510
Test:  [20/40]  eta: 0:00:22  loss: 0.5100 (0.5287)  acc1: 83.2031 (83.3519)  acc5: 99.2188 (99.2746)  time: 1.0361  data: 0.0759  max mem: 3510
Test:  [39/40]  eta: 0:00:01  loss: 0.5019 (0.5256)  acc1: 83.2031 (83.2100)  acc5: 99.2188 (99.2400)  time: 0.9519  data: 0.0510  max mem: 3510
Test: Total time: 0:00:41 (1.0429 s / it)
* Acc@1 83.210 Acc@5 99.240 loss 0.526
Accuracy of the network on the 40 test images: 83.2%
number of params: 5526346
Start training for 5 epochs
Epoch: [1]  [  0/625]  eta: 0:03:26  loss: 2.1664 (2.1664)  time: 0.3309  data: 0.1566  max mem: 3510
Epoch: [1]  [100/625]  eta: 0:00:56  loss: 2.1607 (2.2988)  time: 0.0994  data: 0.0045  max mem: 3510
Epoch: [1]  [200/625]  eta: 0:00:45  loss: 2.0223 (2.1863)  time: 0.1005  data: 0.0047  max mem: 3510
Epoch: [1]  [300/625]  eta: 0:00:34  loss: 1.8026 (2.1052)  time: 0.107