In [47]:
from torch import nn
from torch.nn import functional as F
import torch
import numpy as np

# class SpaceTranslator(nn.Module):
#     def __init__(
#         self,
#         input_dim,
#         output_dim,
#         hidden_layers,
#         activation,
#         dropout_rate,
#         init_method: str = 'xavier'
#     ):
#         super().__init__()

#         self.init_method = init_method.lower()
#         if self.init_method not in ['xavier', 'kaiming']:
#             raise ValueError("Unsupported init_method")

#         layers = []
#         last = input_dim

#         for hidden in hidden_layers:
#             layers += [
#                 nn.Linear(last, hidden),
#                 nn.LayerNorm(hidden),
#                 activation(),
#                 nn.Dropout(dropout_rate)
#             ]
#             last = hidden

#         layers.append(nn.Linear(last, output_dim))
#         self.net = nn.Sequential(*layers)

#         self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

#         # Applica l'inizializzazione scelta
#         self.apply(self.init_weights)

#     def init_weights(self, module):
#         if isinstance(module, nn.Linear):
#             if self.init_method == 'kaiming':
#                 nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')
#             else:
#                 nn.init.xavier_uniform_(module.weight)
            
#             if module.bias is not None:
#                 nn.init.constant_(module.bias, 0.0)
                
#         elif isinstance(module, nn.LayerNorm):
#             nn.init.ones_(module.weight)
#             nn.init.zeros_(module.bias)


#     def forward(self, x):
#         return self.net(x)
#         #return F.normalize(self.net(x), p=2, dim=1)


import torch
from typing import Optional
from torch import nn
from torch.nn import functional as F

class SpaceTranslator(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        dir_hidden_dims: list[int],
        scale_hidden_dims: list[int],
        activation=nn.ReLU,
        dropout_rate: float=0.3,
        init_method: str = 'xavier'
    ):
        super().__init__()

        self.init_method = init_method.lower()
        if self.init_method not in ['xavier', 'kaiming']:
            raise ValueError("Unsupported init_method")

        def build_mlp(hidden_dims, out_dim, apply_softplus=False):
            layers = []
            last_dim = input_dim
            for hidden in hidden_dims:
                layers += [
                    nn.Linear(last_dim, hidden),
                    activation(),
                    nn.LayerNorm(hidden),
                    nn.Dropout(dropout_rate)
                ]
                last_dim = hidden
            layers.append(nn.Linear(last_dim, out_dim))
            
            if apply_softplus:
                layers.append(nn.Softplus())
            
            return nn.Sequential(*layers)

        self.dir_head = build_mlp(dir_hidden_dims, output_dim, apply_softplus=False)
        self.scale_head = build_mlp(scale_hidden_dims, 1, apply_softplus=True)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            if self.init_method == 'kaiming':
                nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')
            else:
                nn.init.xavier_uniform_(module.weight)
            
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
                
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x):
        direction = self.dir_head(x)
        scale = self.scale_head(x)
        
        return F.normalize(direction, p=2, dim=-1) * scale

In [32]:
from pathlib import Path
import numpy as np
import pandas as pd

'''Code from https://github.com/Mamiglia/challenge'''

def mrr(pred_indices: np.ndarray, gt_indices: np.ndarray) -> float:
    """
    Compute Mean Reciprocal Rank (MRR)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries (top-K)
        gt_indices: (N,) array of ground truth indices
    Returns:
        mrr: Mean Reciprocal Rank
    """
    reciprocal_ranks = []
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i] == gt_indices[i])[0]
        if matches.size > 0:
            reciprocal_ranks.append(1.0 / (matches[0] + 1))
        else:
            reciprocal_ranks.append(0.0)
    return np.mean(reciprocal_ranks)


def recall_at_k(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int) -> float:
    """Compute Recall@k
    Args:
        pred_indices: (N, N) array of top indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        recall: Recall@k
    """
    recall = 0
    for i in range(len(gt_indices)):
        if gt_indices[i] in pred_indices[i, :k]:
            recall += 1
    recall /= len(gt_indices)
    return recall

import numpy as np

def ndcg(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int = 100) -> float:
    """
    Compute Normalized Discounted Cumulative Gain (NDCG@k)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        ndcg: NDCG@k
    """
    ndcg_total = 0.0
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i, :k] == gt_indices[i])[0]
        if matches.size > 0:
            rank = matches[0] + 1
            ndcg_total += 1.0 / np.log2(rank + 1)  # DCG (IDCG = 1)
    return ndcg_total / len(gt_indices)



