# Vit for classification in CIFAR-10

In [1]:
%load_ext autoreload
%autoreload 2
!nvidia-smi
!which python

Tue Oct  1 20:07:42 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   43C    P0              40W / 300W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [6]:
try:
    from ViT.train import train
    from ViT.utils import cifar_train_set, cifar_val_set
    from ViT.model import *
except:
    from train import train
    from utils import cifar_train_set, cifar_val_set
    from model import *

from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F

import datetime

train_loader = DataLoader(cifar_train_set, 256, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(cifar_val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

import os
if not os.path.exists("ViT/log"):
    os.makedirs("ViT/log")

def timestr():
    now = datetime.datetime.now()
    return now.strftime("%Y%m%d_%H%M%S")

def get_outdir(time_str):
    outdir = f"ViT/log/{time_str}.out"
    return outdir

In [3]:
# Here is the hyperparameters

epochs = 50
patch_size = 4
embed_dim = 256
n_layers = 6
heads = 8
attn_mlp_dim = 512
mlp_dim = 512
pool = 'cls'

model = ViT(image_size=32, patch_size=patch_size, num_classes=10, embed_dim=embed_dim, n_layers=n_layers, heads=heads, attn_mlp_dim=attn_mlp_dim, mlp_dim=mlp_dim, pool=pool)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999), weight_decay=5e-4)

time_str = timestr()

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

The model has 6,334,986 trainable parameters


Epoch 2/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [1/50], Train Loss: 1.9715, Val Loss: 1.8063, Val Accuracy: 33.16%


Epoch 3/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [2/50], Train Loss: 1.7009, Val Loss: 1.5996, Val Accuracy: 41.37%


Epoch 4/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [3/50], Train Loss: 1.5775, Val Loss: 1.5126, Val Accuracy: 45.01%


Epoch 5/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [4/50], Train Loss: 1.5039, Val Loss: 1.4591, Val Accuracy: 47.28%


Epoch 6/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [5/50], Train Loss: 1.4555, Val Loss: 1.4533, Val Accuracy: 47.61%


Epoch 7/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [6/50], Train Loss: 1.4213, Val Loss: 1.4251, Val Accuracy: 48.24%


Epoch 8/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [7/50], Train Loss: 1.3732, Val Loss: 1.3547, Val Accuracy: 50.49%


Epoch 9/50:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [8/50], Train Loss: 1.3421, Val Loss: 1.3035, Val Accuracy: 53.10%


Epoch 10/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [9/50], Train Loss: 1.3066, Val Loss: 1.3063, Val Accuracy: 52.82%


Epoch 11/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [10/50], Train Loss: 1.2710, Val Loss: 1.2850, Val Accuracy: 53.82%


Epoch 12/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [11/50], Train Loss: 1.2539, Val Loss: 1.2236, Val Accuracy: 55.87%


Epoch 13/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [12/50], Train Loss: 1.2296, Val Loss: 1.2039, Val Accuracy: 56.60%


Epoch 14/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [13/50], Train Loss: 1.2095, Val Loss: 1.1686, Val Accuracy: 57.51%


Epoch 15/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [14/50], Train Loss: 1.1823, Val Loss: 1.1618, Val Accuracy: 57.72%


Epoch 16/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [15/50], Train Loss: 1.1709, Val Loss: 1.1637, Val Accuracy: 57.72%


Epoch 17/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [16/50], Train Loss: 1.1508, Val Loss: 1.1281, Val Accuracy: 59.05%


Epoch 18/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [17/50], Train Loss: 1.1354, Val Loss: 1.1497, Val Accuracy: 58.45%


Epoch 19/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [18/50], Train Loss: 1.1185, Val Loss: 1.0975, Val Accuracy: 61.03%


Epoch 20/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [19/50], Train Loss: 1.0979, Val Loss: 1.1393, Val Accuracy: 59.32%


Epoch 21/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [20/50], Train Loss: 1.0932, Val Loss: 1.0734, Val Accuracy: 61.71%


Epoch 22/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [21/50], Train Loss: 1.0782, Val Loss: 1.0953, Val Accuracy: 60.38%


Epoch 23/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [22/50], Train Loss: 1.0649, Val Loss: 1.0984, Val Accuracy: 61.21%


Epoch 24/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [23/50], Train Loss: 1.0503, Val Loss: 1.0588, Val Accuracy: 62.49%


Epoch 25/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [24/50], Train Loss: 1.0415, Val Loss: 1.0795, Val Accuracy: 61.41%


Epoch 26/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [25/50], Train Loss: 1.0278, Val Loss: 1.0214, Val Accuracy: 63.08%


