# Vit for classification in CIFAR-10

In [1]:
%load_ext autoreload
%autoreload 2
!nvidia-smi
!which python

Fri Oct  4 17:54:52 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   43C    P0              50W / 184W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2-32GB           On  | 00000004:05:00.0 Off |  

In [2]:
try:
    from ViT.train import train
    from ViT.utils import cifar_train_set, cifar_val_set
    from ViT.model import *
except:
    from train import train
    from utils import cifar_train_set, cifar_val_set
    from model import *

from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F

import datetime

train_loader = DataLoader(cifar_train_set, 256, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(cifar_val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

import os
if not os.path.exists("ViT/log"):
    os.makedirs("ViT/log")

def timestr():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%S")

def get_outdir(time_str):
    outdir = f"ViT/log/{time_str}.out"
    return outdir

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Here is the hyperparameters
# 0-50
epochs = 50
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_dim = 512
mlp_dim = None # default to 4*embed_dim
pool = 'cls'
dropout = 0.1
mlp_dropout = 0.0

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_dim=attn_dim, mlp_dim=mlp_dim, pool=pool, dropout=dropout, mlp_dropout=mlp_dropout)

# last_time_str = "20241003_124314"

# # load model
# model.load_state_dict(torch.load(f"ViT/models/{last_time_str}.pth"))

# print(f"models loaded from ViT/models/{last_time_str}.pth")

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

torch.save(model.state_dict(), f"ViT/models/{time_str}.pth")

print(f"models saved to ViT/models/{time_str}.pth")

Time string: 20241004_175456
The model has 6,356,234 trainable parameters


Epoch 2/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 1/50, Train Loss: 1.9537, Train Accuracy: 26.69%, Val Loss: 1.7563, Val Accuracy: 35.77%


Epoch 3/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 2/50, Train Loss: 1.6786, Train Accuracy: 38.62%, Val Loss: 1.5335, Val Accuracy: 44.66%


Epoch 4/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 3/50, Train Loss: 1.5890, Train Accuracy: 41.87%, Val Loss: 1.4881, Val Accuracy: 46.45%


Epoch 5/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 4/50, Train Loss: 1.5377, Train Accuracy: 44.22%, Val Loss: 1.4159, Val Accuracy: 49.27%


Epoch 6/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 5/50, Train Loss: 1.4966, Train Accuracy: 45.65%, Val Loss: 1.4128, Val Accuracy: 48.93%


Epoch 7/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 6/50, Train Loss: 1.4756, Train Accuracy: 46.37%, Val Loss: 1.3956, Val Accuracy: 49.67%


Epoch 8/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 7/50, Train Loss: 1.4455, Train Accuracy: 47.34%, Val Loss: 1.3911, Val Accuracy: 49.27%


Epoch 9/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 8/50, Train Loss: 1.4232, Train Accuracy: 48.31%, Val Loss: 1.3322, Val Accuracy: 51.40%


Epoch 10/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.49]

Epoch 9/50, Train Loss: 1.4011, Train Accuracy: 49.18%, Val Loss: 1.3464, Val Accuracy: 51.60%


Epoch 11/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.20it/s, Train Loss=1.37]

Epoch 10/50, Train Loss: 1.3769, Train Accuracy: 49.99%, Val Loss: 1.3400, Val Accuracy: 52.11%


Epoch 12/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 11/50, Train Loss: 1.3656, Train Accuracy: 50.54%, Val Loss: 1.3036, Val Accuracy: 53.49%


Epoch 13/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 12/50, Train Loss: 1.3485, Train Accuracy: 51.70%, Val Loss: 1.2453, Val Accuracy: 55.70%


Epoch 14/50:   1%|▎                                                   | 1/196 [00:00<00:37,  5.13it/s, Train Loss=1.3]

Epoch 13/50, Train Loss: 1.3325, Train Accuracy: 51.97%, Val Loss: 1.2648, Val Accuracy: 54.08%


Epoch 15/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 14/50, Train Loss: 1.3193, Train Accuracy: 52.14%, Val Loss: 1.2447, Val Accuracy: 55.04%


Epoch 16/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 15/50, Train Loss: 1.3106, Train Accuracy: 52.84%, Val Loss: 1.2332, Val Accuracy: 55.40%


Epoch 17/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.27]

