# Vit for classification in CIFAR-10

In [2]:
%load_ext autoreload
%autoreload 2
!nvidia-smi
!which python

Wed Oct  2 14:13:24 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   41C    P0              40W / 300W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
try:
    from ViT.train import train
    from ViT.utils import cifar_train_set, cifar_val_set
    from ViT.model import *
except:
    from train import train
    from utils import cifar_train_set, cifar_val_set
    from model import *

from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F

import datetime

train_loader = DataLoader(cifar_train_set, 256, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(cifar_val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

import os
if not os.path.exists("ViT/log"):
    os.makedirs("ViT/log")

def timestr():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%S")

def get_outdir(time_str):
    outdir = f"ViT/log/{time_str}.out"
    return outdir

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Here is the hyperparameters

epochs = 50
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_mlp_dim = 512
mlp_dim = 512
pool = 'cls'

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_mlp_dim=attn_mlp_dim, mlp_dim=mlp_dim, pool=pool)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

Time string: 20241002_091830
The model has 6,359,562 trainable parameters


Epoch 2/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 1/50, Train Loss: 2.0089, Train Accuracy: 24.63%, Val Loss: 1.8227, Val Accuracy: 32.91%


Epoch 3/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 2/50, Train Loss: 1.7361, Train Accuracy: 36.08%, Val Loss: 1.6336, Val Accuracy: 40.28%


Epoch 4/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 3/50, Train Loss: 1.6261, Train Accuracy: 40.61%, Val Loss: 1.5836, Val Accuracy: 42.45%


Epoch 5/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 4/50, Train Loss: 1.5654, Train Accuracy: 42.90%, Val Loss: 1.5123, Val Accuracy: 45.19%


Epoch 6/50:   1%|▎                                                      | 1/196 [00:00<00:32,  6.00it/s, Train Loss=1.6]

Epoch 5/50, Train Loss: 1.5155, Train Accuracy: 44.71%, Val Loss: 1.5135, Val Accuracy: 45.03%


Epoch 7/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 6/50, Train Loss: 1.4834, Train Accuracy: 46.23%, Val Loss: 1.4772, Val Accuracy: 46.55%


Epoch 8/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 7/50, Train Loss: 1.4466, Train Accuracy: 47.31%, Val Loss: 1.4147, Val Accuracy: 49.00%


Epoch 9/50:   1%|▎                                                      | 1/196 [00:00<00:32,  5.99it/s, Train Loss=1.4]

Epoch 8/50, Train Loss: 1.4240, Train Accuracy: 48.27%, Val Loss: 1.4380, Val Accuracy: 48.50%


Epoch 10/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 9/50, Train Loss: 1.4047, Train Accuracy: 48.96%, Val Loss: 1.3805, Val Accuracy: 49.75%


Epoch 11/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 10/50, Train Loss: 1.3813, Train Accuracy: 50.00%, Val Loss: 1.3691, Val Accuracy: 50.72%


Epoch 12/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.85it/s, Train Loss=1.35]

Epoch 11/50, Train Loss: 1.3705, Train Accuracy: 50.49%, Val Loss: 1.3803, Val Accuracy: 49.92%


Epoch 13/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 12/50, Train Loss: 1.3512, Train Accuracy: 51.05%, Val Loss: 1.3570, Val Accuracy: 51.12%


Epoch 14/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 13/50, Train Loss: 1.3411, Train Accuracy: 51.40%, Val Loss: 1.3434, Val Accuracy: 51.36%


Epoch 15/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 14/50, Train Loss: 1.3307, Train Accuracy: 51.91%, Val Loss: 1.3058, Val Accuracy: 52.99%


Epoch 16/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.04it/s, Train Loss=1.27]

Epoch 15/50, Train Loss: 1.3176, Train Accuracy: 52.36%, Val Loss: 1.3390, Val Accuracy: 51.61%


Epoch 17/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.01it/s, Train Loss=1.26]

Epoch 16/50, Train Loss: 1.3088, Train Accuracy: 52.71%, Val Loss: 1.3243, Val Accuracy: 53.22%


Epoch 18/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.02it/s, Train Loss=1.2]

