In [1]:
import os
import gc
import pickle
import datetime
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm
from pycocotools.coco import COCO
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

In [2]:
class PrecomputedDataset(Dataset):
    def __init__(self, emb_images_file, emb_captions_file, device="cuda"):
        with open(emb_images_file, "rb") as f:
            self.eimg = pickle.load(f)

        with open(emb_captions_file, "rb") as f:
            self.etext = pickle.load(f)

        self.device = device
        self.keys = list(self.etext.keys())
    
    def __len__(self):
        return len(self.keys)
    
    def __getitem__(self, idx):
        key_text = self.keys[idx]
        # print(key_text)
        # Feltételezzük, hogy self.etext[key_text]["emb"] tensor formátumú
        tensor_text = self.etext[key_text]["emb"].clone().detach().to(self.device)
        
        key_img = self.etext[key_text]["img_fn"]
        # print(key_img)
        # Feltételezzük, hogy self.eimg[key_img] tensor formátumú
        tensor_img = self.eimg[key_img].clone().detach().to(self.device)       
        
        return tensor_img, tensor_text

In [3]:
class CLIPModel(nn.Module):
    def __init__(self, image_encoder_dim, text_encoder_dim, output_dim):
        super().__init__()

        self.image_proj = nn.Linear(image_encoder_dim, output_dim)
        self.text_proj = nn.Linear(text_encoder_dim, output_dim)
        self.embed_dim = output_dim

        print(image_encoder_dim, text_encoder_dim, output_dim)

    def forward(self, image_tensor, text_tensor):
        # Projection
        image_embeddings = self.image_proj(image_tensor)
        text_embeddings = self.text_proj(text_tensor)
    
        # Normalization
        image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
        text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
        
        return image_embeddings, text_embeddings

In [4]:
class CLIPLoss(nn.Module):
    def __init__(self, initial_temperature=0.07):
        super(CLIPLoss, self).__init__()
        self.loss_fn = nn.CrossEntropyLoss()
        # Define the temperature parameter as a learnable parameter
        self.temperature = nn.Parameter(torch.tensor(initial_temperature))

    def forward(self, image_features, text_features):
        # Scale logits with learnable temperature
        logits = torch.matmul(image_features, text_features.t()) / self.temperature
        labels = torch.arange(logits.shape[0], device=logits.device)
        loss = (self.loss_fn(logits, labels) + self.loss_fn(logits.t(), labels)) / 2
        return loss

In [5]:
class CLIPLossWithCompensation(nn.Module):
    def __init__(self, initial_temperature=0.07, initial_alpha=0.5):
        super(CLIPLossWithCompensation, self).__init__()
        self.loss_fn = nn.CrossEntropyLoss()
        # Define the temperature parameter as a learnable parameter
        self.temperature = nn.Parameter(torch.tensor(initial_temperature))
        # Scaling factor for the compensation
        self.alpha = nn.Parameter(torch.tensor(initial_alpha)) # alpha

    def forward(self, image_features, text_features):
        # Compute logits scaled with learnable temperature
        logits = torch.matmul(image_features, text_features.t()) / self.temperature
        
        # Normalize the features to compute cosine similarity
        image_features_norm = F.normalize(image_features, p=2, dim=-1)
        text_features_norm = F.normalize(text_features, p=2, dim=-1)

        # Compute pairwise cosine similarity for images and texts
        image_sim = torch.matmul(image_features_norm, image_features_norm.t())  # Image-to-image similarity
        text_sim = torch.matmul(text_features_norm, text_features_norm.t())    # Text-to-text similarity

        # Compute dynamic weights for the logits based on similarity
        weights = 1 + self.alpha * (image_sim + text_sim) / 2  

        # Apply weights to the logits
        weighted_logits = logits * weights

        # Create labels for the positive pairs
        labels = torch.arange(logits.shape[0], device=logits.device)

        # Compute the loss for both directions
        loss = (self.loss_fn(weighted_logits, labels) + self.loss_fn(weighted_logits.t(), labels)) / 2
        
        return loss

In [None]:
os.makedirs("models", exist_ok=True)
os.makedirs("results", exist_ok=True)

In [6]:
dataset_val = PrecomputedDataset(emb_images_file="data/coco/embs/val2017_images.pkl",
                                 emb_captions_file="data/coco/embs/val2017_captions.pkl")
dataloader_val = DataLoader(dataset_val, batch_size=32, shuffle=True)
len(dataloader_val), len(list(dataset_val.eimg.keys())), len(list(dataset_val.etext.keys()))

(157, 5000, 5000)

In [7]:
dataset_train = PrecomputedDataset(emb_images_file="data/coco/embs/train2017_images.pkl",
                                   emb_captions_file="data/coco/embs/train2017_captions.pkl")
dataloader_train = DataLoader(dataset_train, batch_size=32, shuffle=True)
len(dataloader_train), len(list(dataset_train.eimg.keys())), len(list(dataset_train.etext.keys()))

(3697, 118287, 118287)

In [8]:
EPOCH = 50

