# **ViT fine-tuning**

## **Download datasets**

In [None]:
!wget -nc http://images.cocodataset.org/zips/train2014.zip
!wget -nc http://images.cocodataset.org/zips/val2014.zip

In [None]:
# Unzip all files
!unzip -q train2014.zip
!unzip -q val2014.zip

!rm *.zip

In [None]:
# VQA v2 zip files
!wget -nc https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip
!wget -nc https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip
!wget -nc https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip
!wget -nc https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip

In [None]:
!unzip -q v2_Questions_Train_mscoco.zip
!unzip -q v2_Questions_Val_mscoco.zip
!unzip -q v2_Annotations_Train_mscoco.zip
!unzip -q v2_Annotations_Val_mscoco.zip

!rm *.zip

## **Import libraries**

In [None]:
# standard libs
import os
import io
import glob
import time
import json
import re
import pickle
from glob import glob
import seaborn as sns

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from tqdm import tqdm
from PIL import Image

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# TorchVision
import torchvision
from torchvision import transforms, utils, models
from torchvision.models import vit_b_16, ViT_B_16_Weights
from transformers import CLIPModel, CLIPProcessor

## **Extract from datasets**

In [None]:
# connect to Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# images
images_train_folder = '/content/train2014'
images_val_folder = '/content/val2014'

train_images = glob(os.path.join(images_train_folder, '*.jpg'))
val_images = glob(os.path.join(images_val_folder, '*.jpg'))

print(len(train_images))
print(len(val_images))

