# Vit for classification in CIFAR-10

In [1]:
%load_ext autoreload
%autoreload 2
!nvidia-smi
!which python

Fri Oct  4 13:33:08 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   44C    P0              51W / 184W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
try:
    from ViT.train import train
    from ViT.utils import cifar_train_set, cifar_val_set
    from ViT.model import *
except:
    from train import train
    from utils import cifar_train_set, cifar_val_set
    from model import *

from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F

import datetime

train_loader = DataLoader(cifar_train_set, 256, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(cifar_val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

import os
if not os.path.exists("ViT/log"):
    os.makedirs("ViT/log")

def timestr():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%S")

def get_outdir(time_str):
    outdir = f"ViT/log/{time_str}.out"
    return outdir

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Here is the hyperparameters

epochs = 50
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_dim = 512
mlp_dim = None # default to 4*embed_dim
pool = 'cls'
dropout = 0.1

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_dim=attn_dim, mlp_dim=mlp_dim, pool=pool, dropout=dropout)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))\

# save model
torch.save(model.state_dict(), f"ViT/models/{time_str}.pth")

Time string: 20241004_133312
The model has 6,359,562 trainable parameters


Epoch 2/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 1/50, Train Loss: 1.9669, Train Accuracy: 26.23%, Val Loss: 1.7118, Val Accuracy: 37.17%


Epoch 3/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 2/50, Train Loss: 1.6896, Train Accuracy: 37.97%, Val Loss: 1.5440, Val Accuracy: 44.06%


Epoch 4/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 3/50, Train Loss: 1.6042, Train Accuracy: 41.44%, Val Loss: 1.4898, Val Accuracy: 47.19%


Epoch 5/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 4/50, Train Loss: 1.5500, Train Accuracy: 43.60%, Val Loss: 1.4661, Val Accuracy: 46.92%


Epoch 6/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 5/50, Train Loss: 1.5178, Train Accuracy: 44.67%, Val Loss: 1.4180, Val Accuracy: 48.83%


Epoch 7/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 6/50, Train Loss: 1.4886, Train Accuracy: 46.10%, Val Loss: 1.3669, Val Accuracy: 51.26%


Epoch 8/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch 7/50, Train Loss: 1.4547, Train Accuracy: 46.95%, Val Loss: 1.3577, Val Accuracy: 51.31%


Epoch 9/50:   1%|▎                                                     | 1/196 [00:00<00:38,  5.12it/s, Train Loss=1.33]

Epoch 8/50, Train Loss: 1.4381, Train Accuracy: 47.78%, Val Loss: 1.3719, Val Accuracy: 50.86%


Epoch 10/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 9/50, Train Loss: 1.4172, Train Accuracy: 48.61%, Val Loss: 1.3351, Val Accuracy: 51.55%


Epoch 11/50:   1%|▎                                                    | 1/196 [00:00<00:38,  5.13it/s, Train Loss=1.41]

Epoch 10/50, Train Loss: 1.3918, Train Accuracy: 49.85%, Val Loss: 1.3652, Val Accuracy: 51.80%


Epoch 12/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 11/50, Train Loss: 1.3840, Train Accuracy: 49.91%, Val Loss: 1.3084, Val Accuracy: 52.82%


Epoch 13/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 12/50, Train Loss: 1.3704, Train Accuracy: 50.40%, Val Loss: 1.2775, Val Accuracy: 53.90%


Epoch 14/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 13/50, Train Loss: 1.3542, Train Accuracy: 51.16%, Val Loss: 1.2541, Val Accuracy: 55.00%


Epoch 15/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 14/50, Train Loss: 1.3304, Train Accuracy: 52.06%, Val Loss: 1.2384, Val Accuracy: 55.77%


Epoch 16/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 15/50, Train Loss: 1.3251, Train Accuracy: 52.35%, Val Loss: 1.2379, Val Accuracy: 55.36%


Epoch 17/50:   1%|▎                                                     | 1/196 [00:00<00:38,  5.12it/s, Train Loss=1.3]

Epoch 16/50, Train Loss: 1.3095, Train Accuracy: 52.42%, Val Loss: 1.2401, Val Accuracy: 55.10%


Epoch 18/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 17/50, Train Loss: 1.2963, Train Accuracy: 53.14%, Val Loss: 1.2234, Val Accuracy: 55.93%


