# Vit for classification in CIFAR-10

In [1]:
%load_ext autoreload
%autoreload 2
!nvidia-smi
!which python

Wed Oct  2 19:34:41 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   41C    P0              40W / 300W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
try:
    from ViT.train import train
    from ViT.utils import cifar_train_set, cifar_val_set
    from ViT.model import *
except:
    from train import train
    from utils import cifar_train_set, cifar_val_set
    from model import *

from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F

import datetime

train_loader = DataLoader(cifar_train_set, 256, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(cifar_val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

import os
if not os.path.exists("ViT/log"):
    os.makedirs("ViT/log")

def timestr():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%S")

def get_outdir(time_str):
    outdir = f"ViT/log/{time_str}.out"
    return outdir

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Here is the hyperparameters

epochs = 50
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_dim = 512
mlp_dim = None # default to 4*embed_dim
pool = 'cls'
dropout = 0.0

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_dim=attn_dim, mlp_dim=mlp_dim, pool=pool, dropout=dropout)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

Time string: 20241002_154301
The model has 6,356,234 trainable parameters


Epoch 2/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 1/50, Train Loss: 2.0066, Train Accuracy: 24.60%, Val Loss: 1.8415, Val Accuracy: 31.70%


Epoch 3/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 2/50, Train Loss: 1.7390, Train Accuracy: 35.72%, Val Loss: 1.6409, Val Accuracy: 39.93%


Epoch 4/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 3/50, Train Loss: 1.6144, Train Accuracy: 41.03%, Val Loss: 1.5837, Val Accuracy: 43.22%


Epoch 5/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 4/50, Train Loss: 1.5596, Train Accuracy: 43.02%, Val Loss: 1.5484, Val Accuracy: 43.87%


Epoch 6/50:   1%|▎                                                      | 1/196 [00:00<00:31,  6.12it/s, Train Loss=1.56]

Epoch 5/50, Train Loss: 1.5098, Train Accuracy: 44.88%, Val Loss: 1.4888, Val Accuracy: 46.59%


Epoch 7/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 6/50, Train Loss: 1.4709, Train Accuracy: 46.43%, Val Loss: 1.4622, Val Accuracy: 47.08%


Epoch 8/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 7/50, Train Loss: 1.4370, Train Accuracy: 47.75%, Val Loss: 1.4017, Val Accuracy: 49.35%


Epoch 9/50:   1%|▎                                                      | 1/196 [00:00<00:32,  6.06it/s, Train Loss=1.37]

Epoch 8/50, Train Loss: 1.4167, Train Accuracy: 48.39%, Val Loss: 1.4207, Val Accuracy: 48.88%


Epoch 10/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 9/50, Train Loss: 1.3943, Train Accuracy: 49.30%, Val Loss: 1.3814, Val Accuracy: 50.10%


Epoch 11/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 10/50, Train Loss: 1.3747, Train Accuracy: 50.03%, Val Loss: 1.3663, Val Accuracy: 50.92%


Epoch 12/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 11/50, Train Loss: 1.3591, Train Accuracy: 50.80%, Val Loss: 1.3587, Val Accuracy: 50.62%


Epoch 13/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 12/50, Train Loss: 1.3472, Train Accuracy: 51.46%, Val Loss: 1.3412, Val Accuracy: 51.62%


Epoch 14/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 13/50, Train Loss: 1.3341, Train Accuracy: 52.02%, Val Loss: 1.3381, Val Accuracy: 51.07%


Epoch 15/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 14/50, Train Loss: 1.3228, Train Accuracy: 52.03%, Val Loss: 1.3055, Val Accuracy: 53.16%


Epoch 16/50:   1%|▎                                                     | 1/196 [00:00<00:31,  6.20it/s, Train Loss=1.25]

Epoch 15/50, Train Loss: 1.3082, Train Accuracy: 52.83%, Val Loss: 1.3194, Val Accuracy: 52.44%


Epoch 17/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.01it/s, Train Loss=1.26]

Epoch 16/50, Train Loss: 1.2975, Train Accuracy: 53.26%, Val Loss: 1.3298, Val Accuracy: 52.27%


Epoch 18/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 17/50, Train Loss: 1.2887, Train Accuracy: 53.45%, Val Loss: 1.2982, Val Accuracy: 52.66%


Epoch 19/50:   1%|▎                                                     | 1/196 [00:00<00:31,  6.14it/s, Train Loss=1.34]

Epoch 18/50, Train Loss: 1.2733, Train Accuracy: 54.33%, Val Loss: 1.3083, Val Accuracy: 52.91%