Epoch 17/50, Train Loss: 1.2916, Train Accuracy: 53.55%, Val Loss: 1.3104, Val Accuracy: 52.61%


Epoch 19/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.89it/s, Train Loss=1.46]

Epoch 18/50, Train Loss: 1.2865, Train Accuracy: 53.68%, Val Loss: 1.3308, Val Accuracy: 52.61%


Epoch 20/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 19/50, Train Loss: 1.2775, Train Accuracy: 53.91%, Val Loss: 1.2829, Val Accuracy: 54.20%


Epoch 21/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.01it/s, Train Loss=1.24]

Epoch 20/50, Train Loss: 1.2661, Train Accuracy: 54.39%, Val Loss: 1.2861, Val Accuracy: 53.94%


Epoch 22/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 21/50, Train Loss: 1.2544, Train Accuracy: 54.63%, Val Loss: 1.2777, Val Accuracy: 54.16%


Epoch 23/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 22/50, Train Loss: 1.2499, Train Accuracy: 54.77%, Val Loss: 1.2629, Val Accuracy: 54.47%


Epoch 24/50:   1%|▎                                                    | 1/196 [00:00<00:35,  5.53it/s, Train Loss=1.19]

Epoch 23/50, Train Loss: 1.2411, Train Accuracy: 55.37%, Val Loss: 1.2634, Val Accuracy: 54.61%


Epoch 25/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 24/50, Train Loss: 1.2321, Train Accuracy: 55.40%, Val Loss: 1.2501, Val Accuracy: 55.06%


Epoch 26/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 25/50, Train Loss: 1.2187, Train Accuracy: 56.16%, Val Loss: 1.2432, Val Accuracy: 55.40%


Epoch 27/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 26/50, Train Loss: 1.2127, Train Accuracy: 56.53%, Val Loss: 1.2035, Val Accuracy: 56.91%


Epoch 28/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.05it/s, Train Loss=1.11]

Epoch 27/50, Train Loss: 1.2007, Train Accuracy: 56.73%, Val Loss: 1.2454, Val Accuracy: 55.08%


Epoch 29/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.98it/s, Train Loss=1.16]

Epoch 28/50, Train Loss: 1.1928, Train Accuracy: 57.24%, Val Loss: 1.2245, Val Accuracy: 55.86%


Epoch 30/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 29/50, Train Loss: 1.1855, Train Accuracy: 57.41%, Val Loss: 1.1823, Val Accuracy: 58.16%


Epoch 31/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.03it/s, Train Loss=1.21]

Epoch 30/50, Train Loss: 1.1766, Train Accuracy: 57.75%, Val Loss: 1.2210, Val Accuracy: 56.30%


Epoch 32/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.02it/s, Train Loss=1.1]

Epoch 31/50, Train Loss: 1.1686, Train Accuracy: 58.06%, Val Loss: 1.1950, Val Accuracy: 56.71%


Epoch 33/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 32/50, Train Loss: 1.1555, Train Accuracy: 58.70%, Val Loss: 1.1690, Val Accuracy: 58.40%


Epoch 34/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.02it/s, Train Loss=1.07]

Epoch 33/50, Train Loss: 1.1534, Train Accuracy: 58.62%, Val Loss: 1.1837, Val Accuracy: 57.22%


Epoch 35/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.00it/s, Train Loss=1.21]

Epoch 34/50, Train Loss: 1.1410, Train Accuracy: 59.10%, Val Loss: 1.1910, Val Accuracy: 57.58%


Epoch 36/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 35/50, Train Loss: 1.1367, Train Accuracy: 59.34%, Val Loss: 1.1682, Val Accuracy: 57.94%


Epoch 37/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 36/50, Train Loss: 1.1319, Train Accuracy: 59.27%, Val Loss: 1.1655, Val Accuracy: 58.62%


Epoch 38/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 37/50, Train Loss: 1.1169, Train Accuracy: 60.09%, Val Loss: 1.1544, Val Accuracy: 58.80%


Epoch 39/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 38/50, Train Loss: 1.1171, Train Accuracy: 59.88%, Val Loss: 1.1380, Val Accuracy: 59.93%


Epoch 40/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.91it/s, Train Loss=1.13]