Epoch 16/50, Train Loss: 1.3044, Train Accuracy: 53.12%, Val Loss: 1.2466, Val Accuracy: 55.40%


Epoch 18/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.34]

Epoch 17/50, Train Loss: 1.2813, Train Accuracy: 54.10%, Val Loss: 1.2524, Val Accuracy: 54.87%


Epoch 19/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 18/50, Train Loss: 1.2777, Train Accuracy: 54.02%, Val Loss: 1.1892, Val Accuracy: 57.84%


Epoch 20/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 19/50, Train Loss: 1.2702, Train Accuracy: 54.21%, Val Loss: 1.1718, Val Accuracy: 58.32%


Epoch 21/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.18it/s, Train Loss=1.31]

Epoch 20/50, Train Loss: 1.2556, Train Accuracy: 54.85%, Val Loss: 1.1994, Val Accuracy: 57.04%


Epoch 22/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.34]

Epoch 21/50, Train Loss: 1.2464, Train Accuracy: 55.07%, Val Loss: 1.1741, Val Accuracy: 58.27%


Epoch 23/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.22it/s, Train Loss=1.17]

Epoch 22/50, Train Loss: 1.2370, Train Accuracy: 55.59%, Val Loss: 1.2198, Val Accuracy: 56.20%


Epoch 24/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 23/50, Train Loss: 1.2274, Train Accuracy: 55.86%, Val Loss: 1.1589, Val Accuracy: 58.71%


Epoch 25/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.22it/s, Train Loss=1.17]

Epoch 24/50, Train Loss: 1.2193, Train Accuracy: 56.08%, Val Loss: 1.1673, Val Accuracy: 58.50%


Epoch 26/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.22it/s, Train Loss=1.15]

Epoch 25/50, Train Loss: 1.2087, Train Accuracy: 56.58%, Val Loss: 1.1604, Val Accuracy: 58.62%


Epoch 27/50:   1%|▎                                                  | 1/196 [00:00<00:38,  5.11it/s, Train Loss=1.21]

Epoch 26/50, Train Loss: 1.2056, Train Accuracy: 56.69%, Val Loss: 1.1728, Val Accuracy: 57.89%


Epoch 28/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 27/50, Train Loss: 1.1967, Train Accuracy: 57.07%, Val Loss: 1.1360, Val Accuracy: 59.87%


Epoch 29/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 28/50, Train Loss: 1.1876, Train Accuracy: 57.21%, Val Loss: 1.1032, Val Accuracy: 60.38%


Epoch 30/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.22it/s, Train Loss=1.16]

Epoch 29/50, Train Loss: 1.1762, Train Accuracy: 57.82%, Val Loss: 1.1141, Val Accuracy: 60.09%


Epoch 31/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 30/50, Train Loss: 1.1740, Train Accuracy: 57.93%, Val Loss: 1.1192, Val Accuracy: 59.92%


Epoch 32/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 31/50, Train Loss: 1.1724, Train Accuracy: 57.98%, Val Loss: 1.0993, Val Accuracy: 60.99%


Epoch 33/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 32/50, Train Loss: 1.1626, Train Accuracy: 58.28%, Val Loss: 1.0830, Val Accuracy: 61.24%


Epoch 34/50:   1%|▎                                                   | 1/196 [00:00<00:37,  5.22it/s, Train Loss=1.1]

Epoch 33/50, Train Loss: 1.1537, Train Accuracy: 58.65%, Val Loss: 1.1581, Val Accuracy: 58.29%


Epoch 35/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 34/50, Train Loss: 1.1449, Train Accuracy: 59.18%, Val Loss: 1.0733, Val Accuracy: 61.31%


Epoch 36/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.15it/s, Train Loss=1.35]

Epoch 35/50, Train Loss: 1.1435, Train Accuracy: 59.06%, Val Loss: 1.0957, Val Accuracy: 60.98%


Epoch 37/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.03]

Epoch 36/50, Train Loss: 1.1392, Train Accuracy: 59.22%, Val Loss: 1.0806, Val Accuracy: 61.46%


Epoch 38/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 37/50, Train Loss: 1.1345, Train Accuracy: 59.49%, Val Loss: 1.0682, Val Accuracy: 62.07%


Epoch 39/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.11]