Epoch 20/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 19/50, Train Loss: 1.2602, Train Accuracy: 54.55%, Val Loss: 1.2738, Val Accuracy: 54.45%


Epoch 21/50:   1%|▎                                                     | 1/196 [00:00<00:32,  5.95it/s, Train Loss=1.22]

Epoch 20/50, Train Loss: 1.2547, Train Accuracy: 54.82%, Val Loss: 1.2953, Val Accuracy: 53.28%


Epoch 22/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 21/50, Train Loss: 1.2423, Train Accuracy: 55.31%, Val Loss: 1.2621, Val Accuracy: 54.77%


Epoch 23/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 22/50, Train Loss: 1.2351, Train Accuracy: 55.32%, Val Loss: 1.2557, Val Accuracy: 54.51%


Epoch 24/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 23/50, Train Loss: 1.2264, Train Accuracy: 55.95%, Val Loss: 1.2436, Val Accuracy: 55.78%


Epoch 25/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 24/50, Train Loss: 1.2148, Train Accuracy: 56.27%, Val Loss: 1.2404, Val Accuracy: 55.92%


Epoch 26/50:   0%|                                                              | 0/196 [00:00<?, ?it/s, Train Loss=1.15]

Epoch 25/50, Train Loss: 1.2011, Train Accuracy: 56.73%, Val Loss: 1.2312, Val Accuracy: 55.86%


Epoch 27/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 26/50, Train Loss: 1.1936, Train Accuracy: 57.13%, Val Loss: 1.1965, Val Accuracy: 56.77%


Epoch 28/50:   1%|▎                                                     | 1/196 [00:00<00:31,  6.15it/s, Train Loss=1.07]

Epoch 27/50, Train Loss: 1.1841, Train Accuracy: 57.24%, Val Loss: 1.2048, Val Accuracy: 56.57%


Epoch 29/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.02it/s, Train Loss=1.16]

Epoch 28/50, Train Loss: 1.1780, Train Accuracy: 57.80%, Val Loss: 1.1990, Val Accuracy: 57.18%


Epoch 30/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 29/50, Train Loss: 1.1665, Train Accuracy: 58.04%, Val Loss: 1.1759, Val Accuracy: 57.43%


Epoch 31/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.01it/s, Train Loss=1.17]

Epoch 30/50, Train Loss: 1.1572, Train Accuracy: 58.44%, Val Loss: 1.1902, Val Accuracy: 56.99%


Epoch 32/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.05it/s, Train Loss=1.03]

Epoch 31/50, Train Loss: 1.1577, Train Accuracy: 58.50%, Val Loss: 1.1795, Val Accuracy: 57.86%


Epoch 33/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 32/50, Train Loss: 1.1384, Train Accuracy: 59.22%, Val Loss: 1.1663, Val Accuracy: 58.44%


Epoch 34/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 33/50, Train Loss: 1.1347, Train Accuracy: 59.46%, Val Loss: 1.1661, Val Accuracy: 58.06%


Epoch 35/50:   1%|▎                                                     | 1/196 [00:00<00:31,  6.15it/s, Train Loss=1.19]

Epoch 34/50, Train Loss: 1.1216, Train Accuracy: 59.80%, Val Loss: 1.1722, Val Accuracy: 58.41%


Epoch 36/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 35/50, Train Loss: 1.1206, Train Accuracy: 59.94%, Val Loss: 1.1486, Val Accuracy: 58.67%


Epoch 37/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 36/50, Train Loss: 1.1114, Train Accuracy: 60.29%, Val Loss: 1.1327, Val Accuracy: 59.54%


Epoch 38/50:   1%|▎                                                     | 1/196 [00:00<00:31,  6.15it/s, Train Loss=1.15]

Epoch 37/50, Train Loss: 1.1053, Train Accuracy: 60.54%, Val Loss: 1.1504, Val Accuracy: 58.83%


Epoch 39/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 38/50, Train Loss: 1.1008, Train Accuracy: 60.40%, Val Loss: 1.1234, Val Accuracy: 59.61%


Epoch 40/50:   1%|▎                                                     | 1/196 [00:00<00:31,  6.18it/s, Train Loss=1.16]

Epoch 39/50, Train Loss: 1.0913, Train Accuracy: 60.92%, Val Loss: 1.1302, Val Accuracy: 59.90%


Epoch 41/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 40/50, Train Loss: 1.0840, Train Accuracy: 61.23%, Val Loss: 1.1220, Val Accuracy: 60.38%