Epoch 39/50, Train Loss: 1.1067, Train Accuracy: 60.25%, Val Loss: 1.1386, Val Accuracy: 59.46%


Epoch 41/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 40/50, Train Loss: 1.1017, Train Accuracy: 60.64%, Val Loss: 1.1320, Val Accuracy: 60.05%


Epoch 42/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 41/50, Train Loss: 1.0928, Train Accuracy: 61.00%, Val Loss: 1.1245, Val Accuracy: 60.24%


Epoch 43/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 42/50, Train Loss: 1.0894, Train Accuracy: 61.15%, Val Loss: 1.1136, Val Accuracy: 60.40%


Epoch 44/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 43/50, Train Loss: 1.0831, Train Accuracy: 61.21%, Val Loss: 1.1064, Val Accuracy: 61.26%


Epoch 45/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.00it/s, Train Loss=1.03]

Epoch 44/50, Train Loss: 1.0729, Train Accuracy: 61.61%, Val Loss: 1.1102, Val Accuracy: 61.01%


Epoch 46/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 45/50, Train Loss: 1.0682, Train Accuracy: 61.83%, Val Loss: 1.0931, Val Accuracy: 60.70%


Epoch 47/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.03it/s, Train Loss=1.12]

Epoch 46/50, Train Loss: 1.0602, Train Accuracy: 62.03%, Val Loss: 1.1064, Val Accuracy: 60.38%


Epoch 48/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 47/50, Train Loss: 1.0577, Train Accuracy: 62.30%, Val Loss: 1.0928, Val Accuracy: 61.57%


Epoch 49/50:   1%|▎                                                   | 1/196 [00:00<00:32,  6.00it/s, Train Loss=0.948]

Epoch 48/50, Train Loss: 1.0497, Train Accuracy: 62.32%, Val Loss: 1.1186, Val Accuracy: 60.49%


Epoch 50/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 49/50, Train Loss: 1.0473, Train Accuracy: 62.52%, Val Loss: 1.0925, Val Accuracy: 61.13%


                                                                                                                        

Epoch 50/50, Train Loss: 1.0331, Train Accuracy: 63.01%, Val Loss: 1.0868, Val Accuracy: 60.67%


In [4]:
# Here is the hyperparameters

epochs = 20
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_mlp_dim = 512
mlp_dim = 512
pool = 'cls'

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_mlp_dim=attn_mlp_dim, mlp_dim=mlp_dim, pool=pool)

# load model
model.load_state_dict(torch.load("ViT.pth"))

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))



Time string: 20241002_122629
The model has 6,359,562 trainable parameters




Epoch 1/20, Train Loss: 1.0269, Train Accuracy: 63.00%, Val Loss: 1.1021, Val Accuracy: 61.55%




Epoch 2/20, Train Loss: 1.0262, Train Accuracy: 63.26%, Val Loss: 1.0776, Val Accuracy: 61.30%




Epoch 3/20, Train Loss: 1.0108, Train Accuracy: 63.96%, Val Loss: 1.0612, Val Accuracy: 62.45%




Epoch 4/20, Train Loss: 1.0083, Train Accuracy: 63.92%, Val Loss: 1.0760, Val Accuracy: 62.45%




Epoch 5/20, Train Loss: 1.0022, Train Accuracy: 64.44%, Val Loss: 1.0520, Val Accuracy: 62.87%




Epoch 6/20, Train Loss: 0.9957, Train Accuracy: 64.36%, Val Loss: 1.0622, Val Accuracy: 62.15%




Epoch 7/20, Train Loss: 0.9910, Train Accuracy: 64.69%, Val Loss: 1.0355, Val Accuracy: 63.52%




Epoch 8/20, Train Loss: 0.9878, Train Accuracy: 64.69%, Val Loss: 1.0539, Val Accuracy: 62.91%




Epoch 9/20, Train Loss: 0.9771, Train Accuracy: 65.07%, Val Loss: 1.0261, Val Accuracy: 63.37%




Epoch 10/20, Train Loss: 0.9706, Train Accuracy: 65.40%, Val Loss: 1.0586, Val Accuracy: 62.47%