Epoch 38/50, Train Loss: 1.1330, Train Accuracy: 59.36%, Val Loss: 1.0702, Val Accuracy: 61.53%


Epoch 40/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.13]

Epoch 39/50, Train Loss: 1.1206, Train Accuracy: 59.78%, Val Loss: 1.0997, Val Accuracy: 60.97%


Epoch 41/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 40/50, Train Loss: 1.1220, Train Accuracy: 59.79%, Val Loss: 1.0223, Val Accuracy: 63.67%


Epoch 42/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.19it/s, Train Loss=1.03]

Epoch 41/50, Train Loss: 1.1184, Train Accuracy: 59.96%, Val Loss: 1.0527, Val Accuracy: 62.43%


Epoch 43/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.22it/s, Train Loss=1.11]

Epoch 42/50, Train Loss: 1.1080, Train Accuracy: 60.19%, Val Loss: 1.0538, Val Accuracy: 62.30%


Epoch 44/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.19]

Epoch 43/50, Train Loss: 1.1023, Train Accuracy: 60.72%, Val Loss: 1.0628, Val Accuracy: 62.14%


Epoch 45/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.22it/s, Train Loss=1.16]

Epoch 44/50, Train Loss: 1.1016, Train Accuracy: 60.53%, Val Loss: 1.0287, Val Accuracy: 63.40%


Epoch 46/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.20it/s, Train Loss=1.06]

Epoch 45/50, Train Loss: 1.0960, Train Accuracy: 60.78%, Val Loss: 1.0615, Val Accuracy: 62.75%


Epoch 47/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.20it/s, Train Loss=1.01]

Epoch 46/50, Train Loss: 1.0917, Train Accuracy: 61.19%, Val Loss: 1.0511, Val Accuracy: 62.57%


Epoch 48/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.22it/s, Train Loss=1.11]

Epoch 47/50, Train Loss: 1.0886, Train Accuracy: 60.89%, Val Loss: 1.0537, Val Accuracy: 61.89%


Epoch 49/50:   1%|▎                                                   | 1/196 [00:00<00:37,  5.22it/s, Train Loss=1.1]

Epoch 48/50, Train Loss: 1.0794, Train Accuracy: 61.30%, Val Loss: 1.0399, Val Accuracy: 62.79%


Epoch 50/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 49/50, Train Loss: 1.0813, Train Accuracy: 61.16%, Val Loss: 0.9974, Val Accuracy: 65.11%


                                                                                                                      

Epoch 50/50, Train Loss: 1.0788, Train Accuracy: 61.42%, Val Loss: 1.0130, Val Accuracy: 64.12%


FileNotFoundError: [Errno 2] No such file or directory: 'ViT/models/20241004_175456.pth'

In [5]:
# Here is the hyperparameters
# 50-100
epochs = 50
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_dim = 512
mlp_dim = None # default to 4*embed_dim
pool = 'cls'
dropout = 0.1
mlp_dropout = 0.0

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_dim=attn_dim, mlp_dim=mlp_dim, pool=pool, dropout=dropout, mlp_dropout=mlp_dropout)

last_time_str = "20241004_175456"

# load model
model.load_state_dict(torch.load(f"ViT/models/{last_time_str}.pth"))

print(f"models loaded from ViT/models/{last_time_str}.pth")

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

torch.save(model.state_dict(), f"ViT/models/{time_str}.pth")

print(f"models saved to ViT/models/{time_str}.pth")

Epoch 1/50:   0%|                                                            | 0/196 [00:00<?, ?it/s, Train Loss=1.02]

models loaded from ViT/models/20241004_175456.pth
Time string: 20241004_183234
The model has 6,356,234 trainable parameters


Epoch 2/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 1/50, Train Loss: 1.0816, Train Accuracy: 61.38%, Val Loss: 1.0409, Val Accuracy: 62.71%


Epoch 3/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 2/50, Train Loss: 1.0665, Train Accuracy: 62.11%, Val Loss: 0.9814, Val Accuracy: 65.04%


Epoch 4/50:   1%|▎                                                   | 1/196 [00:00<00:37,  5.20it/s, Train Loss=1.16]

Epoch 3/50, Train Loss: 1.0720, Train Accuracy: 61.68%, Val Loss: 1.0032, Val Accuracy: 65.30%