@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """Evaluate retrieval performance using cosine similarity
    Args:
        translated_embd: (N_captions, D) translated caption embeddings
        image_embd: (N_images, D) image embeddings
        gt_indices: (N_captions,) ground truth image indices for each caption
        max_indices: number of top predictions to consider
    Returns:
        results: dict of evaluation metrics

    """
    # Compute similarity matrix
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()

    n_queries = translated_embd.shape[0]
    device = translated_embd.device

    # Prepare containers for the fragments to be reassembled
    all_sorted_indices = []
    l2_distances = []

    # Process in batches - the narrow gate approach
    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]
        batch_img_embd = image_embd[batch_slice]

        # Compute similarity only for this batch
        batch_similarity = batch_translated @ batch_img_embd.T

        # Get top-k predictions for this batch
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.numpy()
        all_sorted_indices.append(gt_indices[batch_slice][batch_indices])

        # Compute L2 distance for this batch
        batch_gt = gt_indices[batch_slice]
        batch_gt_embeddings = image_embd[batch_gt]
        batch_l2 = (batch_translated - batch_gt_embeddings).norm(dim=1)
        l2_distances.append(batch_l2)

    # Reassemble the fragments
    sorted_indices = np.concatenate(all_sorted_indices, axis=0)

    # Apply the sacred metrics to the whole
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }

    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }

    l2_dist = torch.cat(l2_distances, dim=0).mean().item()
    results['l2_dist'] = l2_dist

    return results

def eval_on_val(x_val: np.ndarray, y_val: np.ndarray, model: nn.Module, device) -> dict:
    gt_indices = torch.arange(len(y_val))

    model.eval()

    with torch.inference_mode():
        translated = model(x_val.to(device)).to('cpu')

    results = evaluate_retrieval(translated, y_val, gt_indices)

    return results


def generate_submission(model: nn.Module, test_path: Path, output_file="submission.csv", device=None):
    test_data = np.load(test_path)
    sample_ids = test_data['captions/ids']
    test_embds = test_data['captions/embeddings']
    test_embds = torch.from_numpy(test_embds).float()

    with torch.no_grad():
        pred_embds = model(test_embds.to(device)).cpu()

    print("Generating submission file...")

    if isinstance(pred_embds, torch.Tensor):
        pred_embds = pred_embds.cpu().numpy()

    df_submission = pd.DataFrame({'id': sample_ids, 'embedding': pred_embds.tolist()})

    df_submission.to_csv(output_file, index=False, float_format='%.17g')
    print(f"‚úì Saved submission to {output_file}")

    return df_submission

In [55]:
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

def moco_info_nce_loss(q, k, queue, logit_scale):
    # Positivo
    l_pos = torch.sum(q * k, dim=1, keepdim=True)  # [batch,1]

    # Negativi
    l_neg = q @ queue.T  # [batch, queue_size]

    logit_scale = torch.clamp(logit_scale, min=np.log(0.01), max=np.log(100))

    logits = torch.cat([l_pos, l_neg], dim=1)
    logits = logits * logit_scale.exp()

    # Labels: positivo sempre in posizione 0
    labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

    return F.cross_entropy(logits, labels)


@torch.no_grad()
def enqueue(queue, keys, queue_ptr):
    batch_size = keys.shape[0]
    queue_size = queue.shape[0]
    ptr = int(queue_ptr[0])

    if ptr + batch_size <= queue_size:
        # Slice diretta
        queue[ptr:ptr+batch_size, :] = keys
    else:
        #  (wrap-around)
        first_part = queue_size - ptr
        queue[ptr:, :] = keys[:first_part, :]
        queue[:batch_size - first_part, :] = keys[first_part:, :]

    # Aggiorna il puntatore
    queue_ptr[0] = (ptr + batch_size) % queue_size