In [9]:
def train(loss_fn, epoch, mark):
    model = CLIPModel(image_encoder_dim=2048, text_encoder_dim=768, output_dim=512)
    model = model.to("cuda")

    optimizer = optim.Adam(list(model.parameters()) + list(loss_fn.parameters()), lr=1e-4, weight_decay=1e-6)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCH)

    current_time = datetime.datetime.now()
    directory_name = current_time.strftime("%Y%m%d%H%M%S")
    directory_name_models = f"models/{mark}_{directory_name}"
    directory_name_results = f"results/{mark}_{directory_name}"
    os.mkdir(directory_name_models, exist_ok=True)
    os.mkdir(directory_name_results, exist_ok=True)
    print(directory_name_models, directory_name_results)
    
    pbar = tqdm(range(epoch))
    history = []
    val_loss = 0.0
    val_loss_best = float("inf")
    val_accuracy = 0.0
    val_accuracy_best = 0.0
    current_lr = scheduler.get_last_lr()[0]

    for ep in pbar:
        # Train
        train_loss = 0.0
        model.train()
        for batch, (tesnor_img, tensor_text) in enumerate(dataloader_train):
            optimizer.zero_grad()
            image_features, text_features = model(tesnor_img, tensor_text)
            loss = loss_fn(image_features, text_features)
            loss.backward()
            optimizer.step()
            train_loss = loss.item()

            h = {
                "train_loss": train_loss,
                "val_loss": val_loss,
                "val_accuracy": val_accuracy,
                "temperature": loss_fn.temperature.item(),
                "alpha": 0.0,
                "batch": batch + 1,
                "learning rate": current_lr,
                "epoch": ep
            }
    
            if hasattr(loss_fn, "alpha"):
                h["alpha"] = loss_fn.alpha.item()
    
            history.append(h)
            train_loss = round(train_loss, 5)
            pbar.set_description(f"Train Loss: {train_loss}, Val Loss: {val_loss}, Best Acc: {val_accuracy_best}, Best V. Loss: {val_loss_best}")

        # Validation
        val_loss = 0.0
        val_accuracy = 0.0
        all_cosine_sim = []
        model.eval()
        with torch.no_grad():
            for tensor_img, tensor_text in dataloader_val:
                image_features, text_features = model(tensor_img, tensor_text)
                loss = loss_fn(image_features, text_features)
                val_loss += loss.item()
        
                # Cosine similarity for validation metrics
                cosine_sim = torch.matmul(image_features, text_features.T).cpu().numpy()
                all_cosine_sim.append(cosine_sim)
        
                preds = cosine_sim.argmax(axis=1)  # Predicted indices
                targets = np.arange(len(tensor_img))  # Ground truth: perfect match
                # print("p",preds)
                # print("t",targets)
                val_accuracy += accuracy_score(targets, preds)
        
            # Average validation metrics
            val_loss /= len(dataloader_val)
            val_accuracy /= len(dataloader_val)
        
            # Combine similarity matrices by padding to maximum shape
            max_rows = max(sim.shape[0] for sim in all_cosine_sim)
            max_cols = max(sim.shape[1] for sim in all_cosine_sim)
            padded_cosine_sim = np.zeros((max_rows, max_cols))
        
            for sim in all_cosine_sim:
                padded_cosine_sim[:sim.shape[0], :sim.shape[1]] += sim
        
            cosine_sim_matrix = padded_cosine_sim / len(all_cosine_sim)
        
            # Visualization
            plt.figure(figsize=(8, 6))
            sns.heatmap(cosine_sim_matrix, cmap="viridis", annot=False)
            plt.title(f"Epoch:{ep} Val Acc: {val_accuracy:.4f}, Loss: {val_loss:.4f}")
            plt.xlabel("Text Embeddings")
            plt.ylabel("Image Embeddings")
            plt.savefig(f"{directory_name_results}/similarity_matrix_epoch_{ep:02}.png")
            plt.close()

            h = {
                "train_loss": train_loss,
                "val_loss": val_loss,
                "val_accuracy": val_accuracy,
                "temperature": loss_fn.temperature.item(),
                "alpha": 0.0,
                "batch": 0,
                "learning rate": current_lr,
                "epoch": ep,
            }

            if hasattr(loss_fn, "alpha"):
                h["alpha"] = loss_fn.alpha.item()

            history.append(h)
            pd.DataFrame(history).to_csv(f"{directory_name_results}/history.csv", index=False)

            # Refresh Learning rate
            scheduler.step()
            current_lr = scheduler.get_last_lr()[0]
        
            # Save model
            train_loss = round(train_loss, 5)
            val_loss = round(val_loss, 5)
            if val_loss_best >= val_loss:
                val_loss_best = val_loss
                val_accuracy_best = val_accuracy
        
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, f"{directory_name_models}/model.pth")
        
            pbar.set_description(f"Train Loss: {train_loss}, Val Loss: {val_loss}, Best Acc: {val_accuracy_best}, Best V. Loss: {val_loss_best}")

    del model
    del optimizer
    del scheduler
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
train(loss_fn=CLIPLossWithCompensation(), epoch=EPOCH, mark="compensated")

2048 768 512




Directory name: models/compensated_20241217092922


Train Loss: 0.0477, Val Loss: 0.0, Best Acc: 0.0, Best V. Loss: inf:   0% 0/50 [00:09<?, ?it/s] 

In [10]:
train(loss_fn=CLIPLoss(), epoch=EPOCH, mark="baseline")

2048 768 512




Directory name: models/baseline_20241217101413


Train Loss: 0.00495, Val Loss: 1.23966, Best Acc: 0.6711783439490446, Best V. Loss: 0.96155: 100% 50/50 [13:40<00:00, 16.41s/it]


In [11]:
from IPython import get_ipython

get_ipython().kernel.do_shutdown(restart=True)

{'status': 'ok', 'restart': True}