Epoch 5/50:   1%|▎                                                   | 1/196 [00:00<00:37,  5.20it/s, Train Loss=1.07]

Epoch 4/50, Train Loss: 1.0625, Train Accuracy: 62.25%, Val Loss: 0.9859, Val Accuracy: 65.55%


Epoch 6/50:   1%|▎                                                   | 1/196 [00:00<00:37,  5.20it/s, Train Loss=1.08]

Epoch 5/50, Train Loss: 1.0628, Train Accuracy: 61.95%, Val Loss: 1.0168, Val Accuracy: 63.93%


Epoch 7/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.22it/s, Train Loss=0.963]

Epoch 6/50, Train Loss: 1.0570, Train Accuracy: 62.20%, Val Loss: 1.0164, Val Accuracy: 64.39%


Epoch 8/50:   1%|▎                                                  | 1/196 [00:00<00:38,  5.12it/s, Train Loss=0.994]

Epoch 7/50, Train Loss: 1.0540, Train Accuracy: 62.48%, Val Loss: 0.9956, Val Accuracy: 64.46%


Epoch 9/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.995]

Epoch 8/50, Train Loss: 1.0537, Train Accuracy: 62.15%, Val Loss: 0.9956, Val Accuracy: 64.78%


Epoch 10/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 9/50, Train Loss: 1.0476, Train Accuracy: 62.52%, Val Loss: 0.9534, Val Accuracy: 66.20%


Epoch 11/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.06]

Epoch 10/50, Train Loss: 1.0474, Train Accuracy: 62.68%, Val Loss: 1.0183, Val Accuracy: 64.30%


Epoch 12/50:   0%|                                                           | 0/196 [00:00<?, ?it/s, Train Loss=1.02]

Epoch 11/50, Train Loss: 1.0434, Train Accuracy: 62.84%, Val Loss: 1.0042, Val Accuracy: 64.35%


Epoch 13/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.996]

Epoch 12/50, Train Loss: 1.0399, Train Accuracy: 62.73%, Val Loss: 0.9925, Val Accuracy: 64.28%


Epoch 14/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.12]

Epoch 13/50, Train Loss: 1.0327, Train Accuracy: 63.20%, Val Loss: 1.0075, Val Accuracy: 64.15%


Epoch 15/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.22it/s, Train Loss=0.961]

Epoch 14/50, Train Loss: 1.0358, Train Accuracy: 63.29%, Val Loss: 0.9812, Val Accuracy: 65.15%


Epoch 16/50:   1%|▎                                                   | 1/196 [00:00<00:37,  5.20it/s, Train Loss=1.1]

Epoch 15/50, Train Loss: 1.0385, Train Accuracy: 62.96%, Val Loss: 0.9953, Val Accuracy: 64.26%


Epoch 17/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.18it/s, Train Loss=1.15]

Epoch 16/50, Train Loss: 1.0306, Train Accuracy: 63.40%, Val Loss: 1.0602, Val Accuracy: 61.45%


Epoch 18/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.07]

Epoch 17/50, Train Loss: 1.0238, Train Accuracy: 63.45%, Val Loss: 0.9701, Val Accuracy: 65.53%


Epoch 19/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.96]

Epoch 18/50, Train Loss: 1.0260, Train Accuracy: 63.55%, Val Loss: 1.0211, Val Accuracy: 63.83%


Epoch 20/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.20it/s, Train Loss=0.959]

Epoch 19/50, Train Loss: 1.0191, Train Accuracy: 63.77%, Val Loss: 0.9984, Val Accuracy: 65.17%


Epoch 21/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.05]

Epoch 20/50, Train Loss: 1.0226, Train Accuracy: 63.42%, Val Loss: 0.9994, Val Accuracy: 64.09%


Epoch 22/50:   1%|▎                                                  | 1/196 [00:00<00:38,  5.09it/s, Train Loss=1.08]

Epoch 21/50, Train Loss: 1.0215, Train Accuracy: 63.74%, Val Loss: 0.9742, Val Accuracy: 65.14%


Epoch 23/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.22it/s, Train Loss=0.881]

Epoch 22/50, Train Loss: 1.0192, Train Accuracy: 63.84%, Val Loss: 1.0336, Val Accuracy: 63.07%