def train_model_moco(model, save_path, train_dataset, val_dataset, batch_size, epochs, lr, patience, queue_size, weight_decay):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.7,
        patience=3,
        threshold=0.001,
        min_lr=1e-6
    )

    queue = torch.zeros(queue_size, 1536, device=device)
    queue_ptr = torch.zeros(1, dtype=torch.long, device=device)

    best_mrr = -float('inf')
    no_improvements = 0

    for epoch in range(1, epochs+1):
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)

        # ---- Training ----
        for text_batch, image_emb_batch in progress_bar:
            text_batch, image_emb_batch = text_batch.to(device), image_emb_batch.to(device)

            optimizer.zero_grad()

            q = model(text_batch)
            k = image_emb_batch

            loss = moco_info_nce_loss(q, k, queue, model.logit_scale)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                enqueue(queue, k, queue_ptr)

            running_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_train_loss = running_loss / len(train_loader)

        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for text_batch, image_emb_batch in val_loader:
                text_batch, image_emb_batch = text_batch.to(device), image_emb_batch.to(device)

                q = model(text_batch)
                k = image_emb_batch

                loss = moco_info_nce_loss(q, k, queue, model.logit_scale)
                running_val_loss += loss.item()
        avg_val_loss = running_val_loss / len(val_loader)

        results = test(val_dataset, model, device)
        mrr = results["mrr"]

        scheduler.step(mrr)

        print(f"Epoch {epoch:03d} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f} | "
              f"MRR: {mrr:.6f} | Recall-1: {results['recall_at_1']:.6f} | LR: {optimizer.param_groups[0]['lr']:.2e}")

        # ---- EARLY STOPPING + MODEL SAVING basati su MRR ----
        if mrr > best_mrr:
            best_mrr = mrr
            no_improvements = 0

            # Save best model
            Path(save_path).parent.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), save_path)
            print(f"üíæ Saved new best model (MRR={mrr:.6f})")

        else:
            no_improvements += 1
            if no_improvements >= patience:
                print("‚èπ Early stopping triggered based on MRR.")
                break

    print("‚úÖ Training complete")
    return model



def get_data(data_path: Path):
    data = np.load(data_path)
    caption_embeddings = data['captions/embeddings']
    image_embeddings = data['images/embeddings']
    caption_labels = data['captions/label']
    data.close()

    X_abs, y_abs = torch.tensor(caption_embeddings), torch.tensor(image_embeddings[np.argmax(caption_labels, axis=1)])

    return X_abs, y_abs

def get_datasets(X_abs, y_abs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    print('Texts shape', X_abs.shape)
    print('Images shape', y_abs.shape)

    dataset = TensorDataset(X_abs, y_abs)
    train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42))

    return train_dataset, val_dataset

def test(val_dataset: TensorDataset, model: nn.Module, device):
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
    for x_val, y_val in val_loader:
        results = eval_on_val(x_val, y_val, model=model, device=device)
    return results

def center(X: torch.Tensor):
    mean = X.mean(dim=0, keepdim=True)
    return X - mean, mean

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_path= '/kaggle/input/aml-competition/train/train/train.npz'
test_path= '/kaggle/input/aml-competition/test/test/test.clean.npz'

save_path = './models/dir-model.pth'

In [11]:
x, y = get_data(data_path)

#x_centered, x_center = center(x)

train_dataset, val_dataset = get_datasets(x, y)

Texts shape torch.Size([125000, 1024])
Images shape torch.Size([125000, 1536])


In [64]:
input_dim = x.shape[1]
output_dim = y.shape[1]
#hidden_layers = [1472, 2048]
dropout_rate = 0.3881417338814056
batch_size = 256
lr = 0.00017095147919728726
epochs = 250
patience = 5

queue_size = 32768
weight_decay = 5.9917829998790125e-06

# {'lr': 0.00017095147919728726, 'weight_decay': 5.9917829998790125e-06, 'batch_size': 256, 'queue_size': 32768, 'dropout_rate': 0.3881417338814056, 'activation': 'silu', 'init_method': 'xavier', 'n_layers_dir': 1, 'dir_dim_0': 2048, 'n_layers_scale': 1, 'scale_dim_0': 512}
model_args = {
    'input_dim': input_dim,
    'output_dim': output_dim,
    #'hidden_layers': hidden_layers,
    'dir_hidden_dims': [2048],
    'scale_hidden_dims': [512],
    'dropout_rate': dropout_rate,
    'activation': nn.SiLU,
    'init_method': 'xavier'
}

model = SpaceTranslator(**model_args)

train_model_moco(model, save_path, train_dataset, val_dataset, batch_size, epochs, lr, patience, queue_size, weight_decay)

print('Finished training. Now testing using best model...')

state = torch.load(save_path)
model.load_state_dict(state)
results = test(val_dataset, model, device)
print("Test Results:", results)

                                                                         

Epoch 001 | Train Loss: 7.385201 | Val Loss: 5.580341 | MRR: 0.845623 | Recall-1: 0.758880 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.845623)


                                                                         

Epoch 002 | Train Loss: 5.526271 | Val Loss: 4.973348 | MRR: 0.886905 | Recall-1: 0.820360 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.886905)


                                                                         