Epoch 42/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 41/50, Train Loss: 1.0804, Train Accuracy: 61.55%, Val Loss: 1.1171, Val Accuracy: 60.12%


Epoch 43/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 42/50, Train Loss: 1.0669, Train Accuracy: 62.17%, Val Loss: 1.1145, Val Accuracy: 60.22%


Epoch 44/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 43/50, Train Loss: 1.0672, Train Accuracy: 61.92%, Val Loss: 1.1114, Val Accuracy: 60.36%


Epoch 45/50:   1%|▎                                                     | 1/196 [00:00<00:31,  6.11it/s, Train Loss=1.06]

Epoch 44/50, Train Loss: 1.0623, Train Accuracy: 62.08%, Val Loss: 1.1050, Val Accuracy: 61.05%


Epoch 46/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 45/50, Train Loss: 1.0486, Train Accuracy: 62.68%, Val Loss: 1.0681, Val Accuracy: 61.80%


Epoch 47/50:   1%|▎                                                     | 1/196 [00:00<00:31,  6.15it/s, Train Loss=1.04]

Epoch 46/50, Train Loss: 1.0432, Train Accuracy: 62.73%, Val Loss: 1.1257, Val Accuracy: 59.97%


Epoch 48/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.12it/s, Train Loss=0.974]

Epoch 47/50, Train Loss: 1.0389, Train Accuracy: 63.02%, Val Loss: 1.0709, Val Accuracy: 61.71%


Epoch 49/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.02it/s, Train Loss=0.988]

Epoch 48/50, Train Loss: 1.0321, Train Accuracy: 63.07%, Val Loss: 1.0732, Val Accuracy: 61.37%


Epoch 50/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.05it/s, Train Loss=1.09]

Epoch 49/50, Train Loss: 1.0256, Train Accuracy: 63.22%, Val Loss: 1.0755, Val Accuracy: 61.80%


                                                                                                                         

Epoch 50/50, Train Loss: 1.0150, Train Accuracy: 63.70%, Val Loss: 1.1179, Val Accuracy: 60.34%


In [3]:
# Here is the hyperparameters

epochs = 50
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_dim = 512
mlp_dim = None # default to 4*embed_dim
pool = 'cls'
dropout = 0.0

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_dim=attn_dim, mlp_dim=mlp_dim, pool=pool, dropout=dropout)

# load model
model.load_state_dict(torch.load("ViT.pth"))

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

Epoch 1/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Time string: 20241002_193454
The model has 6,356,234 trainable parameters


Epoch 2/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 1/50, Train Loss: 1.0157, Train Accuracy: 63.58%, Val Loss: 1.0695, Val Accuracy: 61.76%


Epoch 3/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 2/50, Train Loss: 1.0047, Train Accuracy: 64.19%, Val Loss: 1.0594, Val Accuracy: 62.25%


Epoch 4/50:   1%|▎                                                      | 1/196 [00:00<00:32,  6.02it/s, Train Loss=1.05]

Epoch 3/50, Train Loss: 0.9953, Train Accuracy: 64.43%, Val Loss: 1.0660, Val Accuracy: 62.32%


Epoch 5/50:   1%|▎                                                      | 1/196 [00:00<00:31,  6.16it/s, Train Loss=1.09]

Epoch 4/50, Train Loss: 0.9968, Train Accuracy: 64.52%, Val Loss: 1.0648, Val Accuracy: 62.69%


Epoch 6/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 5/50, Train Loss: 0.9931, Train Accuracy: 64.51%, Val Loss: 1.0309, Val Accuracy: 63.68%


Epoch 7/50:   1%|▎                                                         | 1/196 [00:00<00:32,  6.03it/s, Train Loss=1]

Epoch 6/50, Train Loss: 0.9848, Train Accuracy: 64.62%, Val Loss: 1.0466, Val Accuracy: 63.14%


Epoch 8/50:   1%|▎                                                     | 1/196 [00:00<00:31,  6.18it/s, Train Loss=0.979]

Epoch 7/50, Train Loss: 0.9822, Train Accuracy: 64.91%, Val Loss: 1.0363, Val Accuracy: 63.15%


Epoch 9/50:   1%|▎                                                      | 1/196 [00:00<00:32,  6.03it/s, Train Loss=0.93]

Epoch 8/50, Train Loss: 0.9763, Train Accuracy: 65.17%, Val Loss: 1.0339, Val Accuracy: 63.41%