Epoch 24/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 23/50, Train Loss: 1.0160, Train Accuracy: 63.80%, Val Loss: 0.9396, Val Accuracy: 66.68%


Epoch 25/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.95]

Epoch 24/50, Train Loss: 1.0052, Train Accuracy: 63.81%, Val Loss: 0.9514, Val Accuracy: 66.35%


Epoch 26/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.03]

Epoch 25/50, Train Loss: 1.0091, Train Accuracy: 64.15%, Val Loss: 1.0064, Val Accuracy: 64.39%


Epoch 27/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.20it/s, Train Loss=1.07]

Epoch 26/50, Train Loss: 1.0121, Train Accuracy: 63.84%, Val Loss: 1.0558, Val Accuracy: 62.89%


Epoch 28/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.22it/s, Train Loss=1.05]

Epoch 27/50, Train Loss: 1.0065, Train Accuracy: 64.26%, Val Loss: 0.9779, Val Accuracy: 65.26%


Epoch 29/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 28/50, Train Loss: 1.0033, Train Accuracy: 64.11%, Val Loss: 0.9177, Val Accuracy: 67.70%


Epoch 30/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.987]

Epoch 29/50, Train Loss: 1.0064, Train Accuracy: 64.11%, Val Loss: 0.9726, Val Accuracy: 66.08%


Epoch 31/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.20it/s, Train Loss=1.03]

Epoch 30/50, Train Loss: 1.0048, Train Accuracy: 64.13%, Val Loss: 0.9858, Val Accuracy: 64.70%


Epoch 32/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.15it/s, Train Loss=0.996]

Epoch 31/50, Train Loss: 0.9978, Train Accuracy: 64.45%, Val Loss: 0.9951, Val Accuracy: 64.41%


Epoch 33/50:   1%|▎                                                 | 1/196 [00:00<00:38,  5.07it/s, Train Loss=0.961]

Epoch 32/50, Train Loss: 0.9973, Train Accuracy: 64.50%, Val Loss: 0.9752, Val Accuracy: 65.39%


Epoch 34/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.22it/s, Train Loss=0.946]

Epoch 33/50, Train Loss: 0.9967, Train Accuracy: 64.52%, Val Loss: 0.9663, Val Accuracy: 65.14%


Epoch 35/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.01]

Epoch 34/50, Train Loss: 0.9968, Train Accuracy: 64.31%, Val Loss: 0.9781, Val Accuracy: 65.39%


Epoch 36/50:   1%|▎                                                  | 1/196 [00:00<00:38,  5.06it/s, Train Loss=1.07]

Epoch 35/50, Train Loss: 0.9884, Train Accuracy: 64.72%, Val Loss: 0.9516, Val Accuracy: 66.28%


Epoch 37/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.03]

Epoch 36/50, Train Loss: 0.9888, Train Accuracy: 64.74%, Val Loss: 0.9597, Val Accuracy: 66.38%


Epoch 38/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.22it/s, Train Loss=0.951]

Epoch 37/50, Train Loss: 0.9894, Train Accuracy: 64.73%, Val Loss: 0.9500, Val Accuracy: 66.49%


Epoch 39/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.22it/s, Train Loss=0.956]

Epoch 38/50, Train Loss: 0.9868, Train Accuracy: 64.84%, Val Loss: 0.9900, Val Accuracy: 64.91%


Epoch 40/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.921]

Epoch 39/50, Train Loss: 0.9842, Train Accuracy: 64.88%, Val Loss: 0.9262, Val Accuracy: 67.26%


Epoch 41/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.22it/s, Train Loss=0.905]

Epoch 40/50, Train Loss: 0.9810, Train Accuracy: 65.13%, Val Loss: 0.9220, Val Accuracy: 67.54%


Epoch 42/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.939]

Epoch 41/50, Train Loss: 0.9865, Train Accuracy: 64.99%, Val Loss: 0.9329, Val Accuracy: 66.66%


Epoch 43/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.863]

Epoch 42/50, Train Loss: 0.9835, Train Accuracy: 64.69%, Val Loss: 0.9660, Val Accuracy: 65.58%


Epoch 44/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.976]

Epoch 43/50, Train Loss: 0.9846, Train Accuracy: 64.96%, Val Loss: 0.9490, Val Accuracy: 66.05%


Epoch 45/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=1.02]