Epoch 27/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [26/50], Train Loss: 1.0186, Val Loss: 1.0242, Val Accuracy: 63.53%


Epoch 28/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [27/50], Train Loss: 1.0136, Val Loss: 1.0721, Val Accuracy: 61.76%


Epoch 29/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [28/50], Train Loss: 1.0086, Val Loss: 1.0313, Val Accuracy: 62.72%


Epoch 30/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [29/50], Train Loss: 1.0058, Val Loss: 1.0316, Val Accuracy: 63.10%


Epoch 31/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [30/50], Train Loss: 0.9887, Val Loss: 1.0428, Val Accuracy: 63.11%


Epoch 32/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [31/50], Train Loss: 0.9799, Val Loss: 0.9986, Val Accuracy: 64.45%


Epoch 33/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [32/50], Train Loss: 0.9824, Val Loss: 1.0147, Val Accuracy: 64.13%


Epoch 34/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [33/50], Train Loss: 0.9703, Val Loss: 1.0015, Val Accuracy: 63.89%


Epoch 35/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [34/50], Train Loss: 0.9654, Val Loss: 0.9525, Val Accuracy: 66.09%


Epoch 36/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [35/50], Train Loss: 0.9552, Val Loss: 0.9894, Val Accuracy: 64.21%


Epoch 37/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [36/50], Train Loss: 0.9501, Val Loss: 0.9758, Val Accuracy: 64.40%


Epoch 38/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [37/50], Train Loss: 0.9415, Val Loss: 0.9644, Val Accuracy: 66.52%


Epoch 39/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [38/50], Train Loss: 0.9422, Val Loss: 0.9623, Val Accuracy: 65.64%


Epoch 40/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [39/50], Train Loss: 0.9344, Val Loss: 0.9894, Val Accuracy: 65.28%


Epoch 41/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [40/50], Train Loss: 0.9247, Val Loss: 0.9746, Val Accuracy: 64.39%


Epoch 42/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [41/50], Train Loss: 0.9261, Val Loss: 0.9519, Val Accuracy: 66.59%


Epoch 43/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [42/50], Train Loss: 0.9126, Val Loss: 1.0129, Val Accuracy: 63.99%


Epoch 44/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [43/50], Train Loss: 0.9080, Val Loss: 0.9260, Val Accuracy: 67.39%


Epoch 45/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [44/50], Train Loss: 0.9058, Val Loss: 0.9473, Val Accuracy: 65.90%


Epoch 46/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [45/50], Train Loss: 0.9031, Val Loss: 0.9362, Val Accuracy: 67.16%


Epoch 47/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [46/50], Train Loss: 0.8964, Val Loss: 0.9393, Val Accuracy: 66.59%


Epoch 48/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [47/50], Train Loss: 0.8894, Val Loss: 0.9557, Val Accuracy: 66.18%


Epoch 49/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [48/50], Train Loss: 0.8880, Val Loss: 0.9249, Val Accuracy: 67.41%


Epoch 50/50:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [49/50], Train Loss: 0.8809, Val Loss: 0.9021, Val Accuracy: 68.07%


                                                                                                                        

Epoch [50/50], Train Loss: 0.8700, Val Loss: 0.9801, Val Accuracy: 65.52%