Epoch 11/20, Train Loss: 0.9668, Train Accuracy: 65.36%, Val Loss: 1.0139, Val Accuracy: 63.96%




Epoch 12/20, Train Loss: 0.9686, Train Accuracy: 65.44%, Val Loss: 1.0253, Val Accuracy: 64.05%




Epoch 13/20, Train Loss: 0.9648, Train Accuracy: 65.75%, Val Loss: 1.0341, Val Accuracy: 63.29%




Epoch 14/20, Train Loss: 0.9515, Train Accuracy: 66.14%, Val Loss: 1.0017, Val Accuracy: 64.17%




Epoch 15/20, Train Loss: 0.9535, Train Accuracy: 65.92%, Val Loss: 1.0350, Val Accuracy: 63.46%




Epoch 16/20, Train Loss: 0.9447, Train Accuracy: 66.39%, Val Loss: 1.0172, Val Accuracy: 63.84%




Epoch 17/20, Train Loss: 0.9390, Train Accuracy: 66.59%, Val Loss: 1.0216, Val Accuracy: 64.35%




Epoch 18/20, Train Loss: 0.9420, Train Accuracy: 66.31%, Val Loss: 1.0138, Val Accuracy: 64.41%




Epoch 19/20, Train Loss: 0.9295, Train Accuracy: 66.84%, Val Loss: 1.0001, Val Accuracy: 64.85%




Epoch 20/20, Train Loss: 0.9294, Train Accuracy: 67.07%, Val Loss: 0.9909, Val Accuracy: 65.47%


Here is the model:

```bash
AttentionLayers(
  (layers): ModuleList(
    (0): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (1): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (2): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (3): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (4): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (5): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (6): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (7): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (8): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (9): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (10): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (11): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
  )
  (adaptive_mlp): Identity()
  (final_norm): LayerNorm(
    (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
  )
  (skip_combines): ModuleList(
    (0): None
    (1): None
    (2): None
    (3): None
    (4): None
    (5): None
    (6): None
    (7): None
    (8): None
    (9): None
    (10): None
    (11): None
  )
)

```

In [3]:
# Here is the hyperparameters

epochs = 30
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_mlp_dim = 512
mlp_dim = 512
pool = 'cls'

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_mlp_dim=attn_mlp_dim, mlp_dim=mlp_dim, pool=pool)

# load model
model.load_state_dict(torch.load("ViT.pth"))

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

Epoch 1/30:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Time string: 20241002_131005
The model has 6,359,562 trainable parameters


Epoch 2/30:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 1/30, Train Loss: 0.9234, Train Accuracy: 67.25%, Val Loss: 1.0420, Val Accuracy: 63.05%


Epoch 3/30:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 2/30, Train Loss: 0.9180, Train Accuracy: 67.12%, Val Loss: 1.0031, Val Accuracy: 65.03%


Epoch 4/30:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 3/30, Train Loss: 0.9122, Train Accuracy: 67.41%, Val Loss: 0.9941, Val Accuracy: 65.73%


Epoch 5/30:   1%|▎                                                     | 1/196 [00:00<00:32,  6.04it/s, Train Loss=0.994]

Epoch 4/30, Train Loss: 0.9063, Train Accuracy: 67.68%, Val Loss: 1.0017, Val Accuracy: 64.61%


Epoch 6/30:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 5/30, Train Loss: 0.9082, Train Accuracy: 67.81%, Val Loss: 0.9821, Val Accuracy: 65.67%


Epoch 7/30:   1%|▎                                                         | 1/196 [00:00<00:33,  5.91it/s, Train Loss=1]

Epoch 6/30, Train Loss: 0.8995, Train Accuracy: 68.06%, Val Loss: 1.0026, Val Accuracy: 64.69%


Epoch 8/30:   1%|▎                                                     | 1/196 [00:00<00:32,  5.91it/s, Train Loss=0.939]

Epoch 7/30, Train Loss: 0.8943, Train Accuracy: 68.18%, Val Loss: 0.9877, Val Accuracy: 64.80%


Epoch 9/30:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 8/30, Train Loss: 0.8902, Train Accuracy: 68.15%, Val Loss: 0.9722, Val Accuracy: 65.72%