Epoch 44/50, Train Loss: 0.9724, Train Accuracy: 65.47%, Val Loss: 0.9995, Val Accuracy: 64.55%


Epoch 46/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.22it/s, Train Loss=0.938]

Epoch 45/50, Train Loss: 0.9784, Train Accuracy: 65.21%, Val Loss: 0.9336, Val Accuracy: 67.07%


Epoch 47/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 46/50, Train Loss: 0.9750, Train Accuracy: 65.19%, Val Loss: 0.9136, Val Accuracy: 67.54%


Epoch 48/50:   1%|▎                                                 | 1/196 [00:00<00:38,  5.11it/s, Train Loss=0.884]

Epoch 47/50, Train Loss: 0.9741, Train Accuracy: 65.44%, Val Loss: 0.9142, Val Accuracy: 67.47%


Epoch 49/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.968]

Epoch 48/50, Train Loss: 0.9777, Train Accuracy: 65.10%, Val Loss: 0.9450, Val Accuracy: 66.44%


Epoch 50/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 49/50, Train Loss: 0.9698, Train Accuracy: 65.35%, Val Loss: 0.9035, Val Accuracy: 68.33%


                                                                                                                      

Epoch 50/50, Train Loss: 0.9678, Train Accuracy: 65.54%, Val Loss: 0.9433, Val Accuracy: 66.02%
models saved to ViT/models/20241004_183234.pth


In [6]:
# Here is the hyperparameters
# 100-150
epochs = 50
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_dim = 512
mlp_dim = None # default to 4*embed_dim
pool = 'cls'
dropout = 0.1
mlp_dropout = 0.0

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_dim=attn_dim, mlp_dim=mlp_dim, pool=pool, dropout=dropout, mlp_dropout=mlp_dropout)

last_time_str = "20241004_183234"

# load model
model.load_state_dict(torch.load(f"ViT/models/{last_time_str}.pth"))

print(f"models loaded from ViT/models/{last_time_str}.pth")

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

torch.save(model.state_dict(), f"ViT/models/{time_str}.pth")

print(f"models saved to ViT/models/{time_str}.pth")

Epoch 1/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

models loaded from ViT/models/20241004_183234.pth
Time string: 20241004_211435
The model has 6,356,234 trainable parameters


Epoch 2/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 1/50, Train Loss: 0.8939, Train Accuracy: 68.21%, Val Loss: 0.8825, Val Accuracy: 68.77%


Epoch 3/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 2/50, Train Loss: 0.8686, Train Accuracy: 69.03%, Val Loss: 0.8620, Val Accuracy: 69.41%


Epoch 4/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 3/50, Train Loss: 0.8610, Train Accuracy: 69.32%, Val Loss: 0.8562, Val Accuracy: 69.68%


Epoch 5/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 4/50, Train Loss: 0.8500, Train Accuracy: 69.76%, Val Loss: 0.8555, Val Accuracy: 69.74%


Epoch 6/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.20it/s, Train Loss=0.906]

Epoch 5/50, Train Loss: 0.8505, Train Accuracy: 69.80%, Val Loss: 0.8584, Val Accuracy: 69.96%


Epoch 7/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 6/50, Train Loss: 0.8489, Train Accuracy: 69.74%, Val Loss: 0.8534, Val Accuracy: 70.08%


Epoch 8/50:   0%|                                                                             | 0/196 [00:00<?, ?it/s]

Epoch 7/50, Train Loss: 0.8414, Train Accuracy: 70.01%, Val Loss: 0.8494, Val Accuracy: 70.07%


Epoch 9/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.20it/s, Train Loss=0.755]

Epoch 8/50, Train Loss: 0.8445, Train Accuracy: 69.80%, Val Loss: 0.8595, Val Accuracy: 69.78%


Epoch 10/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 9/50, Train Loss: 0.8362, Train Accuracy: 70.38%, Val Loss: 0.8485, Val Accuracy: 70.53%


Epoch 11/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 10/50, Train Loss: 0.8287, Train Accuracy: 70.65%, Val Loss: 0.8332, Val Accuracy: 70.90%


Epoch 12/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.20it/s, Train Loss=0.864]

Epoch 11/50, Train Loss: 0.8281, Train Accuracy: 70.64%, Val Loss: 0.8587, Val Accuracy: 69.88%