Epoch 19/50:   1%|▎                                                     | 1/196 [00:00<00:38,  5.12it/s, Train Loss=1.2]

Epoch 18/50, Train Loss: 1.2865, Train Accuracy: 53.74%, Val Loss: 1.2258, Val Accuracy: 56.06%


Epoch 20/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 19/50, Train Loss: 1.2736, Train Accuracy: 54.31%, Val Loss: 1.1872, Val Accuracy: 57.14%


Epoch 21/50:   1%|▎                                                    | 1/196 [00:00<00:37,  5.15it/s, Train Loss=1.34]

Epoch 20/50, Train Loss: 1.2591, Train Accuracy: 54.54%, Val Loss: 1.1940, Val Accuracy: 56.88%


Epoch 22/50:   1%|▎                                                    | 1/196 [00:00<00:37,  5.13it/s, Train Loss=1.41]

Epoch 21/50, Train Loss: 1.2496, Train Accuracy: 54.81%, Val Loss: 1.2242, Val Accuracy: 56.41%


Epoch 23/50:   1%|▎                                                    | 1/196 [00:00<00:38,  5.08it/s, Train Loss=1.15]

Epoch 22/50, Train Loss: 1.2370, Train Accuracy: 55.45%, Val Loss: 1.1940, Val Accuracy: 57.42%


Epoch 24/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 23/50, Train Loss: 1.2343, Train Accuracy: 55.57%, Val Loss: 1.1976, Val Accuracy: 56.76%


Epoch 25/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 24/50, Train Loss: 1.2133, Train Accuracy: 56.51%, Val Loss: 1.1350, Val Accuracy: 59.70%


Epoch 26/50:   1%|▎                                                    | 1/196 [00:00<00:37,  5.14it/s, Train Loss=1.19]

Epoch 25/50, Train Loss: 1.2108, Train Accuracy: 56.56%, Val Loss: 1.1437, Val Accuracy: 59.18%


Epoch 27/50:   1%|▎                                                    | 1/196 [00:00<00:38,  5.12it/s, Train Loss=1.21]

Epoch 26/50, Train Loss: 1.2083, Train Accuracy: 56.61%, Val Loss: 1.1670, Val Accuracy: 57.68%


Epoch 28/50:   1%|▎                                                    | 1/196 [00:00<00:38,  5.13it/s, Train Loss=1.21]

Epoch 27/50, Train Loss: 1.1885, Train Accuracy: 57.24%, Val Loss: 1.1580, Val Accuracy: 58.33%


Epoch 29/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 28/50, Train Loss: 1.1833, Train Accuracy: 57.38%, Val Loss: 1.1123, Val Accuracy: 59.54%


Epoch 30/50:   1%|▎                                                    | 1/196 [00:00<00:37,  5.15it/s, Train Loss=1.24]

Epoch 29/50, Train Loss: 1.1789, Train Accuracy: 57.63%, Val Loss: 1.1496, Val Accuracy: 58.99%


Epoch 31/50:   1%|▎                                                    | 1/196 [00:00<00:37,  5.13it/s, Train Loss=1.14]

Epoch 30/50, Train Loss: 1.1718, Train Accuracy: 57.87%, Val Loss: 1.1618, Val Accuracy: 59.15%


Epoch 32/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 31/50, Train Loss: 1.1708, Train Accuracy: 57.95%, Val Loss: 1.1034, Val Accuracy: 60.85%


Epoch 33/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 32/50, Train Loss: 1.1559, Train Accuracy: 58.50%, Val Loss: 1.0916, Val Accuracy: 60.69%


Epoch 34/50:   0%|                                                             | 0/196 [00:00<?, ?it/s, Train Loss=1.03]

Epoch 33/50, Train Loss: 1.1532, Train Accuracy: 58.68%, Val Loss: 1.1199, Val Accuracy: 59.84%


Epoch 35/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 34/50, Train Loss: 1.1382, Train Accuracy: 59.26%, Val Loss: 1.0802, Val Accuracy: 61.54%


Epoch 36/50:   1%|▎                                                    | 1/196 [00:00<00:38,  5.11it/s, Train Loss=1.25]

Epoch 35/50, Train Loss: 1.1345, Train Accuracy: 59.31%, Val Loss: 1.1131, Val Accuracy: 60.47%