Epoch 003 | Train Loss: 4.876020 | Val Loss: 4.694013 | MRR: 0.902863 | Recall-1: 0.844800 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.902863)


                                                                         

Epoch 004 | Train Loss: 4.456214 | Val Loss: 4.501168 | MRR: 0.911738 | Recall-1: 0.858000 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.911738)


                                                                         

Epoch 005 | Train Loss: 4.128471 | Val Loss: 4.389642 | MRR: 0.918572 | Recall-1: 0.868760 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.918572)


                                                                         

Epoch 006 | Train Loss: 3.849420 | Val Loss: 4.266974 | MRR: 0.923166 | Recall-1: 0.876680 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.923166)


                                                                         

Epoch 007 | Train Loss: 3.590361 | Val Loss: 4.244143 | MRR: 0.924325 | Recall-1: 0.878240 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.924325)


                                                                         

Epoch 008 | Train Loss: 3.355412 | Val Loss: 4.154324 | MRR: 0.926931 | Recall-1: 0.882320 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.926931)


                                                                         

Epoch 009 | Train Loss: 3.137401 | Val Loss: 4.190953 | MRR: 0.928013 | Recall-1: 0.884240 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.928013)


                                                                          

Epoch 010 | Train Loss: 2.932488 | Val Loss: 4.229918 | MRR: 0.928188 | Recall-1: 0.884360 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.928188)


                                                                          

Epoch 011 | Train Loss: 2.744548 | Val Loss: 4.222723 | MRR: 0.928733 | Recall-1: 0.885600 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.928733)


                                                                          

Epoch 012 | Train Loss: 2.583290 | Val Loss: 4.269889 | MRR: 0.929174 | Recall-1: 0.886400 | LR: 1.71e-04
üíæ Saved new best model (MRR=0.929174)


                                                                          

Epoch 013 | Train Loss: 2.435150 | Val Loss: 4.294111 | MRR: 0.927875 | Recall-1: 0.884560 | LR: 1.71e-04


                                                                          

Epoch 014 | Train Loss: 2.302876 | Val Loss: 4.305423 | MRR: 0.928858 | Recall-1: 0.886760 | LR: 1.71e-04


                                                                          

Epoch 015 | Train Loss: 2.182079 | Val Loss: 4.440578 | MRR: 0.927141 | Recall-1: 0.884080 | LR: 1.71e-04


                                                                          

Epoch 016 | Train Loss: 2.070082 | Val Loss: 4.449347 | MRR: 0.928678 | Recall-1: 0.886240 | LR: 1.20e-04


                                                                          

Epoch 017 | Train Loss: 1.852690 | Val Loss: 4.631016 | MRR: 0.928252 | Recall-1: 0.886000 | LR: 1.20e-04
‚èπ Early stopping triggered based on MRR.
‚úÖ Training complete
Finished training. Now testing using best model...
Test Results: {'mrr': 0.9291739106035963, 'ndcg': 0.946409809518304, 'recall_at_1': 0.8864, 'recall_at_3': 0.96796, 'recall_at_5': 0.98244, 'recall_at_10': 0.99228, 'recall_at_50': 0.99928, 'l2_dist': 25.875186920166016}


In [57]:
generate_submission(model, Path(test_path), output_file="2hmlp_memory-bank.csv", device=device)

Generating submission file...
‚úì Saved submission to 2hmlp_memory-bank.csv


Unnamed: 0,id,embedding
0,1,"[-0.009818177670240402, 0.011352929286658764, ..."
1,2,"[0.002839854918420315, -0.007432268001139164, ..."
2,3,"[0.003888091305270791, 0.01134349126368761, 0...."
3,4,"[0.03038654662668705, -0.023458782583475113, -..."
4,5,"[0.035906244069337845, 0.02111988514661789, 0...."
...,...,...
1495,1496,"[-0.006532762665301561, -0.02181270718574524, ..."
1496,1497,"[0.011069669388234615, 0.022174226120114326, 0..."
1497,1498,"[0.013020381331443787, -0.021422527730464935, ..."
1498,1499,"[-0.04034360498189926, -0.006283098831772804, ..."


In [None]:
torch.save(model.state_dict(), "model_weights.pth")

In [62]:
import optuna
from optuna.pruners import MedianPruner

