In [51]:
import numpy as np
import torch
import torch.nn as nn
from timm.models import create_model
from vit_models import VisionTransformer
import lightning as L
from lightning import Trainer
import torchvision
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchmetrics.classification import Accuracy
from torchvision import transforms

In [52]:
def get_n_params(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

def assert_tensors_equal(t1, t2):
    a1, a2, = t1.detach().numpy(), t2.detach().numpy()
    np.testing.assert_allclose(a1, a2)

In [53]:
model_name = "vit_base_patch16_384"
model_official = create_model(model_name, pretrained=True)
model_official.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [54]:
custom_config = {
        "img_size": 384,
        "in_chans": 3,
        "patch_size": 16,
        "embed_dim": 768,
        "depth": 12,
        "n_heads": 12,
        "qkv_bias": True,
        "mlp_ratio": 4,
        "n_classes": 1000
}

In [55]:
model_custom = VisionTransformer(**custom_config)
model_custom.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=768, out_features=1000, bi

In [56]:
for (n_o, p_o), (n_c, p_c) in zip(model_official.named_parameters(), model_custom.named_parameters()):
    assert p_o.numel() == p_c.numel()
    print(f"{n_o} | {n_c}")
    p_c.data[:] = p_o.data
    assert_tensors_equal(p_c.data, p_o.data)

cls_token | cls_token
pos_embed | pos_embed
patch_embed.proj.weight | patch_embed.proj.weight
patch_embed.proj.bias | patch_embed.proj.bias
blocks.0.norm1.weight | blocks.0.norm1.weight
blocks.0.norm1.bias | blocks.0.norm1.bias
blocks.0.attn.qkv.weight | blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias | blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight | blocks.0.attn.proj.weight
blocks.0.attn.proj.bias | blocks.0.attn.proj.bias
blocks.0.norm2.weight | blocks.0.norm2.weight
blocks.0.norm2.bias | blocks.0.norm2.bias
blocks.0.mlp.fc1.weight | blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias | blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight | blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias | blocks.0.mlp.fc2.bias
blocks.1.norm1.weight | blocks.1.norm1.weight
blocks.1.norm1.bias | blocks.1.norm1.bias
blocks.1.attn.qkv.weight | blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias | blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight | blocks.1.attn.proj.weight
blocks.1.attn.proj.bias | blocks.1.attn.proj.b

In [57]:
for param in model_custom.parameters():
    param.requires_grad = False
    
model_custom.head = nn.Linear(custom_config['embed_dim'], 200)

In [58]:
model_custom.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (head): Linear(in_features=768, out_features=200, bia

In [93]:
compose = transforms.Compose([
transforms.Resize(size=(384, 384), antialias=True),
transforms.ToTensor()
])

In [85]:
train_dataset = torchvision.datasets.ImageFolder(root='/Users/ykamoji/Documents/ImageDatabase/imageNet/tiny-imagenet-200/train', transform=compose)
test_dataset = torchvision.datasets.ImageFolder(root='/Users/ykamoji/Documents/ImageDatabase/imageNet/tiny-imagenet-200/val', transform=compose)

In [86]:
train_labels = {}
for mapping in open('/Users/ykamoji/Documents/ImageDatabase/imageNet/tiny-imagenet-200/words.txt','r').readlines():
    maps = mapping.replace('\n','').split('\t')
    train_labels[maps[0]] = maps[1]

def train_map_labels(dataset):
    class_labels = {}
    for ind in range(0, 100000, 500):
        class_name = dataset.imgs[ind][0].split('/')[-1].split('_')[0]
        class_label = dataset[ind][1]
        class_labels[class_label] = class_name
    return class_labels

# for ind, (k,v) in enumerate(train_labels.items()):
#     print(k,v)
#     if ind > 10:
#         break

train_class_labels = train_map_labels(train_dataset)

for ind, (k, v) in enumerate(train_class_labels.items()):
    print(k, v, train_labels[v])
    if ind == 20:
        break

0 n01443537 goldfish, Carassius auratus
1 n01629819 European fire salamander, Salamandra salamandra
2 n01641577 bullfrog, Rana catesbeiana
3 n01644900 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
4 n01698640 American alligator, Alligator mississipiensis
5 n01742172 boa constrictor, Constrictor constrictor
6 n01768244 trilobite
7 n01770393 scorpion
8 n01774384 black widow, Latrodectus mactans
9 n01774750 tarantula
10 n01784675 centipede
11 n01855672 goose
12 n01882714 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
13 n01910747 jellyfish
14 n01917289 brain coral
15 n01944390 snail
16 n01945685 slug
17 n01950731 sea slug, nudibranch
18 n01983481 American lobster, Northern lobster, Maine lobster, Homarus americanus
19 n01984695 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
20 n02002724 black stork, Ciconia nigra


In [87]:
for i in range(0, 10000, 50):
    class_label = test_dataset.imgs[i][0].split('/')[-3]
    print(test_dataset[i][1], class_label, train_labels[class_label])
    if i == 1000:
        break

0 n01443537 goldfish, Carassius auratus
1 n01629819 European fire salamander, Salamandra salamandra
2 n01641577 bullfrog, Rana catesbeiana
3 n01644900 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
4 n01698640 American alligator, Alligator mississipiensis
5 n01742172 boa constrictor, Constrictor constrictor
6 n01768244 trilobite
7 n01770393 scorpion
8 n01774384 black widow, Latrodectus mactans
9 n01774750 tarantula
10 n01784675 centipede
11 n01855672 goose
12 n01882714 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
13 n01910747 jellyfish
14 n01917289 brain coral
15 n01944390 snail
16 n01945685 slug
17 n01950731 sea slug, nudibranch
18 n01983481 American lobster, Northern lobster, Maine lobster, Homarus americanus
19 n01984695 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
20 n02002724 black stork, Ciconia nigra


In [94]:
batch_size = 300
lr = 0.001
class VisionTransformerWrapper(L.LightningModule):
    
    def __init__(self, model):
        super(VisionTransformerWrapper, self).__init__()
        self.model = model
        self.accuracy = Accuracy(task="multiclass", num_classes=200)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, log_name="train")
    
    def validation_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, log_name="val")
        
    def _common_step(self, batch, batch_idx, log_name=""):
        images, labels = batch
        outputs = self(images)
        # print(outputs)
        # print(labels)
        
        loss = F.cross_entropy(outputs, labels)
        self.log(f"{log_name}_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        self.accuracy(outputs, labels)
        
        self.log(f"{log_name}_acc", self.accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
        
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=lr)
        
    def train_dataloader(self):
        train_dataset = torchvision.datasets.ImageFolder(root='/Users/ykamoji/Documents/ImageDatabase/imageNet/tiny-imagenet-200/train', transform=compose)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=11, persistent_workers=True)
        return train_loader

    def val_dataloader(self):
        test_dataset = torchvision.datasets.ImageFolder(root='/Users/ykamoji/Documents/ImageDatabase/imageNet/tiny-imagenet-200/val', transform=compose)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=11, persistent_workers=True)
        return test_loader

In [95]:
model_custom.train()
trainer = Trainer(max_epochs=1, fast_dev_run=False, 
                  log_every_n_steps=6, 
                  val_check_interval=0.01, 
                  limit_val_batches=0.05,
                  max_steps=50)
model = VisionTransformerWrapper(model_custom)
trainer.fit(model)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name     | Type               | Params
------------------------------------------------
0 | model    | VisionTransformer  | 86.2 M
1 | accuracy | MulticlassAccuracy | 0     
------------------------------------------------
153 K     Trainable params
86.1 M    Non-trainable params
86.2 M    Total params
344.977   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

RuntimeError: MPS backend out of memory (MPS allocated: 13.82 GB, other allocations: 19.51 GB, max allowed: 36.27 GB). Tried to allocate 4.46 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).