([1.971512812132738,
  1.7008938789367676,
  1.5775153722081865,
  1.5038623791568133,
  1.455479241755544,
  1.4213460665576312,
  1.3731926284274276,
  1.3421498555309919,
  1.3066049765567391,
  1.270990970183392,
  1.2538534183891452,
  1.2295535535228497,
  1.209460308357161,
  1.1822837128931163,
  1.1709410134626894,
  1.150828129174758,
  1.1354368189159705,
  1.1185487612169616,
  1.0979098580321487,
  1.0932189293053685,
  1.0781740795592873,
  1.0649321380318428,
  1.0503459989416355,
  1.0414876907455677,
  1.0278346727089005,
  1.0185638307308664,
  1.0135750378272972,
  1.0086105070552047,
  1.0058433045538104,
  0.9887293206185711,
  0.9799468304429736,
  0.982350057181047,
  0.9703457075722364,
  0.9653663066576939,
  0.9552437115688713,
  0.9500971001629926,
  0.941549830290736,
  0.9422043859958649,
  0.9344139588730676,
  0.9247079038498353,
  0.9260517915292662,
  0.9126110301942242,
  0.9080354063486566,
  0.9058180122959371,
  0.9030759653874806,
  0.8963843444172

In [7]:
# Here is the hyperparameters

epochs = 20
patch_size = 4
embed_dim = 256
n_layers = 6
heads = 8
attn_mlp_dim = 512
mlp_dim = 512
pool = 'cls'

time_str = timestr()

print(f"Time string: {time_str}")

# print the model and the number of parameters
# print(model.transformer)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

train(epochs=epochs, model=model, optimizer=optimizer, criterion=nn.CrossEntropyLoss(), 
      train_loader=train_loader, val_loader=val_loader, outdir=get_outdir(time_str))

Epoch 1/20:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

The model has 6,334,986 trainable parameters


Epoch 2/20:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [1/20], Train Loss: 0.8726, Val Loss: 0.9260, Val Accuracy: 67.18%


Epoch 3/20:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [2/20], Train Loss: 0.8689, Val Loss: 0.9008, Val Accuracy: 68.13%


Epoch 4/20:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [3/20], Train Loss: 0.8643, Val Loss: 0.9199, Val Accuracy: 67.56%


Epoch 5/20:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [4/20], Train Loss: 0.8616, Val Loss: 0.8978, Val Accuracy: 68.13%


Epoch 6/20:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [5/20], Train Loss: 0.8571, Val Loss: 0.9314, Val Accuracy: 67.13%


Epoch 7/20:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [6/20], Train Loss: 0.8546, Val Loss: 0.9029, Val Accuracy: 67.84%


Epoch 8/20:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [7/20], Train Loss: 0.8471, Val Loss: 0.8969, Val Accuracy: 68.29%


Epoch 9/20:   0%|                                                                               | 0/196 [00:00<?, ?it/s]

Epoch [8/20], Train Loss: 0.8410, Val Loss: 0.8803, Val Accuracy: 68.23%


Epoch 10/20:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [9/20], Train Loss: 0.8366, Val Loss: 0.9122, Val Accuracy: 67.71%


Epoch 11/20:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [10/20], Train Loss: 0.8391, Val Loss: 0.8842, Val Accuracy: 68.72%


Epoch 12/20:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [11/20], Train Loss: 0.8349, Val Loss: 0.8940, Val Accuracy: 68.07%


Epoch 13/20:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [12/20], Train Loss: 0.8315, Val Loss: 0.8815, Val Accuracy: 68.19%


Epoch 14/20:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [13/20], Train Loss: 0.8218, Val Loss: 0.9012, Val Accuracy: 68.15%


Epoch 15/20:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [14/20], Train Loss: 0.8244, Val Loss: 0.8954, Val Accuracy: 68.39%


Epoch 16/20:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [15/20], Train Loss: 0.8170, Val Loss: 0.8791, Val Accuracy: 68.49%


Epoch 17/20:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [16/20], Train Loss: 0.8150, Val Loss: 0.8640, Val Accuracy: 69.16%


Epoch 18/20:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [17/20], Train Loss: 0.8092, Val Loss: 0.8790, Val Accuracy: 68.65%


Epoch 19/20:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [18/20], Train Loss: 0.8048, Val Loss: 0.8885, Val Accuracy: 68.46%


Epoch 20/20:   0%|                                                                              | 0/196 [00:00<?, ?it/s]

Epoch [19/20], Train Loss: 0.8014, Val Loss: 0.8918, Val Accuracy: 68.79%


                                                                                                                        

Epoch [20/20], Train Loss: 0.8041, Val Loss: 0.8647, Val Accuracy: 69.97%


Here is the model:

```bash
AttentionLayers(
  (layers): ModuleList(
    (0): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (1): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (2): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (3): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (4): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (5): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (6): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (7): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (8): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (9): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
    (10): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): Attention(
        (to_q): Linear(in_features=256, out_features=512, bias=False)
        (to_k): Linear(in_features=256, out_features=512, bias=False)
        (to_v): Linear(in_features=256, out_features=512, bias=False)
        (attend): Attend(
          (attn_dropout): Dropout(p=0.0, inplace=False)
        )
        (to_out): Linear(in_features=512, out_features=256, bias=False)
      )
      (2): Residual()
    )
    (11): ModuleList(
      (0): ModuleList(
        (0): LayerNorm(
          (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
        )
        (1): None
        (2): None
      )
      (1): FeedForward(
        (ff): Sequential(
          (0): Sequential(
            (0): Linear(in_features=256, out_features=1024, bias=True)
            (1): MyGELU()
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
      (2): Residual()
    )
  )
  (adaptive_mlp): Identity()
  (final_norm): LayerNorm(
    (ln): LayerNorm((256,), eps=1e-05, elementwise_affine=False)
  )
  (skip_combines): ModuleList(
    (0): None
    (1): None
    (2): None
    (3): None
    (4): None
    (5): None
    (6): None
    (7): None
    (8): None
    (9): None
    (10): None
    (11): None
  )
)

```

In [8]:
# save the model
torch.save(model.state_dict(), 'ViT.pth')