def objective(
    trial,
    train_dataset,
    val_dataset,
    epochs: int = 15,
    device=None
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ------------------------------
    # HYPERPARAMETERS
    # ------------------------------

    # Optimizer params
    lr = trial.suggest_float("lr", 5e-5, 5e-3, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)

    # Training params
    batch_size = trial.suggest_categorical("batch_size", [32, 256, 512])

    # MoCo queue
    queue_size = trial.suggest_categorical("queue_size", [8192, 12288, 16384, 24576, 32768])

    # Model params
    dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5)
    activation = {
        #"relu": nn.ReLU,
        "gelu": nn.GELU,
        "silu": nn.SiLU
    }[trial.suggest_categorical("activation", ["relu", "gelu", "silu"])]
    init_method = trial.suggest_categorical("init_method", ["xavier", "kaiming"])

    # DIR head
    n_layers_dir = trial.suggest_int("n_layers_dir", 1, 3)
    dir_hidden_dims = [
        trial.suggest_categorical(f"dir_dim_{i}", [1024, 1472, 1856, 2048])
        for i in range(n_layers_dir)
    ]

    # SCALE head
    n_layers_scale = trial.suggest_int("n_layers_scale", 1, 3)
    scale_hidden_dims = [
    trial.suggest_categorical(f"scale_dim_{i}", [256, 512, 768, 1024, 1280, 1472, 1856, 2048])
        for i in range(n_layers_scale)
    ]

    # ------------------------------
    # DATA LOADERS
    # ------------------------------
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # ------------------------------
    # MODEL
    # ------------------------------
    model_args = {
        "input_dim": 1024,
        "output_dim": 1536,
        "dir_hidden_dims": dir_hidden_dims,
        "scale_hidden_dims": scale_hidden_dims,
        "dropout_rate": dropout_rate,
        "activation": nn.GELU,
        "init_method": init_method,
    }
    model = SpaceTranslator(**model_args).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    scheduler = ReduceLROnPlateau(
        optimizer,
        mode="max",
        factor=0.75,
        patience=4,
        threshold=1e-4,
        min_lr=1e-6,
    )

    queue = torch.zeros(queue_size, 1536, device=device)
    queue_ptr = torch.zeros(1, dtype=torch.long, device=device)

    best_mrr = -float("inf")

    # ------------------------------
    # TRAINING LOOP
    # ------------------------------
    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0

        for text_batch, image_emb_batch in train_loader:
            text_batch = text_batch.to(device)
            image_emb_batch = image_emb_batch.to(device)

            optimizer.zero_grad()

            q = model(text_batch)
            k = image_emb_batch
            loss = moco_info_nce_loss(q, k, queue, model.logit_scale)

            loss.backward()
            optimizer.step()

            with torch.no_grad():
                enqueue(queue, k, queue_ptr)

            running_loss += loss.item()

        # ------------------------------
        # VALIDATION + MRR
        # ------------------------------
        model.eval()
        with torch.no_grad():
            results = test(val_dataset, model, device)
        mrr = results["mrr"]

        # Step scheduler based on MRR
        scheduler.step(mrr)

        # Report to Optuna
        trial.report(mrr, epoch)

        if mrr > best_mrr:
            best_mrr = mrr

        # Pruning based on MRR
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    print(f"Trial {trial.number} finished ‚Äî Best MRR: {best_mrr:.6f}")
    return best_mrr   # <<<<<<< Maximize MRR


def run_optuna_search(
    data_path: Path,
    n_trials: int = 150,
    epochs: int = 30,
    n_jobs: int = 1,
    sampler=None,
    pruner=None
):
    if pruner is None:
        pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=3)

    X, y = get_data(data_path)
    train_dataset, val_dataset = get_datasets(X, y)

    # MAXIMIZE MRR
    study = optuna.create_study(
        direction="maximize",
        pruner=pruner,
        sampler=sampler
    )

    func = lambda trial: objective(
        trial,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        epochs=epochs
    )

    study.optimize(func, n_trials=n_trials, n_jobs=n_jobs)

    print("\n==== OPTUNA FINISHED ====")
    print("Trials:", len(study.trials))
    print("Best trial:")
    print(f"  Best MRR: {study.best_trial.value:.6f}")
    print("  Params:")
    for k, v in study.best_trial.params.items():
        print(f"    {k}: {v}")

    return study

In [None]:
study = run_optuna_search(data_path=data_path, n_trials=100, epochs=35, n_jobs=1)
study.trials_dataframe().to_csv("optuna_trials.csv", index=False)

print("Best params:", study.best_params)
print("Best trial number:", study.best_trial.number)

[I 2025-11-15 14:30:12,232] A new study created in memory with name: no-name-48ff3ae4-c006-47ae-8cea-f27bf647f40c


Texts shape torch.Size([125000, 1024])
Images shape torch.Size([125000, 1536])