Epoch 10/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 9/50, Train Loss: 0.9689, Train Accuracy: 65.43%, Val Loss: 1.0243, Val Accuracy: 63.99%


Epoch 11/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.17it/s, Train Loss=0.972]

Epoch 10/50, Train Loss: 0.9633, Train Accuracy: 65.82%, Val Loss: 1.0326, Val Accuracy: 63.65%


Epoch 12/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 11/50, Train Loss: 0.9603, Train Accuracy: 65.57%, Val Loss: 1.0116, Val Accuracy: 64.26%


Epoch 13/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.12it/s, Train Loss=0.862]

Epoch 12/50, Train Loss: 0.9583, Train Accuracy: 65.97%, Val Loss: 1.0180, Val Accuracy: 64.40%


Epoch 14/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.02it/s, Train Loss=0.856]

Epoch 13/50, Train Loss: 0.9492, Train Accuracy: 66.13%, Val Loss: 1.0227, Val Accuracy: 64.26%


Epoch 15/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 14/50, Train Loss: 0.9548, Train Accuracy: 65.90%, Val Loss: 1.0087, Val Accuracy: 64.38%


Epoch 16/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.15it/s, Train Loss=0.914]

Epoch 15/50, Train Loss: 0.9424, Train Accuracy: 66.25%, Val Loss: 1.0391, Val Accuracy: 63.10%


Epoch 17/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.02it/s, Train Loss=0.894]

Epoch 16/50, Train Loss: 0.9401, Train Accuracy: 66.42%, Val Loss: 1.0239, Val Accuracy: 63.67%


Epoch 18/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.14it/s, Train Loss=0.886]

Epoch 17/50, Train Loss: 0.9317, Train Accuracy: 66.66%, Val Loss: 1.0306, Val Accuracy: 63.44%


Epoch 19/50:   1%|▎                                                     | 1/196 [00:00<00:31,  6.15it/s, Train Loss=1.01]

Epoch 18/50, Train Loss: 0.9312, Train Accuracy: 66.50%, Val Loss: 1.0046, Val Accuracy: 64.94%


Epoch 20/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.98it/s, Train Loss=0.787]

Epoch 19/50, Train Loss: 0.9276, Train Accuracy: 66.90%, Val Loss: 1.0090, Val Accuracy: 64.47%


Epoch 21/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.18it/s, Train Loss=0.896]

Epoch 20/50, Train Loss: 0.9262, Train Accuracy: 67.02%, Val Loss: 1.0157, Val Accuracy: 64.45%


Epoch 22/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 21/50, Train Loss: 0.9156, Train Accuracy: 67.34%, Val Loss: 0.9859, Val Accuracy: 65.52%


Epoch 23/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 22/50, Train Loss: 0.9090, Train Accuracy: 67.53%, Val Loss: 0.9844, Val Accuracy: 65.16%


Epoch 24/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.03it/s, Train Loss=0.85]

Epoch 23/50, Train Loss: 0.9025, Train Accuracy: 67.92%, Val Loss: 0.9958, Val Accuracy: 64.85%


Epoch 25/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.16it/s, Train Loss=0.914]

Epoch 24/50, Train Loss: 0.9078, Train Accuracy: 67.40%, Val Loss: 0.9941, Val Accuracy: 64.63%


Epoch 26/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 25/50, Train Loss: 0.8989, Train Accuracy: 68.03%, Val Loss: 0.9794, Val Accuracy: 65.53%


Epoch 27/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.03it/s, Train Loss=0.962]

Epoch 26/50, Train Loss: 0.8946, Train Accuracy: 68.07%, Val Loss: 0.9799, Val Accuracy: 65.97%


Epoch 28/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.14it/s, Train Loss=0.802]

Epoch 27/50, Train Loss: 0.8920, Train Accuracy: 68.04%, Val Loss: 0.9983, Val Accuracy: 64.57%


Epoch 29/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 28/50, Train Loss: 0.8868, Train Accuracy: 68.29%, Val Loss: 0.9632, Val Accuracy: 65.78%


Epoch 30/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.13it/s, Train Loss=0.905]

Epoch 29/50, Train Loss: 0.8803, Train Accuracy: 68.58%, Val Loss: 0.9774, Val Accuracy: 65.85%


Epoch 31/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 30/50, Train Loss: 0.8771, Train Accuracy: 68.74%, Val Loss: 0.9585, Val Accuracy: 66.10%


Epoch 32/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 31/50, Train Loss: 0.8795, Train Accuracy: 68.77%, Val Loss: 0.9497, Val Accuracy: 66.60%