Epoch 10/30:   1%|▎                                                    | 1/196 [00:00<00:32,  5.92it/s, Train Loss=0.995]

Epoch 9/30, Train Loss: 0.8843, Train Accuracy: 68.46%, Val Loss: 0.9759, Val Accuracy: 65.46%


Epoch 11/30:   1%|▎                                                    | 1/196 [00:00<00:32,  6.02it/s, Train Loss=0.863]

Epoch 10/30, Train Loss: 0.8772, Train Accuracy: 68.74%, Val Loss: 0.9813, Val Accuracy: 65.72%


Epoch 12/30:   1%|▎                                                    | 1/196 [00:00<00:33,  5.87it/s, Train Loss=0.892]

Epoch 11/30, Train Loss: 0.8799, Train Accuracy: 68.52%, Val Loss: 0.9767, Val Accuracy: 65.74%


Epoch 13/30:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 12/30, Train Loss: 0.8817, Train Accuracy: 68.56%, Val Loss: 0.9508, Val Accuracy: 66.42%


Epoch 14/30:   1%|▎                                                    | 1/196 [00:00<00:32,  5.93it/s, Train Loss=0.799]

Epoch 13/30, Train Loss: 0.8699, Train Accuracy: 68.88%, Val Loss: 0.9620, Val Accuracy: 65.85%


Epoch 15/30:   1%|▎                                                    | 1/196 [00:00<00:32,  6.01it/s, Train Loss=0.846]

Epoch 14/30, Train Loss: 0.8630, Train Accuracy: 69.13%, Val Loss: 0.9587, Val Accuracy: 66.49%


Epoch 16/30:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 15/30, Train Loss: 0.8661, Train Accuracy: 69.03%, Val Loss: 0.9484, Val Accuracy: 66.65%


Epoch 17/30:   1%|▎                                                    | 1/196 [00:00<00:33,  5.84it/s, Train Loss=0.816]

Epoch 16/30, Train Loss: 0.8534, Train Accuracy: 69.41%, Val Loss: 0.9827, Val Accuracy: 65.73%


Epoch 18/30:   1%|▎                                                    | 1/196 [00:00<00:33,  5.91it/s, Train Loss=0.821]

Epoch 17/30, Train Loss: 0.8496, Train Accuracy: 69.87%, Val Loss: 0.9709, Val Accuracy: 65.66%


Epoch 19/30:   1%|▎                                                    | 1/196 [00:00<00:33,  5.90it/s, Train Loss=0.906]

Epoch 18/30, Train Loss: 0.8512, Train Accuracy: 69.65%, Val Loss: 0.9510, Val Accuracy: 66.57%


Epoch 20/30:   1%|▎                                                    | 1/196 [00:00<00:32,  5.92it/s, Train Loss=0.736]

Epoch 19/30, Train Loss: 0.8460, Train Accuracy: 69.93%, Val Loss: 0.9609, Val Accuracy: 66.49%


Epoch 21/30:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 20/30, Train Loss: 0.8446, Train Accuracy: 69.83%, Val Loss: 0.9478, Val Accuracy: 66.69%


Epoch 22/30:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 21/30, Train Loss: 0.8417, Train Accuracy: 69.99%, Val Loss: 0.9423, Val Accuracy: 66.67%


Epoch 23/30:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 22/30, Train Loss: 0.8395, Train Accuracy: 70.21%, Val Loss: 0.9347, Val Accuracy: 66.89%


Epoch 24/30:   1%|▎                                                    | 1/196 [00:00<00:33,  5.90it/s, Train Loss=0.739]

Epoch 23/30, Train Loss: 0.8357, Train Accuracy: 70.34%, Val Loss: 0.9374, Val Accuracy: 66.79%


Epoch 25/30:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 24/30, Train Loss: 0.8307, Train Accuracy: 70.46%, Val Loss: 0.9338, Val Accuracy: 67.32%


Epoch 26/30:   1%|▎                                                    | 1/196 [00:00<00:32,  6.03it/s, Train Loss=0.798]

Epoch 25/30, Train Loss: 0.8221, Train Accuracy: 70.76%, Val Loss: 0.9389, Val Accuracy: 67.01%