Epoch 13/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.20it/s, Train Loss=0.786]

Epoch 12/50, Train Loss: 0.8239, Train Accuracy: 70.58%, Val Loss: 0.8579, Val Accuracy: 70.06%


Epoch 14/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 13/50, Train Loss: 0.8289, Train Accuracy: 70.39%, Val Loss: 0.8249, Val Accuracy: 71.62%


Epoch 15/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.684]

Epoch 14/50, Train Loss: 0.8246, Train Accuracy: 70.83%, Val Loss: 0.8659, Val Accuracy: 69.99%


Epoch 16/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.18it/s, Train Loss=0.807]

Epoch 15/50, Train Loss: 0.8229, Train Accuracy: 70.84%, Val Loss: 0.8404, Val Accuracy: 70.80%


Epoch 17/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.20it/s, Train Loss=0.897]

Epoch 16/50, Train Loss: 0.8202, Train Accuracy: 70.81%, Val Loss: 0.8405, Val Accuracy: 70.79%


Epoch 18/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 17/50, Train Loss: 0.8226, Train Accuracy: 70.69%, Val Loss: 0.8137, Val Accuracy: 71.64%


Epoch 19/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.866]

Epoch 18/50, Train Loss: 0.8158, Train Accuracy: 70.95%, Val Loss: 0.8452, Val Accuracy: 71.08%


Epoch 20/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.77]

Epoch 19/50, Train Loss: 0.8164, Train Accuracy: 70.88%, Val Loss: 0.8545, Val Accuracy: 70.34%


Epoch 21/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.20it/s, Train Loss=0.832]

Epoch 20/50, Train Loss: 0.8133, Train Accuracy: 70.91%, Val Loss: 0.8607, Val Accuracy: 69.74%


Epoch 22/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.16it/s, Train Loss=0.986]

Epoch 21/50, Train Loss: 0.8140, Train Accuracy: 70.94%, Val Loss: 0.8577, Val Accuracy: 70.45%


Epoch 23/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 22/50, Train Loss: 0.8099, Train Accuracy: 71.05%, Val Loss: 0.8682, Val Accuracy: 69.78%


Epoch 24/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.838]

Epoch 23/50, Train Loss: 0.8116, Train Accuracy: 71.05%, Val Loss: 0.8405, Val Accuracy: 70.73%


Epoch 25/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.20it/s, Train Loss=0.773]

Epoch 24/50, Train Loss: 0.8049, Train Accuracy: 71.44%, Val Loss: 0.8409, Val Accuracy: 70.79%


Epoch 26/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.723]

Epoch 25/50, Train Loss: 0.8041, Train Accuracy: 71.46%, Val Loss: 0.8489, Val Accuracy: 70.77%


Epoch 27/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.19it/s, Train Loss=0.826]

Epoch 26/50, Train Loss: 0.8016, Train Accuracy: 71.42%, Val Loss: 0.8201, Val Accuracy: 71.49%


Epoch 28/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.91]

Epoch 27/50, Train Loss: 0.7968, Train Accuracy: 71.55%, Val Loss: 0.8574, Val Accuracy: 70.16%


Epoch 29/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.771]

Epoch 28/50, Train Loss: 0.8033, Train Accuracy: 71.40%, Val Loss: 0.8270, Val Accuracy: 71.11%


Epoch 30/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.727]

Epoch 29/50, Train Loss: 0.8018, Train Accuracy: 71.44%, Val Loss: 0.8414, Val Accuracy: 70.67%


Epoch 31/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.854]

Epoch 30/50, Train Loss: 0.8009, Train Accuracy: 71.32%, Val Loss: 0.8352, Val Accuracy: 71.23%


Epoch 32/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.748]

Epoch 31/50, Train Loss: 0.7987, Train Accuracy: 71.41%, Val Loss: 0.8565, Val Accuracy: 70.44%


Epoch 33/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.20it/s, Train Loss=0.858]

Epoch 32/50, Train Loss: 0.7992, Train Accuracy: 71.45%, Val Loss: 0.8496, Val Accuracy: 70.36%


Epoch 34/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.18it/s, Train Loss=0.713]

Epoch 33/50, Train Loss: 0.7910, Train Accuracy: 71.64%, Val Loss: 0.8476, Val Accuracy: 70.95%


