In [1]:
! pip install torch numpy timm==0.5.4 tqdm

Collecting timm==0.5.4
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m431.5/431.5 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.

In [2]:
!git clone https://github.com/tsungchiehchen/Vision-Transformer.git

Cloning into 'Vision-Transformer'...
remote: Enumerating objects: 33, done.[K
remote: Counting objects: 100% (33/33), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 33 (delta 13), reused 29 (delta 9), pack-reused 0[K
Receiving objects: 100% (33/33), 31.23 KiB | 15.62 MiB/s, done.
Resolving deltas: 100% (13/13), done.


In [3]:
%cd ./Vision-Transformer

/content/Vision-Transformer


In [4]:
import argparse
import datetime
import os
import sys
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn

from timm.models import create_model

from engine import train_one_epoch, train_one_epoch_distillation, evaluate
from utils import get_training_dataloader, get_test_dataloader
import models

In [5]:
MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
CHECKPOINT_PATH = './checkpoint'
MODEL_NAME = 'vit_base_patch16_224'
num_classes = 10
EPOCHS = 5
LR = 0.0001
WD = 0.0
shots = 1000

print(f"Creating model: {MODEL_NAME}")
model = create_model(
        MODEL_NAME,
        pretrained=False,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
model = model.to(device)

cifar10_training_loader = get_training_dataloader(
    MEAN,
    STD,
    num_workers=2,
    batch_size=16,
    shuffle=True,
    shots=shots
)

assert (shots*num_classes == len(cifar10_training_loader.dataset))

cifar10_test_loader = get_test_dataloader(
    MEAN,
    STD,
    num_workers=4,
    batch_size=256,
    shuffle=False
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

Creating model: vit_base_patch16_224
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13411296.15it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
number of params: 86567656




In [6]:
print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        model, criterion, cifar10_training_loader,
        optimizer, device, epoch)
    if epoch % 10 == 9:
        test_stats = evaluate(cifar10_test_loader, model, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

test_stats = evaluate(cifar10_test_loader, model, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

Start training for 5 epochs
Epoch: [1]  [  0/625]  eta: 0:12:02  loss: 6.8875 (6.8875)  time: 1.1564  data: 0.1729  max mem: 2096
Epoch: [1]  [100/625]  eta: 0:01:36  loss: 2.1863 (2.5013)  time: 0.1718  data: 0.0049  max mem: 3065
Epoch: [1]  [200/625]  eta: 0:01:17  loss: 2.0443 (2.3014)  time: 0.1768  data: 0.0060  max mem: 3065
Epoch: [1]  [300/625]  eta: 0:00:59  loss: 2.0471 (2.2176)  time: 0.1782  data: 0.0048  max mem: 3065
Epoch: [1]  [400/625]  eta: 0:00:40  loss: 2.0320 (2.1723)  time: 0.1870  data: 0.0075  max mem: 3065
Epoch: [1]  [500/625]  eta: 0:00:22  loss: 1.9226 (2.1359)  time: 0.1852  data: 0.0048  max mem: 3065
Epoch: [1]  [600/625]  eta: 0:00:04  loss: 1.9641 (2.1090)  time: 0.1850  data: 0.0070  max mem: 3065
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 1.8869 (2.1024)  time: 0.1813  data: 0.0050  max mem: 3065
Epoch: [1] Total time: 0:01:54 (0.1829 s / it)
Epoch: [1] Training Accuracy: 23.53%
Averaged stats: loss: 1.8869 (2.1024)
Epoch: [2]  [  0/625]  eta: 0:03:1

In [7]:
# Calculate througput
start_time = time.time()
test_stats = evaluate(cifar10_test_loader, model, criterion, device)
end_time = time.time()
num_samples = len(cifar10_test_loader.dataset)
throughput = num_samples / (end_time - start_time)
print("Throughput: {}".format(throughput))

Test:  [ 0/40]  eta: 0:01:44  loss: 1.6280 (1.6280)  acc1: 38.2812 (38.2812)  acc5: 91.0156 (91.0156)  time: 2.6090  data: 1.5375  max mem: 3156
Test:  [20/40]  eta: 0:00:22  loss: 1.6459 (1.6371)  acc1: 38.2812 (38.1324)  acc5: 89.4531 (89.7879)  time: 1.0650  data: 0.0889  max mem: 3156
Test:  [39/40]  eta: 0:00:01  loss: 1.6502 (1.6464)  acc1: 38.6719 (38.2600)  acc5: 88.6719 (89.1000)  time: 0.9484  data: 0.0540  max mem: 3156
Test: Total time: 0:00:41 (1.0459 s / it)
* Acc@1 38.260 Acc@5 89.100 loss 1.646
Throughput: 238.9909673432086


# Q2 Fine-tuning Pretrained ViT

In [8]:
MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
CHECKPOINT_PATH = './checkpoint'
MODEL_NAME = 'vit_base_patch16_224'
num_classes = 10
EPOCHS = 5
LR = 0.0001
WD = 0.0
shots = 1000

print(f"Creating model: {MODEL_NAME}")
model = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
model = model.to(device)

cifar10_training_loader = get_training_dataloader(
    MEAN,
    STD,
    num_workers=2,
    batch_size=16,
    shuffle=True,
    shots=shots
)

assert (shots*num_classes == len(cifar10_training_loader.dataset))

cifar10_test_loader = get_test_dataloader(
    MEAN,
    STD,
    num_workers=4,
    batch_size=256,
    shuffle=False
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

Creating model: vit_base_patch16_224


Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /root/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
100%|██████████| 330M/330M [00:01<00:00, 179MB/s]


Files already downloaded and verified
Files already downloaded and verified
number of params: 85806346


In [9]:
print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        model, criterion, cifar10_training_loader,
        optimizer, device, epoch)
    if epoch % 10 == 9:
        test_stats = evaluate(cifar10_test_loader, model, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

test_stats = evaluate(cifar10_test_loader, model, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

Start training for 5 epochs
Epoch: [1]  [  0/625]  eta: 0:04:40  loss: 2.2632 (2.2632)  time: 0.4489  data: 0.1838  max mem: 3156
Epoch: [1]  [100/625]  eta: 0:01:38  loss: 2.1144 (2.2261)  time: 0.1855  data: 0.0051  max mem: 3156
Epoch: [1]  [200/625]  eta: 0:01:19  loss: 1.9214 (2.1336)  time: 0.1842  data: 0.0048  max mem: 3156
Epoch: [1]  [300/625]  eta: 0:01:00  loss: 1.8683 (2.0609)  time: 0.1920  data: 0.0105  max mem: 3156
Epoch: [1]  [400/625]  eta: 0:00:42  loss: 1.6490 (1.9844)  time: 0.1822  data: 0.0047  max mem: 3156
Epoch: [1]  [500/625]  eta: 0:00:23  loss: 1.3657 (1.9079)  time: 0.1886  data: 0.0074  max mem: 3156
Epoch: [1]  [600/625]  eta: 0:00:04  loss: 1.3566 (1.8288)  time: 0.1847  data: 0.0048  max mem: 3156
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 1.3284 (1.8111)  time: 0.1904  data: 0.0091  max mem: 3156
Epoch: [1] Total time: 0:01:57 (0.1874 s / it)
Epoch: [1] Training Accuracy: 32.97%
Averaged stats: loss: 1.3284 (1.8111)
Epoch: [2]  [  0/625]  eta: 0:03:3

In [10]:
# Calculate througput
start_time = time.time()
test_stats = evaluate(cifar10_test_loader, model, criterion, device)
end_time = time.time()
num_samples = len(cifar10_test_loader.dataset)
throughput = num_samples / (end_time - start_time)
print("Throughput: {}".format(throughput))

Test:  [ 0/40]  eta: 0:02:12  loss: 0.6096 (0.6096)  acc1: 77.7344 (77.7344)  acc5: 99.2188 (99.2188)  time: 3.3081  data: 2.1698  max mem: 3156
Test:  [20/40]  eta: 0:00:23  loss: 0.5715 (0.5838)  acc1: 80.8594 (80.6362)  acc5: 98.8281 (98.9211)  time: 1.0619  data: 0.0785  max mem: 3156
Test:  [39/40]  eta: 0:00:01  loss: 0.5826 (0.5795)  acc1: 80.0781 (80.3900)  acc5: 99.2188 (99.0800)  time: 0.9607  data: 0.0468  max mem: 3156
Test: Total time: 0:00:42 (1.0732 s / it)
* Acc@1 80.390 Acc@5 99.080 loss 0.579
Throughput: 232.89691721747135


# Q3 ViT model on a small device

In [11]:
MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
CHECKPOINT_PATH = './checkpoint'
MODEL_NAME = 'vit_tiny_patch16_224'
num_classes = 10
EPOCHS = 5
LR = 0.0001
WD = 0.0
shots = 1000

print(f"Creating model: {MODEL_NAME}")
model = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
model = model.to(device)

cifar10_training_loader = get_training_dataloader(
    MEAN,
    STD,
    num_workers=2,
    batch_size=16,
    shuffle=True,
    shots=shots
)

assert (shots*num_classes == len(cifar10_training_loader.dataset))

cifar10_test_loader = get_test_dataloader(
    MEAN,
    STD,
    num_workers=4,
    batch_size=256,
    shuffle=False
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

Creating model: vit_tiny_patch16_224


Downloading: "https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth" to /root/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth
100%|██████████| 21.9M/21.9M [00:00<00:00, 134MB/s]


Files already downloaded and verified
Files already downloaded and verified
number of params: 5526346


In [12]:
print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        model, criterion, cifar10_training_loader,
        optimizer, device, epoch)
    if epoch % 10 == 9:
        test_stats = evaluate(cifar10_test_loader, model, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

test_stats = evaluate(cifar10_test_loader, model, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

Start training for 5 epochs


  return F.conv2d(input, weight, bias, self.stride,


Epoch: [1]  [  0/625]  eta: 0:03:05  loss: 2.3241 (2.3241)  time: 0.2961  data: 0.1723  max mem: 3156
Epoch: [1]  [100/625]  eta: 0:00:30  loss: 2.2290 (2.2492)  time: 0.0554  data: 0.0047  max mem: 3156
Epoch: [1]  [200/625]  eta: 0:00:26  loss: 1.9586 (2.1558)  time: 0.0637  data: 0.0067  max mem: 3156
Epoch: [1]  [300/625]  eta: 0:00:19  loss: 1.8633 (2.0821)  time: 0.0561  data: 0.0050  max mem: 3156
Epoch: [1]  [400/625]  eta: 0:00:13  loss: 1.7633 (2.0298)  time: 0.0854  data: 0.0110  max mem: 3156
Epoch: [1]  [500/625]  eta: 0:00:07  loss: 1.5799 (1.9757)  time: 0.0548  data: 0.0046  max mem: 3156
Epoch: [1]  [600/625]  eta: 0:00:01  loss: 1.4734 (1.9215)  time: 0.0923  data: 0.0134  max mem: 3156
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 1.4968 (1.9079)  time: 0.0565  data: 0.0058  max mem: 3156
Epoch: [1] Total time: 0:00:39 (0.0627 s / it)
Epoch: [1] Training Accuracy: 27.69%
Averaged stats: loss: 1.4968 (1.9079)
Epoch: [2]  [  0/625]  eta: 0:02:42  loss: 1.4784 (1.4784)  ti

In [13]:
# Calculate througput
start_time = time.time()
test_stats = evaluate(cifar10_test_loader, model, criterion, device)
end_time = time.time()
num_samples = len(cifar10_test_loader.dataset)
throughput = num_samples / (end_time - start_time)
print("Throughput: {}".format(throughput))

Test:  [ 0/40]  eta: 0:01:32  loss: 0.8062 (0.8062)  acc1: 73.4375 (73.4375)  acc5: 98.0469 (98.0469)  time: 2.3161  data: 1.9110  max mem: 3156
Test:  [20/40]  eta: 0:00:11  loss: 0.7798 (0.7825)  acc1: 72.6562 (73.3259)  acc5: 98.4375 (98.3631)  time: 0.4887  data: 0.1797  max mem: 3156
Test:  [39/40]  eta: 0:00:00  loss: 0.7851 (0.7798)  acc1: 72.6562 (73.1600)  acc5: 98.4375 (98.3500)  time: 0.3562  data: 0.1108  max mem: 3156
Test: Total time: 0:00:18 (0.4709 s / it)
* Acc@1 73.160 Acc@5 98.350 loss 0.780
Throughput: 530.8179297502885


# Q4 Knowledge Distillation

In [14]:
# Step 1: Train the teacher
MODEL_NAME = 'vit_base_patch16_224'
num_classes = 10
EPOCHS = 5
LR = 0.0001
WD = 0.0

print(f"Creating model: {MODEL_NAME}")
teacher = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
teacher = teacher.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(teacher.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in teacher.parameters() if p.requires_grad)
print('number of params:', n_parameters)

Creating model: vit_base_patch16_224
number of params: 85806346


In [15]:
print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch(
        teacher, criterion, cifar10_training_loader,
        optimizer, device, epoch)
    if epoch % 10 == 9:
        test_stats = evaluate(cifar10_test_loader, teacher, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

test_stats = evaluate(cifar10_test_loader, teacher, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

Start training for 5 epochs
Epoch: [1]  [  0/625]  eta: 0:04:20  loss: 2.4588 (2.4588)  time: 0.4176  data: 0.1720  max mem: 3156
Epoch: [1]  [100/625]  eta: 0:01:40  loss: 1.9478 (2.1220)  time: 0.1913  data: 0.0069  max mem: 3156
Epoch: [1]  [200/625]  eta: 0:01:20  loss: 1.4271 (1.8895)  time: 0.1855  data: 0.0048  max mem: 3156
Epoch: [1]  [300/625]  eta: 0:01:01  loss: 1.0312 (1.6797)  time: 0.1909  data: 0.0107  max mem: 3156
Epoch: [1]  [400/625]  eta: 0:00:42  loss: 0.9570 (1.5365)  time: 0.1832  data: 0.0049  max mem: 3156
Epoch: [1]  [500/625]  eta: 0:00:23  loss: 0.8266 (1.4257)  time: 0.1922  data: 0.0101  max mem: 3156
Epoch: [1]  [600/625]  eta: 0:00:04  loss: 0.6655 (1.3293)  time: 0.1867  data: 0.0050  max mem: 3156
Epoch: [1]  [624/625]  eta: 0:00:00  loss: 0.8220 (1.3121)  time: 0.1864  data: 0.0048  max mem: 3156
Epoch: [1] Total time: 0:01:57 (0.1880 s / it)
Epoch: [1] Training Accuracy: 52.63%
Averaged stats: loss: 0.8220 (1.3121)
Epoch: [2]  [  0/625]  eta: 0:04:0

In [18]:
# save finetuned teacher model
torch.save(teacher.state_dict(), './teacher.pth')

In [19]:
teacher = create_model(
        'vit_base_patch16_224',
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
teacher = teacher.to(device)
teacher.load_state_dict(torch.load('./teacher.pth'))

test_stats = evaluate(cifar10_test_loader, teacher, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

# Train the student
for p in teacher.parameters():
    p.requires_grad = False

MODEL_NAME = 'vit_tiny_patch16_224'

model = create_model(
        MODEL_NAME,
        pretrained=True,
        num_classes=10,
        img_size=224)
device = 'cuda:0' # device = 'cpu'
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WD)


n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)


print(f"Start training for {EPOCHS} epochs")

for epoch in range(1, EPOCHS+1):
    train_stats = train_one_epoch_distillation(
        teacher, model, criterion, cifar10_training_loader,
        optimizer, device, epoch, alpha=1.0, temp=1.0)
    if epoch % 2 == 1:
        test_stats = evaluate(cifar10_test_loader, model, criterion, device)
        print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

test_stats = evaluate(cifar10_test_loader, model, criterion, device)
print(f"Accuracy of the network on the {len(cifar10_test_loader)} test images: {test_stats['acc1']:.1f}%")

Test:  [ 0/40]  eta: 0:02:53  loss: 0.5764 (0.5764)  acc1: 81.2500 (81.2500)  acc5: 99.2188 (99.2188)  time: 4.3401  data: 3.1810  max mem: 3512
Test:  [20/40]  eta: 0:00:24  loss: 0.6046 (0.6022)  acc1: 80.8594 (81.2872)  acc5: 98.8281 (98.7909)  time: 1.0940  data: 0.0651  max mem: 3512
Test:  [39/40]  eta: 0:00:01  loss: 0.5773 (0.5969)  acc1: 79.6875 (81.0300)  acc5: 98.8281 (98.8400)  time: 1.0070  data: 0.0554  max mem: 3512
Test: Total time: 0:00:45 (1.1347 s / it)
* Acc@1 81.030 Acc@5 98.840 loss 0.597
Accuracy of the network on the 40 test images: 81.0%
number of params: 5526346
Start training for 5 epochs
Epoch: [1]  [  0/625]  eta: 0:03:13  loss: 3.0322 (3.0322)  time: 0.3098  data: 0.1962  max mem: 3512
Epoch: [1]  [100/625]  eta: 0:00:56  loss: 2.1571 (2.2451)  time: 0.1201  data: 0.0094  max mem: 3512
Epoch: [1]  [200/625]  eta: 0:00:44  loss: 1.7630 (2.1014)  time: 0.1020  data: 0.0047  max mem: 3512
Epoch: [1]  [300/625]  eta: 0:00:34  loss: 1.5722 (1.9864)  time: 0.102