Epoch 27/30:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 26/30, Train Loss: 0.8213, Train Accuracy: 70.70%, Val Loss: 0.9321, Val Accuracy: 67.63%


Epoch 28/30:   1%|▎                                                    | 1/196 [00:00<00:32,  6.03it/s, Train Loss=0.725]

Epoch 27/30, Train Loss: 0.8180, Train Accuracy: 70.91%, Val Loss: 0.9472, Val Accuracy: 67.10%


Epoch 29/30:   1%|▎                                                    | 1/196 [00:00<00:33,  5.85it/s, Train Loss=0.759]

Epoch 28/30, Train Loss: 0.8122, Train Accuracy: 71.09%, Val Loss: 0.9434, Val Accuracy: 67.11%


Epoch 30/30:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 29/30, Train Loss: 0.8127, Train Accuracy: 70.98%, Val Loss: 0.9221, Val Accuracy: 67.85%


                                                                                                                         

Epoch 30/30, Train Loss: 0.8042, Train Accuracy: 71.25%, Val Loss: 0.9663, Val Accuracy: 66.23%


In [4]:
# Here is the hyperparameters

epochs = 50
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_mlp_dim = 512
mlp_dim = 512
pool = 'cls'

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_mlp_dim=attn_mlp_dim, mlp_dim=mlp_dim, pool=pool)

# load model
model.load_state_dict(torch.load("ViT.pth"))

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

Epoch 1/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Time string: 20241002_141337
The model has 6,359,562 trainable parameters


Epoch 2/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 1/50, Train Loss: 0.8041, Train Accuracy: 71.33%, Val Loss: 0.9552, Val Accuracy: 66.63%


Epoch 3/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 2/50, Train Loss: 0.7979, Train Accuracy: 71.35%, Val Loss: 0.9191, Val Accuracy: 67.68%


Epoch 4/50:   1%|▎                                                     | 1/196 [00:00<00:33,  5.91it/s, Train Loss=0.824]

Epoch 3/50, Train Loss: 0.7847, Train Accuracy: 71.99%, Val Loss: 0.9274, Val Accuracy: 67.82%


Epoch 5/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.02it/s, Train Loss=0.885]

Epoch 4/50, Train Loss: 0.7838, Train Accuracy: 71.97%, Val Loss: 0.9488, Val Accuracy: 66.93%


Epoch 6/50:   1%|▎                                                     | 1/196 [00:00<00:33,  5.86it/s, Train Loss=0.837]

Epoch 5/50, Train Loss: 0.7843, Train Accuracy: 72.21%, Val Loss: 0.9311, Val Accuracy: 67.45%


Epoch 7/50:   1%|▎                                                     | 1/196 [00:00<00:33,  5.91it/s, Train Loss=0.831]

Epoch 6/50, Train Loss: 0.7724, Train Accuracy: 72.64%, Val Loss: 0.9268, Val Accuracy: 67.81%


Epoch 8/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 7/50, Train Loss: 0.7743, Train Accuracy: 72.57%, Val Loss: 0.9147, Val Accuracy: 67.57%


Epoch 9/50:   0%|                                                                                | 0/196 [00:00<?, ?it/s]

Epoch 8/50, Train Loss: 0.7688, Train Accuracy: 72.70%, Val Loss: 0.9116, Val Accuracy: 68.12%


Epoch 10/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.95it/s, Train Loss=0.872]

Epoch 9/50, Train Loss: 0.7603, Train Accuracy: 72.83%, Val Loss: 0.9195, Val Accuracy: 68.50%


Epoch 11/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.00it/s, Train Loss=0.709]

Epoch 10/50, Train Loss: 0.7597, Train Accuracy: 72.82%, Val Loss: 0.9192, Val Accuracy: 68.28%


Epoch 12/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 11/50, Train Loss: 0.7640, Train Accuracy: 72.64%, Val Loss: 0.9092, Val Accuracy: 68.17%


Epoch 13/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.88it/s, Train Loss=0.657]

Epoch 12/50, Train Loss: 0.7567, Train Accuracy: 73.11%, Val Loss: 0.9127, Val Accuracy: 68.25%


