In [None]:
import numpy as np
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import lightning as L
from lightning import Trainer
import torchvision
from torch.utils.data import DataLoader, Subset
from PIL import Image
from vit_models import VisionTransformer
import torch.nn.functional as F
from torchmetrics.classification import Accuracy

In [None]:
batch_size = 200
lr = 0.001

In [None]:
compose = transforms.Compose([
# transforms.Resize(size=(384, 384), antialias=True),
transforms.ToTensor()
])
# plt.imshow(torch.permute(resize(train_dataset[index][0]), (1,2,0)).numpy())

In [None]:
train_dataset = torchvision.datasets.ImageFolder(root='/Users/ykamoji/Documents/ImageDatabase/imageNet/tiny-imagenet-200/train', transform=compose)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=11, persistent_workers=True)

In [None]:
labels = {}
for mapping in open('/Users/ykamoji/Documents/ImageDatabase/imageNet/tiny-imagenet-200/words.txt','r').readlines():
    maps = mapping.replace('\n','').split('\t')
    labels[maps[0]] = maps[1]

def map_labels(dataset):
    class_labels = {}
    for ind in range(0, 100000, 500):
        class_name = dataset.imgs[ind][0].split('/')[-1].split('_')[0]
        class_label = dataset[ind][1]
        class_labels[class_label] = class_name
    return class_labels

# for ind, (k,v) in enumerate(labels.items()):
#     print(k,v)
#     if ind > 10:
#         break

class_labels = map_labels(train_dataset)

for ind, (k, v) in enumerate(class_labels.items()):
    print(k, v, labels[v])
    if ind == 20:
        break

In [None]:
index = 4600
plt.imshow(torch.permute(train_dataset[index][0], (1,2,0)).numpy())
label_class = train_dataset.imgs[index][0].split('/')[-1].split('_')[0]
label_num = train_dataset[index][1]
print(label_num, label_class, labels[class_labels[label_num]])
# class_name = class_labels[labels[label_num]]

In [None]:
custom_config = {
        "img_size": 64,
        "in_chans": 3,
        "patch_size": 16,
        "embed_dim": 768,
        "depth": 12,
        "n_heads": 12,
        "qkv_bias": True,
        "mlp_ratio": 4,
        "n_classes":200
}

In [None]:
class VisionTransformerWrapper(L.LightningModule):
    
    def __init__(self, model, **kvargs):
        super(VisionTransformerWrapper, self).__init__()
        self.model = model(**kvargs)
        self.accuracy = Accuracy(task="multiclass", num_classes=200)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        # print(images.shape)
        outputs = self(images)
        loss = F.cross_entropy(outputs, labels)
        self.log("loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        self.accuracy(outputs, labels)
        
        self.log('train_acc_step', self.accuracy)
        
        return loss
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=lr)
        
    def train_dataloader(self):
       train_dataset = torchvision.datasets.ImageFolder(root='/Users/ykamoji/Documents/ImageDatabase/imageNet/tiny-imagenet-200/train', transform=compose)
       train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=11, persistent_workers=True)
       return train_loader

In [None]:
trainer = Trainer(max_epochs = 5,fast_dev_run=True, accelerator="mps", devices=1)
model = VisionTransformerWrapper(VisionTransformer, **custom_config)
trainer.fit(model)

In [None]:
model.eval()

In [None]:
img = train_dataset[0][0]
# img = (np.array(img) / 128) - 1  # in the range -1, 1
print(img.shape)
logits = model(img.unsqueeze(0).to(torch.float32))
# print(logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
top_probs, top_ixs = probs[0].topk(5)
# print(f"\n{image}")
for i, (ix_, prob_) in enumerate(zip(top_ixs, top_probs)):
    ix = ix_.item()
    prob = prob_.item()
    cls = labels[class_labels[ix]]
    print(f"{i}: {cls:<45} --- {prob:.4f}")

In [None]:
torch.mps.empty_cache()