# Vit for classification in CIFAR-10

In [1]:
%load_ext autoreload
%autoreload 2
!nvidia-smi
!which python

Wed Sep 25 10:41:05 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   42C    P0              40W / 300W |      1MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
try:
    from ViT.train import train
    from ViT.utils import cifar_train_set, cifar_val_set
    from ViT.model import *
except:
    from train import train
    from utils import cifar_train_set, cifar_val_set
    from model import *

from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F

import datetime

train_loader = DataLoader(cifar_train_set, 256, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(cifar_val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

import os
if not os.path.exists("ViT/log"):
    os.makedirs("ViT/log")

def timestr():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%S")

def get_outdir(time_str):
    outdir = f"ViT/log/{time_str}.out"
    return outdir

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Here is the hyperparameters

epochs = 50
patch_size = 8
embed_dim = 256
n_layers = 6
heads = 8
attn_dim = 512
mlp_dim = None # default to 4*embed_dim
pool = 'cls'
dropout = 0.0

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attndim=attn_dim, mlp_dim=mlp_dim, pool=pool, dropout=dropout)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

ViT(
  (patch_embedding): Linear(in_features=48, out_features=256, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Sequential(
    (0): tranformer_layer(
      (QKV): Linear(in_features=256, out_features=768, bias=True)
      (fc): Linear(in_features=256, out_features=256, bias=True)
      (mlp): Sequential(
        (0): Linear(in_features=256, out_features=512, bias=True)
        (1): GELU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=512, out_features=256, bias=True)
      )
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): tranformer_layer(
      (QKV): Linear(in_features=256, out_features=768, bias=True)
      (fc): Linear(in_features=256, out_features=256, bias=True)
      (mlp): Sequential(
        (0): Linear(in_features=256, out_features=512, bias=True)
        (1): GELU()
    

Epoch 2/50:   0%|                                       | 0/196 [00:00<?, ?it/s]

Epoch [1/50], Train Loss: 1.9702, Val Loss: 1.7904, Val Accuracy: 34.23%


Epoch 3/50:   0%|                                       | 0/196 [00:00<?, ?it/s]

Epoch [2/50], Train Loss: 1.6894, Val Loss: 1.5948, Val Accuracy: 41.66%


Epoch 4/50:   0%|                                       | 0/196 [00:00<?, ?it/s]

Epoch [3/50], Train Loss: 1.5796, Val Loss: 1.5294, Val Accuracy: 44.84%


Epoch 5/50:   0%|                                       | 0/196 [00:00<?, ?it/s]

Epoch [4/50], Train Loss: 1.5141, Val Loss: 1.5360, Val Accuracy: 45.52%


Epoch 6/50:   0%|                                       | 0/196 [00:00<?, ?it/s]

Epoch [5/50], Train Loss: 1.4673, Val Loss: 1.5013, Val Accuracy: 46.47%


Epoch 7/50:   0%|                                       | 0/196 [00:00<?, ?it/s]

Epoch [6/50], Train Loss: 1.4444, Val Loss: 1.4376, Val Accuracy: 48.24%


Epoch 8/50:   0%|                                       | 0/196 [00:00<?, ?it/s]

Epoch [7/50], Train Loss: 1.4211, Val Loss: 1.4136, Val Accuracy: 49.21%


Epoch 9/50:   0%|                                       | 0/196 [00:00<?, ?it/s]

Epoch [8/50], Train Loss: 1.3926, Val Loss: 1.3810, Val Accuracy: 50.26%


Epoch 10/50:   0%|                                      | 0/196 [00:00<?, ?it/s]

Epoch [9/50], Train Loss: 1.3671, Val Loss: 1.3539, Val Accuracy: 51.37%


Epoch 11/50:   0%|                                      | 0/196 [00:00<?, ?it/s]

Epoch [10/50], Train Loss: 1.3485, Val Loss: 1.3562, Val Accuracy: 51.58%


Epoch 12/50:   0%|                                      | 0/196 [00:00<?, ?it/s]

Epoch [11/50], Train Loss: 1.3398, Val Loss: 1.3275, Val Accuracy: 52.27%


Epoch 13/50:   0%|                                      | 0/196 [00:00<?, ?it/s]

Epoch [12/50], Train Loss: 1.3268, Val Loss: 1.3013, Val Accuracy: 53.18%


Epoch 14/50:   0%|                                      | 0/196 [00:00<?, ?it/s]

Epoch [13/50], Train Loss: 1.3118, Val Loss: 1.3450, Val Accuracy: 51.51%


Epoch 15/50:   0%|                                      | 0/196 [00:00<?, ?it/s]

Epoch [14/50], Train Loss: 1.2965, Val Loss: 1.2867, Val Accuracy: 54.31%


Epoch 15/50:  61%|██████▋    | 119/196 [00:24<00:15,  4.89it/s, Train Loss=1.29]

KeyboardInterrupt: 

In [5]:
# save the model
torch.save(model.state_dict(), 'ViT.pth')