Epoch 14/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.75it/s, Train Loss=0.824]

Epoch 13/50, Train Loss: 0.7465, Train Accuracy: 73.30%, Val Loss: 0.9230, Val Accuracy: 67.65%


Epoch 15/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 14/50, Train Loss: 0.7483, Train Accuracy: 73.20%, Val Loss: 0.8971, Val Accuracy: 68.86%


Epoch 16/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.85it/s, Train Loss=0.765]

Epoch 15/50, Train Loss: 0.7498, Train Accuracy: 73.46%, Val Loss: 0.9086, Val Accuracy: 68.35%


Epoch 17/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 16/50, Train Loss: 0.7442, Train Accuracy: 73.40%, Val Loss: 0.8949, Val Accuracy: 68.80%


Epoch 18/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.01it/s, Train Loss=0.788]

Epoch 17/50, Train Loss: 0.7379, Train Accuracy: 73.79%, Val Loss: 0.8977, Val Accuracy: 69.07%


Epoch 19/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.86it/s, Train Loss=0.698]

Epoch 18/50, Train Loss: 0.7339, Train Accuracy: 73.75%, Val Loss: 0.8982, Val Accuracy: 69.09%


Epoch 20/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 19/50, Train Loss: 0.7220, Train Accuracy: 74.34%, Val Loss: 0.8888, Val Accuracy: 68.89%


Epoch 21/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.90it/s, Train Loss=0.701]

Epoch 20/50, Train Loss: 0.7293, Train Accuracy: 74.07%, Val Loss: 0.9013, Val Accuracy: 69.13%


Epoch 22/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.92it/s, Train Loss=0.715]

Epoch 21/50, Train Loss: 0.7254, Train Accuracy: 74.07%, Val Loss: 0.9011, Val Accuracy: 69.27%


Epoch 23/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 22/50, Train Loss: 0.7232, Train Accuracy: 74.10%, Val Loss: 0.8849, Val Accuracy: 69.50%


Epoch 24/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 23/50, Train Loss: 0.7212, Train Accuracy: 74.23%, Val Loss: 0.8837, Val Accuracy: 69.45%


Epoch 25/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.99it/s, Train Loss=0.689]

Epoch 24/50, Train Loss: 0.7127, Train Accuracy: 74.45%, Val Loss: 0.8952, Val Accuracy: 69.52%


Epoch 26/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.92it/s, Train Loss=0.534]

Epoch 25/50, Train Loss: 0.7124, Train Accuracy: 74.77%, Val Loss: 0.8996, Val Accuracy: 68.95%


Epoch 27/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.01it/s, Train Loss=0.723]

Epoch 26/50, Train Loss: 0.7034, Train Accuracy: 74.77%, Val Loss: 0.8861, Val Accuracy: 69.26%


Epoch 28/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.00it/s, Train Loss=0.626]

Epoch 27/50, Train Loss: 0.7021, Train Accuracy: 74.92%, Val Loss: 0.8847, Val Accuracy: 69.44%


Epoch 29/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 28/50, Train Loss: 0.6988, Train Accuracy: 74.91%, Val Loss: 0.8751, Val Accuracy: 70.11%


Epoch 30/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.84it/s, Train Loss=0.646]

Epoch 29/50, Train Loss: 0.6977, Train Accuracy: 75.11%, Val Loss: 0.9040, Val Accuracy: 69.23%


Epoch 31/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.92it/s, Train Loss=0.638]

Epoch 30/50, Train Loss: 0.6987, Train Accuracy: 74.88%, Val Loss: 0.8767, Val Accuracy: 69.77%


Epoch 32/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.97it/s, Train Loss=0.682]

Epoch 31/50, Train Loss: 0.6892, Train Accuracy: 75.55%, Val Loss: 0.9066, Val Accuracy: 69.19%


Epoch 33/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.98it/s, Train Loss=0.758]

Epoch 32/50, Train Loss: 0.6890, Train Accuracy: 75.30%, Val Loss: 0.8760, Val Accuracy: 69.94%


Epoch 34/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 33/50, Train Loss: 0.6822, Train Accuracy: 75.47%, Val Loss: 0.8741, Val Accuracy: 69.81%