Epoch 35/50:   1%|▎                                                  | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.79]

Epoch 34/50, Train Loss: 0.7861, Train Accuracy: 71.91%, Val Loss: 0.8332, Val Accuracy: 71.15%


Epoch 36/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.864]

Epoch 35/50, Train Loss: 0.7875, Train Accuracy: 72.01%, Val Loss: 0.8483, Val Accuracy: 70.64%


Epoch 37/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.22it/s, Train Loss=0.825]

Epoch 36/50, Train Loss: 0.7883, Train Accuracy: 71.83%, Val Loss: 0.8347, Val Accuracy: 71.39%


Epoch 38/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.20it/s, Train Loss=0.759]

Epoch 37/50, Train Loss: 0.7930, Train Accuracy: 71.67%, Val Loss: 0.8582, Val Accuracy: 70.06%


Epoch 39/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.772]

Epoch 38/50, Train Loss: 0.7862, Train Accuracy: 71.95%, Val Loss: 0.8250, Val Accuracy: 71.22%


Epoch 40/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.22it/s, Train Loss=0.738]

Epoch 39/50, Train Loss: 0.7834, Train Accuracy: 72.02%, Val Loss: 0.8446, Val Accuracy: 70.62%


Epoch 41/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 40/50, Train Loss: 0.7839, Train Accuracy: 71.96%, Val Loss: 0.8115, Val Accuracy: 71.58%


Epoch 42/50:   1%|▎                                                 | 1/196 [00:00<00:38,  5.12it/s, Train Loss=0.764]

Epoch 41/50, Train Loss: 0.7839, Train Accuracy: 72.08%, Val Loss: 0.8150, Val Accuracy: 71.72%


Epoch 43/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 42/50, Train Loss: 0.7817, Train Accuracy: 72.17%, Val Loss: 0.8082, Val Accuracy: 72.10%


Epoch 44/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.20it/s, Train Loss=0.825]

Epoch 43/50, Train Loss: 0.7792, Train Accuracy: 72.17%, Val Loss: 0.8420, Val Accuracy: 70.79%


Epoch 45/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.783]

Epoch 44/50, Train Loss: 0.7801, Train Accuracy: 72.20%, Val Loss: 0.8284, Val Accuracy: 71.14%


Epoch 46/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.822]

Epoch 45/50, Train Loss: 0.7716, Train Accuracy: 72.56%, Val Loss: 0.8436, Val Accuracy: 70.90%


Epoch 47/50:   0%|                                                                            | 0/196 [00:00<?, ?it/s]

Epoch 46/50, Train Loss: 0.7774, Train Accuracy: 72.29%, Val Loss: 0.8010, Val Accuracy: 72.19%


Epoch 48/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.13it/s, Train Loss=0.724]

Epoch 47/50, Train Loss: 0.7774, Train Accuracy: 72.34%, Val Loss: 0.8159, Val Accuracy: 71.32%


Epoch 49/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.797]

Epoch 48/50, Train Loss: 0.7731, Train Accuracy: 72.61%, Val Loss: 0.8153, Val Accuracy: 71.60%


Epoch 50/50:   1%|▎                                                 | 1/196 [00:00<00:37,  5.21it/s, Train Loss=0.784]

Epoch 49/50, Train Loss: 0.7717, Train Accuracy: 72.41%, Val Loss: 0.8315, Val Accuracy: 71.30%


                                                                                                                      

Epoch 50/50, Train Loss: 0.7715, Train Accuracy: 72.62%, Val Loss: 0.8201, Val Accuracy: 71.46%
models saved to ViT/models/20241004_211435.pth


In [None]:
# Here is the hyperparameters
# 150-200
epochs = 50
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_dim = 512
mlp_dim = None # default to 4*embed_dim
pool = 'cls'
dropout = 0.1
mlp_dropout = 0.0

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_dim=attn_dim, mlp_dim=mlp_dim, pool=pool, dropout=dropout, mlp_dropout=mlp_dropout)

last_time_str = "20241004_211435"

# load model
model.load_state_dict(torch.load(f"ViT/models/{last_time_str}.pth"))

print(f"models loaded from ViT/models/{last_time_str}.pth")

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

torch.save(model.state_dict(), f"ViT/models/{time_str}.pth")

print(f"models saved to ViT/models/{time_str}.pth")