Epoch 37/50:   1%|▎                                                    | 1/196 [00:00<00:37,  5.14it/s, Train Loss=1.05]

Epoch 36/50, Train Loss: 1.1343, Train Accuracy: 59.49%, Val Loss: 1.0811, Val Accuracy: 61.49%


Epoch 38/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 37/50, Train Loss: 1.1301, Train Accuracy: 59.63%, Val Loss: 1.0684, Val Accuracy: 61.89%


Epoch 39/50:   1%|▎                                                    | 1/196 [00:00<00:38,  5.13it/s, Train Loss=1.11]

Epoch 38/50, Train Loss: 1.1211, Train Accuracy: 59.73%, Val Loss: 1.1246, Val Accuracy: 60.24%


Epoch 40/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 39/50, Train Loss: 1.1164, Train Accuracy: 60.16%, Val Loss: 1.0518, Val Accuracy: 62.17%


Epoch 41/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 40/50, Train Loss: 1.1102, Train Accuracy: 60.27%, Val Loss: 1.0503, Val Accuracy: 62.70%


Epoch 42/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch 41/50, Train Loss: 1.1061, Train Accuracy: 60.32%, Val Loss: 1.0140, Val Accuracy: 64.33%


Epoch 43/50:   1%|▎                                                    | 1/196 [00:00<00:38,  5.10it/s, Train Loss=1.04]

Epoch 42/50, Train Loss: 1.1017, Train Accuracy: 60.69%, Val Loss: 1.0326, Val Accuracy: 63.54%


Epoch 44/50:   1%|▎                                                     | 1/196 [00:00<00:38,  5.12it/s, Train Loss=1.1]

Epoch 43/50, Train Loss: 1.0962, Train Accuracy: 60.97%, Val Loss: 1.0429, Val Accuracy: 62.45%


Epoch 45/50:   1%|▎                                                    | 1/196 [00:00<00:38,  5.12it/s, Train Loss=1.15]

Epoch 44/50, Train Loss: 1.0932, Train Accuracy: 60.88%, Val Loss: 1.0419, Val Accuracy: 62.99%


Epoch 46/50:   0%|                                                                | 0/196 [00:00<?, ?it/s, Train Loss=1]

Epoch 45/50, Train Loss: 1.0879, Train Accuracy: 61.10%, Val Loss: 1.0425, Val Accuracy: 62.53%


Epoch 47/50:   1%|▎                                                    | 1/196 [00:00<00:38,  5.12it/s, Train Loss=1.06]

Epoch 46/50, Train Loss: 1.0823, Train Accuracy: 61.38%, Val Loss: 1.0473, Val Accuracy: 62.32%


Epoch 48/50:   1%|▎                                                    | 1/196 [00:00<00:37,  5.13it/s, Train Loss=1.04]

Epoch 47/50, Train Loss: 1.0784, Train Accuracy: 61.52%, Val Loss: 1.0472, Val Accuracy: 62.54%


Epoch 49/50:   1%|▎                                                    | 1/196 [00:00<00:38,  5.12it/s, Train Loss=1.04]

Epoch 48/50, Train Loss: 1.0690, Train Accuracy: 61.86%, Val Loss: 1.0343, Val Accuracy: 63.00%


Epoch 50/50:   1%|▎                                                    | 1/196 [00:00<00:37,  5.14it/s, Train Loss=1.05]

Epoch 49/50, Train Loss: 1.0728, Train Accuracy: 61.53%, Val Loss: 1.0479, Val Accuracy: 62.79%


                                                                                                                        

Epoch 50/50, Train Loss: 1.0618, Train Accuracy: 62.00%, Val Loss: 0.9984, Val Accuracy: 63.92%


Here is the model:

```bash
AttentionLayers(
  (layers): ModuleList(
    (0): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (1): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (2): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (3): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (4): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (5): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (6): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (7): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (8): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (9): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (10): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (11): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
  )
  (adaptive_mlp): Identity()
  (final_norm): LayerNorm(
    (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
  )
  (skip_combines): ModuleList(
    (0): None
    (1): None
    (2): None
    (3): None
    (4): None
    (5): None
    (6): None
    (7): None
    (8): None
    (9): None
    (10): None
    (11): None
  )
)

```