Epoch 35/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.85it/s, Train Loss=0.648]

Epoch 34/50, Train Loss: 0.6820, Train Accuracy: 75.77%, Val Loss: 0.8847, Val Accuracy: 69.60%


Epoch 36/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.88it/s, Train Loss=0.659]

Epoch 35/50, Train Loss: 0.6718, Train Accuracy: 75.73%, Val Loss: 0.8846, Val Accuracy: 69.83%


Epoch 37/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.97it/s, Train Loss=0.552]

Epoch 36/50, Train Loss: 0.6751, Train Accuracy: 75.88%, Val Loss: 0.8955, Val Accuracy: 69.24%


Epoch 38/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.99it/s, Train Loss=0.717]

Epoch 37/50, Train Loss: 0.6720, Train Accuracy: 76.13%, Val Loss: 0.8830, Val Accuracy: 69.46%


Epoch 39/50:   1%|▎                                                     | 1/196 [00:00<00:32,  5.97it/s, Train Loss=0.73]

Epoch 38/50, Train Loss: 0.6662, Train Accuracy: 75.94%, Val Loss: 0.8830, Val Accuracy: 69.63%


Epoch 40/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.01it/s, Train Loss=0.55]

Epoch 39/50, Train Loss: 0.6641, Train Accuracy: 76.10%, Val Loss: 0.8770, Val Accuracy: 70.37%


Epoch 41/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.91it/s, Train Loss=0.623]

Epoch 40/50, Train Loss: 0.6596, Train Accuracy: 76.32%, Val Loss: 0.8814, Val Accuracy: 69.65%


Epoch 42/50:   1%|▎                                                    | 1/196 [00:00<00:33,  5.86it/s, Train Loss=0.596]

Epoch 41/50, Train Loss: 0.6596, Train Accuracy: 76.26%, Val Loss: 0.8970, Val Accuracy: 70.24%


Epoch 43/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.00it/s, Train Loss=0.605]

Epoch 42/50, Train Loss: 0.6574, Train Accuracy: 76.40%, Val Loss: 0.8805, Val Accuracy: 69.99%


Epoch 44/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.99it/s, Train Loss=0.724]

Epoch 43/50, Train Loss: 0.6455, Train Accuracy: 76.99%, Val Loss: 0.8874, Val Accuracy: 69.62%


Epoch 45/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.99it/s, Train Loss=0.663]

Epoch 44/50, Train Loss: 0.6471, Train Accuracy: 76.86%, Val Loss: 0.8839, Val Accuracy: 70.37%


Epoch 46/50:   1%|▎                                                      | 1/196 [00:00<00:32,  6.02it/s, Train Loss=0.5]

Epoch 45/50, Train Loss: 0.6452, Train Accuracy: 76.97%, Val Loss: 0.8816, Val Accuracy: 70.10%


Epoch 47/50:   1%|▎                                                    | 1/196 [00:00<00:32,  6.00it/s, Train Loss=0.586]

Epoch 46/50, Train Loss: 0.6395, Train Accuracy: 77.03%, Val Loss: 0.9027, Val Accuracy: 69.15%


Epoch 48/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 47/50, Train Loss: 0.6391, Train Accuracy: 77.26%, Val Loss: 0.8625, Val Accuracy: 70.59%


Epoch 49/50:   1%|▎                                                    | 1/196 [00:00<00:32,  5.94it/s, Train Loss=0.572]

Epoch 48/50, Train Loss: 0.6350, Train Accuracy: 77.36%, Val Loss: 0.8842, Val Accuracy: 69.63%


Epoch 50/50:   1%|▎                                                     | 1/196 [00:00<00:32,  6.00it/s, Train Loss=0.65]

Epoch 49/50, Train Loss: 0.6255, Train Accuracy: 77.51%, Val Loss: 0.8927, Val Accuracy: 69.73%


                                                                                                                         

Epoch 50/50, Train Loss: 0.6304, Train Accuracy: 77.51%, Val Loss: 0.9104, Val Accuracy: 69.07%


In [4]:
# save the model
torch.save(model.state_dict(), 'ViT.pth')