In [1]:
import warnings
warnings.filterwarnings("ignore")

import torch
import numpy as np
import torchvision
import matplotlib.pyplot as plt
import multiprocessing
import torchvision
import timm
import os
from transformers import *
from datasets import load_dataset
from PIL import Image
from torchinfo import summary
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import time

from torch.amp import autocast
from torch.cuda.amp import GradScaler

torch.cuda.empty_cache()
device="cuda" if torch.cuda.is_available() else "cpu"

    PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.0.1+cu117)
    Python  3.9.18 (you have 3.9.17)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


### Loading in data and pretrained model

In [2]:
timm.list_models("vit*",pretrained=True)

['vit_base_patch8_224.augreg2_in21k_ft_in1k',
 'vit_base_patch8_224.augreg_in21k',
 'vit_base_patch8_224.augreg_in21k_ft_in1k',
 'vit_base_patch8_224.dino',
 'vit_base_patch14_dinov2.lvd142m',
 'vit_base_patch16_224.augreg2_in21k_ft_in1k',
 'vit_base_patch16_224.augreg_in1k',
 'vit_base_patch16_224.augreg_in21k',
 'vit_base_patch16_224.augreg_in21k_ft_in1k',
 'vit_base_patch16_224.dino',
 'vit_base_patch16_224.mae',
 'vit_base_patch16_224.orig_in21k_ft_in1k',
 'vit_base_patch16_224.sam_in1k',
 'vit_base_patch16_224_miil.in21k',
 'vit_base_patch16_224_miil.in21k_ft_in1k',
 'vit_base_patch16_384.augreg_in1k',
 'vit_base_patch16_384.augreg_in21k_ft_in1k',
 'vit_base_patch16_384.orig_in21k_ft_in1k',
 'vit_base_patch16_clip_224.datacompxl',
 'vit_base_patch16_clip_224.laion2b',
 'vit_base_patch16_clip_224.laion2b_ft_in1k',
 'vit_base_patch16_clip_224.laion2b_ft_in12k',
 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k',
 'vit_base_patch16_clip_224.openai',
 'vit_base_patch16_clip_224.openai

In [3]:
batch_size=32
cpu_count=multiprocessing.cpu_count()
model_name="vit_base_patch16_siglip_512.webli"

model = timm.create_model(model_name, pretrained=True)

data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

transform_data=torchvision.transforms.Compose([torchvision.transforms.Resize(size=(224,224)),
                                                torchvision.transforms.ToTensor()])

train_data=torchvision.datasets.ImageFolder("./data/train",transform=transforms)
test_data=torchvision.datasets.ImageFolder("./data/test",transform=transforms)

train_loader=torch.utils.data.DataLoader(train_data,shuffle=True,batch_size=batch_size,num_workers=cpu_count)
test_loader=torch.utils.data.DataLoader(test_data,shuffle=True,batch_size=batch_size,num_workers=cpu_count)

RuntimeError: Invalid pretrained tag (webli') for vit_base_patch16_siglip_512.

##### Freezing layers

In [None]:
count=0
for param in model.parameters(): #153 params
    if(count>=120):
        break
    param.requires_grad=False
    count+=1

##### Model Modification

In [None]:
#Changing the classification part
model.head=nn.Linear(in_features=model.head.in_features,out_features=1024)

class modified_vit(nn.Module):
    def __init__(self,model):
        super().__init__()
        self.model=model
        self.sequential=nn.Sequential(nn.Linear(in_features=1024,out_features=2048),
                                      nn.ReLU(),
                                      nn.Linear(in_features=2048,out_features=512),
                                      nn.ReLU(),
                                      nn.Linear(in_features=512,out_features=256),
                                      nn.ReLU(),
                                      nn.Linear(in_features=256,out_features=128),
                                      nn.ReLU(),
                                      nn.Linear(in_features=128,out_features=2))
    def forward(self,x):
        return self.sequential(self.model(x))

In [None]:
#model=modified_vit(model)
model=model.to(device)

In [None]:
summary(model)

In [None]:
model

### Visualizing Transformed Data

In [None]:
plt.figure(figsize=(9,9))
for i in range(1,10):
    data=next(iter(train_loader))
    rand_ind=torch.randint(0,batch_size-1,size=(1,)).item()
    label=data[1][rand_ind]
    image=data[0][rand_ind].permute(1,2,0)
    plt.title(train_data.classes[label.item()])
    plt.subplot(3,3,i)
    plt.axis(False)
    plt.imshow(image)
    

### Optimizer and Loss and Logdir

In [None]:
optimizer=torch.optim.AdamW(model.parameters())
loss_fn=nn.CrossEntropyLoss()

In [None]:
log_dir = "./"+model_name+"feature_extractor_tensorboard"
writer = SummaryWriter(log_dir=log_dir)

def accuracy_fn(logits,true):
    return torch.eq(torch.argmax(torch.softmax(logits,dim=1),dim=1).squeeze(),true).sum().item()/len(logits)
    

### Testing model

In [None]:
model.eval()
with torch.inference_mode():
    x,y=next(iter(train_loader))
    x=x.to(device)
    y=y.to(device)
    logits=model(x)
    print(logits)
    print(loss_fn(logits.squeeze(),y))
    print(accuracy_fn(logits,y))

### Training Loop

In [None]:
epochs=10

train_accuracy=[]
test_accuracy=[]
train_loss=[]
test_loss=[]

for i in tqdm(range(epochs)):
    print("Training:")
    model.train()
    
    #Defining accuracy and loss for train and test data
    temp_train_accuracy=[]
    temp_test_accuracy=[]
    temp_train_loss=[]
    temp_test_loss=[]

    net_train_accuracy=0
    net_test_accuracy=0
    net_train_loss=0
    net_test_loss=0
    
    with tqdm(total=len(train_loader)) as pbar:
        for x,y in train_loader:
                x=x.to(device)
                y=y.to(device)
    
                #Calculating model output
                logits=model(x)
    
                #Reseting any old gradient values
                optimizer.zero_grad()
                loss=loss_fn(logits.squeeze(),y)

            
                #Track of metrics        
                accuracy_train=accuracy_fn(logits.type(torch.float32),y)
                temp_train_accuracy.append(accuracy_train)
                temp_train_loss.append(loss.item())
    
                #Back Propogation
                loss.backward()
            
                #Update Parameters
                optimizer.step()
            
                #Progress Bar Update
                pbar.update(1)
        pbar.close()
    #Tensorboard & Metrics for the dataset
    net_train_accuracy=sum(temp_train_accuracy)/len(temp_train_accuracy)
    net_train_loss=sum(temp_train_loss)/len(temp_train_loss)
    train_accuracy.append(net_train_accuracy)
    train_loss.append(net_train_loss)
    writer.add_scalar("Train Accuracy",net_train_accuracy,i)
    writer.add_scalar("Train Loss",net_train_loss,i)

    #Evaluation
    print("Testing:")
    model.eval()

    with tqdm(total=len(test_loader)) as pbar2:
        for x,y in test_loader:
            x=x.to(device)
            y=y.to(device)
            
            #Setting inference mode
            with torch.inference_mode():
                logits=model(x)
                loss=loss_fn(logits.squeeze().type(torch.float32),y)

                #Track of metrics
                accuracy_test=accuracy_fn(logits,y)
                temp_test_accuracy.append(accuracy_test)
                temp_test_loss.append(loss.item())

                #Progress Bar Update
                pbar2.update(1)
        pbar2.close()

    #Tensorboard & Metrics for the dataset
    net_test_accuracy=sum(temp_test_accuracy)/len(temp_test_accuracy)
    net_test_loss=sum(temp_test_loss)/len(temp_test_loss)
    test_accuracy.append(net_test_accuracy)
    test_loss.append(net_test_loss)
    writer.add_scalar("Test Accuracy",net_test_accuracy,i)
    writer.add_scalar("Test Loss",net_test_loss,i)

    '''
    #Saving the model
    try:
        os.makedirs(f"./{model_name}_feature_extractor/")
    except Exception as e:
        pass
    torch.save(model.state_dict(),f"./{model_name}_feature_extractor/checkpoint-{i+1}.pth")
    '''
    
    print(f"Epoch {i+1}:\nTrain Accuracy: {net_train_accuracy}  Train Loss: {net_train_loss}  Test Accuracy: {net_test_accuracy}  Test Loss: {net_test_loss}")
    print("\n")