In [None]:
# reduce list images size
train_images = train_images[:len(train_images)//80]
val_images = val_images[:len(val_images)//80]
print(len(train_images))
print(len(val_images))

In [None]:
# questions
questions_val_json = '/content/v2_OpenEnded_mscoco_val2014_questions.json'
questions_train_json = '/content/v2_OpenEnded_mscoco_train2014_questions.json'

# annotations
annotations_train_json = '/content/v2_mscoco_train2014_annotations.json'
annotations_val_json = '/content/v2_mscoco_val2014_annotations.json'

# images
images_train_folder = '/content/train2014'
images_val_folder = '/content/val2014'

In [None]:
# loading json files as lists
train_questions_list = json.load(open(questions_train_json))
val_questions_list = json.load(open(questions_val_json))
train_annotations_list = json.load(open(annotations_train_json))
val_annotations_list = json.load(open(annotations_val_json))

In [None]:
# extract image_id from filename
def extract_image_id(filename):
    base = os.path.basename(filename)  # e.g. COCO_train2014_000000123456.jpg
    img_id_str = base.split('_')[-1].split('.')[0]  # '000000123456'
    return int(img_id_str)

In [None]:
# image IDs
train_image_ids = set(extract_image_id(f) for f in train_images)
val_image_ids = set(extract_image_id(f) for f in val_images)

In [None]:
# create lists
train_list = train_questions_list['questions']
val_list = val_questions_list['questions']

for i in range(len(train_list)):
    train_list[i]['multiple_choice_answer'] = train_annotations_list['annotations'][i]['multiple_choice_answer']
    train_list[i]['answers'] = train_annotations_list['annotations'][i]['answers']
    train_list[i]['answer_type'] = train_annotations_list['annotations'][i]['answer_type']

for i in range(len(val_list)):
    val_list[i]['multiple_choice_answer'] = val_annotations_list['annotations'][i]['multiple_choice_answer']
    val_list[i]['answers'] = val_annotations_list['annotations'][i]['answers']
    val_list[i]['answer_type'] = val_annotations_list['annotations'][i]['answer_type']

In [None]:
# filter the list to be based on the images that are /4
filtered_train_list = [q for q in train_list if q['image_id'] in train_image_ids]
filtered_val_list = [q for q in val_list if q['image_id'] in val_image_ids]

print("Filtered train questions:", len(filtered_train_list))
print("Filtered val questions:", len(filtered_val_list))

In [None]:
# unique image IDs
filtered_train_image_ids = set(q['image_id'] for q in filtered_train_list)

# check
print("Train image IDs from images:", len(train_image_ids))
print("Train image IDs from filtered questions:", len(filtered_train_image_ids))
print("Are train image ID sets equal?", train_image_ids == filtered_train_image_ids)

In [None]:
# create folder
drive_path = '/content/drive/MyDrive'
new_folder = 'ViT_Embeddings'
full_path = os.path.join(drive_path, new_folder)

if not os.path.exists(full_path):
    os.makedirs(full_path)
    print(f"Folder '{new_folder}' created at {full_path}")


## **ViT implementation**

### **Create dataset**

In [None]:
# DATASET
class CustomDataset(Dataset):
    def __init__(self, list_images, transform=None):
        self.list_images = list_images
        self.transform = transform

    def __len__(self):
        return len(self.list_images)

    def __getitem__(self, idx):
        img_path = self.list_images[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

### **Define transformations for images**

In [None]:
# Source: https://medium.com/@prabowoyogawicaksana/self-supervised-pre-training-with-simclr-79830997be34

# If using .jpg or .png files for training:
train_image_paths = train_images
val_image_paths = val_images

# ------------ TRANSFORMS ------------
def get_complete_transform(output_shape, s=1.0):
    rnd_crop = transforms.RandomResizedCrop(output_shape, scale=(0.3, 1.0))
    rnd_flip = transforms.RandomHorizontalFlip(p=0.5)

    color_jitter = transforms.ColorJitter(0.4 * s, 0.4 * s, 0.4 * s, 0.1 * s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)

    rnd_gray = transforms.RandomGrayscale(p=0.2)
    gaussian_blur = transforms.GaussianBlur(kernel_size=23)
    rnd_gaussian_blur = transforms.RandomApply([gaussian_blur], p=0.5)

    clip_normalize = transforms.Normalize(
        mean=[0.4815, 0.4578, 0.4082],
        std=[0.2686, 0.2613, 0.2758]
    )

    image_transform = transforms.Compose([
        rnd_crop,
        rnd_flip,
        rnd_color_jitter,
        rnd_gray,
        rnd_gaussian_blur,
        transforms.ToTensor(),
        clip_normalize
    ])
    return image_transform

class ContrastiveLearningViewGenerator:
    """
    Generates multiple augmented views of the same image.
    """
    def __init__(self, base_transform, n_views=2):
        self.base_transform = base_transform
        self.n_views = n_views

    def __call__(self, x):
        views = [self.base_transform(x) for _ in range(self.n_views)]
        return views


### **Loss function**

In [None]:
# LOSS FUNCTION

def ntxent_loss(features, temp=0.5):
    batch_size = features.shape[0]
    assert batch_size % 2 == 0, "Batch size should be even for SimCLR."

    # Create labels dynamically for the current batch size
    labels = torch.cat([torch.arange(batch_size // 2) for _ in range(2)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float().to(DEVICE) # one-hot similarity mask

    similarity_matrix = torch.matmul(features, features.T)

    mask = torch.eye(batch_size, dtype=torch.bool).to(DEVICE)

    labels = labels[~mask].view(batch_size, -1)
    similarity_matrix = similarity_matrix[~mask].view(batch_size, -1)

    positives = similarity_matrix[labels.bool()].view(batch_size, -1)
    negatives = similarity_matrix[~labels.bool()].view(batch_size, -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(DEVICE)

    logits = logits / temp
    return logits, labels


### **ViT model**

In [None]:
class Identity(nn.Module):
    def _init_(self):
        super(Identity, self)._init_()
    def forward(self, x):
        return x

class SimCLR(nn.Module):
    def __init__(self, linear_eval=False, unfreeze_vit=True):
        """
        CLIP ViT-L/14 feature extractor with projection head.
        Args:
            linear_eval (bool): If True, disables contrastive view stacking.
            unfreeze_vit (bool): If True, unfreezes some transformer layers.
        """
        super(SimCLR, self).__init__()
        self.linear_eval = linear_eval

        # Load Hugging Face CLIP ViT-L/14
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
        self.clip_model.eval()

        # Freeze all by default
        for param in self.clip_model.parameters():
            param.requires_grad = False

        # Unfreeze last few layers
        if unfreeze_vit:
            for name, param in self.clip_model.vision_model.named_parameters():
                if "encoder.layers.21" in name or \
                   "encoder.layers.22" in name or \
                   "post_layernorm" in name:
                    param.requires_grad = True

        # Extract the vision encoder
        self.encoder = self.clip_model.vision_model

        # Define projection head
        self.projection = nn.Sequential(
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128)
        )


    def forward(self, x):
        """
        Args:
            x: list of views [view1_tensor, view2_tensor] if not linear_eval
               or single tensor if linear_eval
        Returns:
            Normalized projection vector (N, 128)
        """
        if not self.linear_eval:
            x = torch.cat(x, dim=0)

        # Forward through vision transformer
        outputs = self.encoder(pixel_values=x)
        cls_token = outputs.last_hidden_state[:, 0, :]

        projection = self.projection(cls_token)
        return F.normalize(projection, dim=1)

    def get_preprocess(self):
        return self.preprocess

### **ViT training**

In [None]:
# CONFIG

BATCH_SZ = 64
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

out_shape = 224
num_workers = 2
EPOCHS = 10
model_filename = 'ViT_def.pth'
best_val_loss = float('inf')  # initialize best loss

# Transforms and Dataset
base_transforms = get_complete_transform(output_shape=out_shape)
custom_transform = ContrastiveLearningViewGenerator(base_transform=base_transforms)

ds_train = CustomDataset(list_images=train_image_paths, transform=custom_transform)
ds_val = CustomDataset(list_images=val_image_paths, transform=custom_transform)

train_dl = DataLoader(ds_train, batch_size=BATCH_SZ, num_workers=num_workers, shuffle=True, drop_last=True, pin_memory=True)
val_dl = DataLoader(ds_val, batch_size=BATCH_SZ, num_workers=num_workers, shuffle=False, drop_last=False, pin_memory=True)

# Model & Optimizer
model = SimCLR().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

# Metrics
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

# TRAINING LOOP
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    # TRAIN
    model.train()
    train_loss, train_correct, train_total = 0.0, 0, 0

    for batch_idx, views in enumerate(train_dl):
        projections = model([view.to(DEVICE) for view in views])
        logits, labels = ntxent_loss(projections)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = torch.argmax(logits, dim=1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)
        train_loss += loss.item()

        print(f"Train -> Batch {batch_idx+1}/{len(train_dl)} - Loss: {loss.item():.4f} - Acc: {(preds == labels).float().mean().item():.4f}")

    avg_train_loss = train_loss / len(train_dl)
    train_acc = train_correct / train_total

    # VAL
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0

    with torch.no_grad():
        for batch_idx, views in enumerate(val_dl):
            projections = model([view.to(DEVICE) for view in views])
            logits, labels = ntxent_loss(projections)
            loss = criterion(logits, labels)

            preds = torch.argmax(logits, dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)
            val_loss += loss.item()

            print(f"Val -> Batch {batch_idx+1}/{len(val_dl)} - Loss: {loss.item():.4f} - Acc: {(preds == labels).float().mean().item():.4f}")

    avg_val_loss = val_loss / len(val_dl)
    val_acc = val_correct / val_total

    # Store metrics
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)

    # Save Model with the lowest val loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), os.path.join(full_path, model_filename))

    # Summary
    print(f"\nEpoch {epoch+1}:")
    print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val   Loss: {avg_val_loss:.4f} | Val   Acc: {val_acc:.4f}")

### **Plot results**

In [None]:
# Plot Loss
plt.figure(figsize=(10, 4))
plt.plot(train_losses, label='Train Loss', marker='o')
plt.plot(val_losses, label='Validation Loss', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_plot.png')  # Save plot to file
plt.show()

# Plot Accuracy
plt.figure(figsize=(10, 4))
plt.plot(train_accuracies, label='Train Accuracy', marker='o')
plt.plot(val_accuracies, label='Validation Accuracy', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('accuracy_plot.png')  # Save plot to file
plt.show()

In [None]:
def plot_cosine_similarity(model, dataloader, device, save_path="cosine_similarity.png"):
    model.eval()
    cosine_sims = []

    with torch.no_grad():
        for views in dataloader:
            view1, view2 = views[0].to(device), views[1].to(device)
            if view1.dim() == 3:
                view1 = view1.unsqueeze(0)
            if view2.dim() == 3:
                view2 = view2.unsqueeze(0)

            z1 = model([view1])
            z2 = model([view2])
            sim = torch.nn.functional.cosine_similarity(z1, z2)
            cosine_sims.extend(sim.cpu().numpy())

    plt.figure(figsize=(6, 4))
    sns.histplot(cosine_sims, bins=50, kde=True, color="purple")
    plt.title("Cosine Similarity of Positive Pairs")
    plt.xlabel("Cosine Similarity")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(save_path)  # Save plot to file
    plt.show()

# plot
plot_cosine_similarity(model, val_dl, DEVICE)

### **Save all images**

In [None]:
# images
train_images = glob(os.path.join(images_train_folder, '*.jpg'))
val_images = glob(os.path.join(images_val_folder, '*.jpg'))

print(len(train_images))
print(len(val_images))

In [None]:
# image IDs
train_image_ids = set(extract_image_id(f) for f in train_images)
val_image_ids = set(extract_image_id(f) for f in val_images)

In [None]:
# create lists
train_list = train_questions_list['questions']
val_list = val_questions_list['questions']

for i in range(len(train_list)):
    train_list[i]['multiple_choice_answer'] = train_annotations_list['annotations'][i]['multiple_choice_answer']
    train_list[i]['answers'] = train_annotations_list['annotations'][i]['answers']
    train_list[i]['answer_type'] = train_annotations_list['annotations'][i]['answer_type']

for i in range(len(val_list)):
    val_list[i]['multiple_choice_answer'] = val_annotations_list['annotations'][i]['multiple_choice_answer']
    val_list[i]['answers'] = val_annotations_list['annotations'][i]['answers']
    val_list[i]['answer_type'] = val_annotations_list['annotations'][i]['answer_type']

In [None]:
filtered_train_list = [q for q in train_list if q['image_id'] in train_image_ids]
filtered_val_list = [q for q in val_list if q['image_id'] in val_image_ids]

print("Filtered train questions:", len(filtered_train_list))
print("Filtered val questions:", len(filtered_val_list))

In [None]:
# unique image IDs
filtered_train_image_ids = set(q['image_id'] for q in filtered_train_list)

# check
print("Train image IDs from images:", len(train_image_ids))
print("Train image IDs from filtered questions:", len(filtered_train_image_ids))
print("Are train image ID sets equal?", train_image_ids == filtered_train_image_ids)

In [None]:
# save train_list
train_list_path = os.path.join(full_path, 'train_list.pkl')
with open(train_list_path, 'wb') as f:
    pickle.dump(filtered_train_list, f)
print(f"Saved train_list to {train_list_path}")

# save val_list
val_list_path = os.path.join(full_path, 'val_list.pkl')
with open(val_list_path, 'wb') as f:
    pickle.dump(filtered_val_list, f)
print(f"Saved val_list to {val_list_path}")

### **Extract features**

In [None]:
# load your model
model = SimCLR(unfreeze_vit=False, linear_eval=True) # model.linear_eval = True   to avoid Concatenation

state_dict = torch.load('/content/drive/MyDrive/ViT_Embeddings/ViT_def.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()

model.to(DEVICE)

In [None]:
# dataset for JPG images
class ImageDataset(Dataset):
    def __init__(self, json_list, image_folder, split="train2014"):
        self.imgids = list(set([i['image_id'] for i in json_list]))
        self.image_folder = image_folder  # e.g., '/content/train2014'
        self.split = split                # 'train2014' or 'val2014'

    def __len__(self):
        return len(self.imgids)

    def __getitem__(self, idx):
        img_id = self.imgids[idx]
        filename = f"COCO_{self.split}_{str(img_id).rjust(12, '0')}.jpg"
        file_path = os.path.join(self.image_folder, filename)

        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Image not found: {file_path}")

        image = Image.open(file_path).convert('RGB')
        image = transform(image)
        return {"image": image, "id": img_id}

In [None]:
# feature extraction function
def extract_and_save_features(json_list, image_folder, split_name, batch_size=256, num_workers=2):
    dataset = ImageDataset(json_list, image_folder, split=split_name)
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=True)

    model.eval()
    model.linear_eval = True
    model.to('cuda')

    all_features = []
    all_img_ids = []

    for batch in tqdm(loader, desc=f"Extracting features for {split_name}"):
        images = batch['image'].to('cuda')
        img_ids = batch['id']

        with torch.no_grad():
            features = model(images)

        features_np = features.cpu().numpy().astype('float16')
        all_features.append(features_np)
        all_img_ids.extend(img_ids)

    all_features = np.vstack(all_features)
    all_img_ids = np.array(all_img_ids, dtype=np.int32)

    # save
    save_path = os.path.join(full_path, f"{split_name}_image_features.pkl")
    with open(save_path, "wb") as f:
        pickle.dump({"img_ids": all_img_ids, "features": all_features}, f)

    print(f"Saved {len(all_img_ids)} features to {save_path}")

In [None]:
imgids = list(set([i['image_id'] for i in filtered_train_list]))
print(len(imgids))

In [None]:
# image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# extract
extract_and_save_features(filtered_train_list, "/content/train2014", "train2014")
extract_and_save_features(filtered_val_list, "/content/val2014", "val2014")