Epoch 33/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.03it/s, Train Loss=0.975]

Epoch 32/50, Train Loss: 0.8712, Train Accuracy: 69.08%, Val Loss: 0.9534, Val Accuracy: 66.18%


Epoch 34/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.96it/s, Train Loss=0.811]

Epoch 33/50, Train Loss: 0.8644, Train Accuracy: 68.97%, Val Loss: 0.9643, Val Accuracy: 65.72%


Epoch 35/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.17it/s, Train Loss=0.832]

Epoch 34/50, Train Loss: 0.8633, Train Accuracy: 69.02%, Val Loss: 0.9821, Val Accuracy: 65.86%


Epoch 36/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.04it/s, Train Loss=0.84]

Epoch 35/50, Train Loss: 0.8552, Train Accuracy: 69.46%, Val Loss: 0.9672, Val Accuracy: 65.51%


Epoch 37/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.01it/s, Train Loss=0.905]

Epoch 36/50, Train Loss: 0.8542, Train Accuracy: 69.44%, Val Loss: 0.9634, Val Accuracy: 65.86%


Epoch 38/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.06it/s, Train Loss=0.87]

Epoch 37/50, Train Loss: 0.8511, Train Accuracy: 69.54%, Val Loss: 0.9583, Val Accuracy: 66.57%


Epoch 39/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.15it/s, Train Loss=0.848]

Epoch 38/50, Train Loss: 0.8465, Train Accuracy: 69.83%, Val Loss: 0.9713, Val Accuracy: 65.89%


Epoch 40/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 39/50, Train Loss: 0.8404, Train Accuracy: 70.14%, Val Loss: 0.9340, Val Accuracy: 67.21%


Epoch 41/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.03it/s, Train Loss=0.826]

Epoch 40/50, Train Loss: 0.8367, Train Accuracy: 70.27%, Val Loss: 0.9476, Val Accuracy: 66.89%


Epoch 42/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.04it/s, Train Loss=0.912]

Epoch 41/50, Train Loss: 0.8383, Train Accuracy: 70.02%, Val Loss: 0.9433, Val Accuracy: 67.18%


Epoch 43/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.97it/s, Train Loss=0.729]

Epoch 42/50, Train Loss: 0.8259, Train Accuracy: 70.47%, Val Loss: 0.9393, Val Accuracy: 67.38%


Epoch 44/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.06it/s, Train Loss=0.845]

Epoch 43/50, Train Loss: 0.8303, Train Accuracy: 70.24%, Val Loss: 0.9504, Val Accuracy: 67.11%


Epoch 45/50:   1%|▎                                                     | 1/196 [00:00<00:31,  6.13it/s, Train Loss=0.88]

Epoch 44/50, Train Loss: 0.8247, Train Accuracy: 70.39%, Val Loss: 0.9358, Val Accuracy: 66.99%


Epoch 46/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.09it/s, Train Loss=0.686]

Epoch 45/50, Train Loss: 0.8196, Train Accuracy: 70.61%, Val Loss: 0.9456, Val Accuracy: 66.48%


Epoch 47/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.17it/s, Train Loss=0.829]

Epoch 46/50, Train Loss: 0.8174, Train Accuracy: 70.64%, Val Loss: 0.9481, Val Accuracy: 67.52%


Epoch 48/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 47/50, Train Loss: 0.8117, Train Accuracy: 71.20%, Val Loss: 0.9124, Val Accuracy: 67.68%


Epoch 49/50:   1%|▎                                                    | 1/196 [00:00<00:31,  6.15it/s, Train Loss=0.798]

Epoch 48/50, Train Loss: 0.8114, Train Accuracy: 70.97%, Val Loss: 0.9263, Val Accuracy: 67.69%


Epoch 50/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.03it/s, Train Loss=0.906]

Epoch 49/50, Train Loss: 0.8088, Train Accuracy: 71.11%, Val Loss: 0.9241, Val Accuracy: 67.91%


                                                                                                                         

Epoch 50/50, Train Loss: 0.7997, Train Accuracy: 71.20%, Val Loss: 0.9492, Val Accuracy: 67.31%


In [None]:
# Here is the hyperparameters

epochs = 50
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_dim = 512
mlp_dim = None # default to 4*embed_dim
pool = 'cls'
dropout = 0.0

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_dim=attn_dim, mlp_dim=mlp_dim, pool=pool, dropout=dropout)

# load model
model.load_state_dict(torch.load("ViT.pth"))

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

In [4]:
# save the model
torch.save(model.state_dict(), 'ViT.pth')