In [34]:
from torch import nn
from torch.nn import functional as F
import torch
import numpy as np # Assumendo che np sia importato

class SpaceTranslator(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_layers,
        activation,
        dropout_rate,
        init_method: str = 'xavier'
    ):
        super().__init__()

        self.init_method = init_method.lower()
        if self.init_method not in ['xavier', 'kaiming']:
            raise ValueError("Unsupported init_method")

        layers = []
        last = input_dim

        for hidden in hidden_layers:
            layers += [
                nn.Linear(last, hidden),
                nn.LayerNorm(hidden),
                activation(),
                nn.Dropout(dropout_rate)
            ]
            last = hidden

        layers.append(nn.Linear(last, output_dim))
        self.net = nn.Sequential(*layers)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

        # Applica l'inizializzazione scelta
        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            if self.init_method == 'kaiming':
                nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')
            else:
                nn.init.xavier_uniform_(module.weight)
            
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
                
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)


    def forward(self, x):
        return F.normalize(self.net(x), p=2, dim=1) # Usata la versione normalizzata nell'esempio

In [24]:
from pathlib import Path
import numpy as np
import pandas as pd

'''Code from https://github.com/Mamiglia/challenge'''

def mrr(pred_indices: np.ndarray, gt_indices: np.ndarray) -> float:
    """
    Compute Mean Reciprocal Rank (MRR)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries (top-K)
        gt_indices: (N,) array of ground truth indices
    Returns:
        mrr: Mean Reciprocal Rank
    """
    reciprocal_ranks = []
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i] == gt_indices[i])[0]
        if matches.size > 0:
            reciprocal_ranks.append(1.0 / (matches[0] + 1))
        else:
            reciprocal_ranks.append(0.0)
    return np.mean(reciprocal_ranks)


def recall_at_k(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int) -> float:
    """Compute Recall@k
    Args:
        pred_indices: (N, N) array of top indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        recall: Recall@k
    """
    recall = 0
    for i in range(len(gt_indices)):
        if gt_indices[i] in pred_indices[i, :k]:
            recall += 1
    recall /= len(gt_indices)
    return recall

import numpy as np

def ndcg(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int = 100) -> float:
    """
    Compute Normalized Discounted Cumulative Gain (NDCG@k)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        ndcg: NDCG@k
    """
    ndcg_total = 0.0
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i, :k] == gt_indices[i])[0]
        if matches.size > 0:
            rank = matches[0] + 1
            ndcg_total += 1.0 / np.log2(rank + 1)  # DCG (IDCG = 1)
    return ndcg_total / len(gt_indices)



@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """Evaluate retrieval performance using cosine similarity
    Args:
        translated_embd: (N_captions, D) translated caption embeddings
        image_embd: (N_images, D) image embeddings
        gt_indices: (N_captions,) ground truth image indices for each caption
        max_indices: number of top predictions to consider
    Returns:
        results: dict of evaluation metrics

    """
    # Compute similarity matrix
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()

    n_queries = translated_embd.shape[0]
    device = translated_embd.device

    # Prepare containers for the fragments to be reassembled
    all_sorted_indices = []
    l2_distances = []

    # Process in batches - the narrow gate approach
    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]
        batch_img_embd = image_embd[batch_slice]

        # Compute similarity only for this batch
        batch_similarity = batch_translated @ batch_img_embd.T

        # Get top-k predictions for this batch
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.numpy()
        all_sorted_indices.append(gt_indices[batch_slice][batch_indices])

        # Compute L2 distance for this batch
        batch_gt = gt_indices[batch_slice]
        batch_gt_embeddings = image_embd[batch_gt]
        batch_l2 = (batch_translated - batch_gt_embeddings).norm(dim=1)
        l2_distances.append(batch_l2)

    # Reassemble the fragments
    sorted_indices = np.concatenate(all_sorted_indices, axis=0)

    # Apply the sacred metrics to the whole
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }

    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }

    l2_dist = torch.cat(l2_distances, dim=0).mean().item()
    results['l2_dist'] = l2_dist

    return results

def eval_on_val(x_val: np.ndarray, y_val: np.ndarray, model: nn.Module, device) -> dict:
    gt_indices = torch.arange(len(y_val))

    model.eval()

    with torch.inference_mode():
        translated = model(x_val.to(device)).to('cpu')

    results = evaluate_retrieval(translated, y_val, gt_indices)

    return results


def generate_submission(model: nn.Module, test_path: Path, output_file="submission-dirmodel.csv", device=None):
    test_data = np.load(test_path)
    sample_ids = test_data['captions/ids']
    test_embds = test_data['captions/embeddings']
    test_embds = torch.from_numpy(test_embds).float()

    with torch.no_grad():
        pred_embds = model(test_embds.to(device)).cpu()

    print("Generating submission file...")

    if isinstance(pred_embds, torch.Tensor):
        pred_embds = pred_embds.cpu().numpy()

    df_submission = pd.DataFrame({'id': sample_ids, 'embedding': pred_embds.tolist()})

    df_submission.to_csv(output_file, index=False, float_format='%.17g')
    print(f"‚úì Saved submission to {output_file}")

    return df_submission

In [47]:
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

def moco_info_nce_loss(q, k, queue, logit_scale):
    # Positivo
    l_pos = torch.sum(q * k, dim=1, keepdim=True)  # [batch,1]

    # Negativi
    l_neg = q @ queue.T  # [batch, queue_size]

    logit_scale = torch.clamp(logit_scale, min=np.log(0.01), max=np.log(100))

    logits = torch.cat([l_pos, l_neg], dim=1)
    logits = logits * logit_scale.exp()

    # Labels: positivo sempre in posizione 0
    labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)

    return F.cross_entropy(logits, labels)


@torch.no_grad()
def enqueue(queue, keys, queue_ptr):
    batch_size = keys.shape[0]
    queue_size = queue.shape[0]
    ptr = int(queue_ptr[0])

    if ptr + batch_size <= queue_size:
        # Slice diretta
        queue[ptr:ptr+batch_size, :] = keys
    else:
        #  (wrap-around)
        first_part = queue_size - ptr
        queue[ptr:, :] = keys[:first_part, :]
        queue[:batch_size - first_part, :] = keys[first_part:, :]

    # Aggiorna il puntatore
    queue_ptr[0] = (ptr + batch_size) % queue_size

def train_model_moco(model, save_path, train_dataset, val_dataset, batch_size, epochs, lr, patience, queue_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    
    # Lo scheduler segue la Val Loss (che si minimizza), quindi mode='min'
    scheduler = ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, threshold=0.003, min_lr=1e-5
    )

    queue = torch.zeros(queue_size, 1536, device=device)
    queue_ptr = torch.zeros(1, dtype=torch.long, device=device)

    best_val_loss = float('inf')
    loss_no_improvements = 0 

    for epoch in range(1, epochs+1):
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)

        for text_batch, image_emb_batch in progress_bar:
            text_batch, image_emb_batch = text_batch.to(device), image_emb_batch.to(device)
            
            optimizer.zero_grad()

            q = model(text_batch)
            k = image_emb_batch

            loss = moco_info_nce_loss(q, k, queue, model.logit_scale)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                enqueue(queue, k, queue_ptr)

            running_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_train_loss = running_loss / len(train_loader)

        # --- Validation ---
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for text_batch, image_emb_batch in val_loader:
                text_batch, image_emb_batch = text_batch.to(device), image_emb_batch.to(device)

                q = model(text_batch)
                k = image_emb_batch

                loss = moco_info_nce_loss(q, k, queue, model.logit_scale)
                running_val_loss += loss.item()
        avg_val_loss = running_val_loss / len(val_loader)

        results = test(val_dataset, model, device)
        mrr = results['mrr']

        scheduler.step(avg_val_loss) 

        print(f"Epoch {epoch:03d} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f} | "
              f"MRR: {mrr:.6f} | Recall-1: {results['recall_at_1']:.6f} | LR: {optimizer.param_groups[0]['lr']:.2e}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            loss_no_improvements = 0 
            
            # Salvataggio del modello con la Validation Loss migliore
            Path(save_path).parent.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), save_path)
            print(f"üíæ Saved new best model (Val Loss={avg_val_loss:.6f})")
        else:
            loss_no_improvements += 1
            if loss_no_improvements >= patience:
                print("‚èπ Early stopping triggered based on Validation Loss.")
                break

    print(f"‚úÖ Training complete")
    return model

def get_data(data_path: Path):
    data = np.load(data_path)
    caption_embeddings = data['captions/embeddings']
    image_embeddings = data['images/embeddings']
    caption_labels = data['captions/label']
    data.close()

    X_abs, y_abs = torch.tensor(caption_embeddings), torch.tensor(image_embeddings[np.argmax(caption_labels, axis=1)])

    return X_abs, y_abs

def get_datasets(X_abs, y_abs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    print('Texts shape', X_abs.shape)
    print('Images shape', y_abs.shape)

    dataset = TensorDataset(X_abs, y_abs)
    train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42))

    return train_dataset, val_dataset

def test(val_dataset: TensorDataset, model: nn.Module, device):
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
    for x_val, y_val in val_loader:
        results = eval_on_val(x_val, y_val, model=model, device=device)
    return results

def center(X: torch.Tensor):
    mean = X.mean(dim=0, keepdim=True)
    return X - mean, mean

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_path= '/kaggle/input/aml-competition/train/train/train.npz'
test_path= '/kaggle/input/aml-competition/test/test/test.clean.npz'

save_path = './models/dir-model.pth'

In [31]:
x, y = get_data(data_path)

#x_centered, x_center = center(x)

train_dataset, val_dataset = get_datasets(x, y)

Texts shape torch.Size([125000, 1024])
Images shape torch.Size([125000, 1536])


In [49]:
input_dim = x.shape[1]
output_dim = y.shape[1]
hidden_layers = [1256, 2048]
dropout_rate = 0.5
batch_size = 256
lr = 0.01
epochs = 250
patience = 5

queue_size = 10000

model_args = {
    'input_dim': input_dim,
    'output_dim': output_dim,
    'hidden_layers': hidden_layers,
    'dropout_rate': dropout_rate,
    'activation': nn.GELU
}

model = SpaceTranslator(**model_args)

train_model_moco(model, save_path, train_dataset, val_dataset, batch_size, epochs, lr, patience, queue_size)

print('Finished training. Now testing using best model...')

state = torch.load(save_path)
model.load_state_dict(state)
results = test(val_dataset, model, device)
print("Test Results:", results)

                                                                         

Epoch 001 | Train Loss: 8.727778 | Val Loss: 5.525506 | MRR: 0.725039 | Recall-1: 0.590840 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.725039)


                                                                         

Epoch 002 | Train Loss: 5.365557 | Val Loss: 4.570784 | MRR: 0.837310 | Recall-1: 0.743840 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.837310)


                                                                         

Epoch 003 | Train Loss: 4.706854 | Val Loss: 4.170775 | MRR: 0.870314 | Recall-1: 0.793680 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.870314)


                                                                         

Epoch 004 | Train Loss: 4.321810 | Val Loss: 3.900128 | MRR: 0.887549 | Recall-1: 0.819160 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.887549)


                                                                         

Epoch 005 | Train Loss: 4.050368 | Val Loss: 3.720185 | MRR: 0.898636 | Recall-1: 0.837000 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.898636)


                                                                         

Epoch 006 | Train Loss: 3.821990 | Val Loss: 3.571072 | MRR: 0.906888 | Recall-1: 0.849640 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.906888)


                                                                         

Epoch 007 | Train Loss: 3.631227 | Val Loss: 3.457870 | MRR: 0.912080 | Recall-1: 0.857120 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.912080)


                                                                         

Epoch 008 | Train Loss: 3.470839 | Val Loss: 3.378798 | MRR: 0.915122 | Recall-1: 0.862640 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.915122)


                                                                         

Epoch 009 | Train Loss: 3.333408 | Val Loss: 3.284214 | MRR: 0.919675 | Recall-1: 0.869840 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.919675)


                                                                          

Epoch 010 | Train Loss: 3.199440 | Val Loss: 3.215446 | MRR: 0.922864 | Recall-1: 0.874840 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.922864)


                                                                          

Epoch 011 | Train Loss: 3.080352 | Val Loss: 3.192078 | MRR: 0.923270 | Recall-1: 0.874840 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.923270)


                                                                          

Epoch 012 | Train Loss: 2.976312 | Val Loss: 3.119946 | MRR: 0.927357 | Recall-1: 0.881920 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.927357)


                                                                          

Epoch 013 | Train Loss: 2.867031 | Val Loss: 3.097158 | MRR: 0.928782 | Recall-1: 0.884200 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.928782)


                                                                          

Epoch 014 | Train Loss: 2.774466 | Val Loss: 3.060852 | MRR: 0.930447 | Recall-1: 0.886720 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.930447)


                                                                          

Epoch 015 | Train Loss: 2.683679 | Val Loss: 3.018417 | MRR: 0.931809 | Recall-1: 0.889000 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.931809)


                                                                          

Epoch 016 | Train Loss: 2.593929 | Val Loss: 2.993433 | MRR: 0.932717 | Recall-1: 0.890920 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.932717)


                                                                          

Epoch 017 | Train Loss: 2.516491 | Val Loss: 2.986677 | MRR: 0.933537 | Recall-1: 0.892760 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.933537)


                                                                          

Epoch 018 | Train Loss: 2.445432 | Val Loss: 2.973553 | MRR: 0.933315 | Recall-1: 0.891440 | LR: 1.00e-02


                                                                          

Epoch 019 | Train Loss: 2.368953 | Val Loss: 2.946566 | MRR: 0.934461 | Recall-1: 0.893560 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.934461)


                                                                          

Epoch 020 | Train Loss: 2.308971 | Val Loss: 2.932775 | MRR: 0.934600 | Recall-1: 0.893960 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.934600)


                                                                          

Epoch 021 | Train Loss: 2.239273 | Val Loss: 2.912142 | MRR: 0.935215 | Recall-1: 0.894720 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.935215)


                                                                          

Epoch 022 | Train Loss: 2.173672 | Val Loss: 2.896063 | MRR: 0.936930 | Recall-1: 0.897880 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.936930)


                                                                          

Epoch 023 | Train Loss: 2.119078 | Val Loss: 2.897631 | MRR: 0.936983 | Recall-1: 0.898040 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.936983)


                                                                          

Epoch 024 | Train Loss: 2.072061 | Val Loss: 2.890368 | MRR: 0.938616 | Recall-1: 0.900800 | LR: 1.00e-02
üíæ Saved new best model (MRR=0.938616)


                                                                          

Epoch 025 | Train Loss: 2.010360 | Val Loss: 2.869364 | MRR: 0.937530 | Recall-1: 0.899320 | LR: 1.00e-02


                                                                          

Epoch 026 | Train Loss: 1.961379 | Val Loss: 2.903009 | MRR: 0.937849 | Recall-1: 0.899280 | LR: 1.00e-02


                                                                          

Epoch 027 | Train Loss: 1.916301 | Val Loss: 2.864191 | MRR: 0.937485 | Recall-1: 0.899280 | LR: 1.00e-02


                                                                          

Epoch 028 | Train Loss: 1.871107 | Val Loss: 2.867248 | MRR: 0.938820 | Recall-1: 0.901320 | LR: 5.00e-03
üíæ Saved new best model (MRR=0.938820)


                                                                          

Epoch 029 | Train Loss: 1.596607 | Val Loss: 2.851360 | MRR: 0.941144 | Recall-1: 0.905040 | LR: 5.00e-03
üíæ Saved new best model (MRR=0.941144)


                                                                          

Epoch 030 | Train Loss: 1.497734 | Val Loss: 2.861686 | MRR: 0.941666 | Recall-1: 0.906040 | LR: 5.00e-03
üíæ Saved new best model (MRR=0.941666)


                                                                          

Epoch 031 | Train Loss: 1.452512 | Val Loss: 2.872866 | MRR: 0.941747 | Recall-1: 0.906160 | LR: 5.00e-03
üíæ Saved new best model (MRR=0.941747)


                                                                          

Epoch 032 | Train Loss: 1.409349 | Val Loss: 2.909070 | MRR: 0.940283 | Recall-1: 0.904240 | LR: 2.50e-03


                                                                          

Epoch 033 | Train Loss: 1.279878 | Val Loss: 2.898021 | MRR: 0.942009 | Recall-1: 0.906920 | LR: 2.50e-03
üíæ Saved new best model (MRR=0.942009)


                                                                          

Epoch 034 | Train Loss: 1.216193 | Val Loss: 2.897232 | MRR: 0.941551 | Recall-1: 0.905920 | LR: 2.50e-03
‚èπ Early stopping triggered based on Validation Loss.
‚úÖ Training complete. Best MRR: 0.942009
Finished training. Now testing using best model...
Test Results: {'mrr': 0.942009235100382, 'ndcg': 0.9561381249868742, 'recall_at_1': 0.90692, 'recall_at_3': 0.97312, 'recall_at_5': 0.98604, 'recall_at_10': 0.9942, 'recall_at_50': 0.9994, 'l2_dist': 25.848154067993164}


In [50]:
generate_submission(model, Path(test_path), output_file="mojo-pin.csv", device=device)

Generating submission file...
‚úì Saved submission to mojo-pin.csv


Unnamed: 0,id,embedding
0,1,"[0.02428484708070755, 0.027138926088809967, -0..."
1,2,"[-0.006308969110250473, -0.005466327536851168,..."
2,3,"[0.011657739989459515, -0.03169963136315346, 0..."
3,4,"[0.060520488768815994, -0.02953926846385002, -..."
4,5,"[0.06194249540567398, 0.04723084717988968, 0.0..."
...,...,...
1495,1496,"[0.010205189697444439, -0.0313861183822155, 0...."
1496,1497,"[0.006104839034378529, 0.04210679605603218, 0...."
1497,1498,"[0.03780222684144974, -0.0013050471898168325, ..."
1498,1499,"[-0.01464053988456726, 0.003734411671757698, 0..."


In [51]:
torch.save(model.state_dict(), "model_weights.pth")

In [None]:
import optuna
from optuna.pruners import MedianPruner

def objective(trial, train_dataset, val_dataset, epochs: int = 15, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Parametri MLP ---
    n_dir_layers = trial.suggest_int("n_dir_layers", 2, 4)
    n_scale_layers = trial.suggest_int("n_scale_layers", 2, 4)

    layer_choices = [1024, 1152, 1546, 2048, 4096]
    dir_hidden_dims = [trial.suggest_categorical(f"dir_units_l{i}", layer_choices) for i in range(n_dir_layers)]
    scale_hidden_dims = [trial.suggest_categorical(f"scale_units_l{i}", layer_choices) for i in range(n_scale_layers)]

    # --- Parametri di training ---
    batch_size = trial.suggest_categorical("batch_size", [64, 512, 1024, 2048, 4096])
    lr = trial.suggest_float("lr", 1e-3, 1e-2, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    dropout_rate = trial.suggest_categorical("dropout_rate", [0.3, 0.4, 0.5])
    activation_name = 'GELU'
    activation = nn.GELU

    # --- DataLoader ---
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # --- Modello ---
    model_args = {
        'input_dim': 1024,
        'output_dim': 1536,
        'dir_hidden_dims': dir_hidden_dims,
        'scale_hidden_dims': scale_hidden_dims,
        'dropout_rate': dropout_rate,
        'activation': activation
    }
    model = SpaceTranslator(**model_args).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=1, threshold=0.01, min_lr=1e-7
    )

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            output = model(X_batch)
            
            loss = info_nce_loss(output, y_batch, model.logit_scale)
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader)

        # --- Validation ---
        model.eval()
        val_loss = 0.0
        with torch.inference_mode():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)

                output = model(X_batch)
                
                loss = info_nce_loss(output, y_batch, model.logit_scale)
                
                val_loss += loss.item()
        val_loss /= len(val_loader)

        # --- MRR ---
        results = test(val_dataset, model, device)
        mrr = results['mrr']

        scheduler.step(mrr)

        trial.report(mrr, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

        print(f"[Trial {trial.number} | Epoch {epoch+1}/{epochs}] "
              f"Train Loss: {train_loss:.4f} Val Loss: {val_loss:.4f} "
              f"MRR: {mrr:.4f} Activation: {activation_name} | LR: {optimizer.param_groups[0]['lr']:.2e}")


    return best_mrr


def run_optuna_search(data_path: Path, n_trials: int = 150, epochs: int = 30, n_jobs: int = 1, sampler=None, pruner=None):
    if pruner is None:
        pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=1)

    X, y = get_data(data_path)
    train_dataset, val_dataset = get_datasets(X, y)

    study = optuna.create_study(direction="maximize", pruner=pruner)
    func = lambda trial: objective(trial, train_dataset=train_dataset, val_dataset=val_dataset, epochs=epochs)
    study.optimize(func, n_trials=n_trials, n_jobs=n_jobs)

    print("Study statistics:")
    print("  Number of finished trials: ", len(study.trials))
    print("  Best trial:")
    trial = study.best_trial
    print("    Value: ", trial.value)
    print("    Params: ")
    for k, v in trial.params.items():
        print(f"      {k}: {v}")

    return study

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_model_moco_full(model, save_path, full_dataset, batch_size, epochs, lr, queue_size):
    """
    Versione full-dataset di train_model_moco:
    - niente validation
    - niente early stopping
    - allena su full_dataset (train + val)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    train_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

    
    queue = torch.zeros(queue_size, 1536, device=device)
    queue_ptr = torch.zeros(1, dtype=torch.long, device=device)

    print("=== FULL-DATASET MoCo TRAINING (NO VALIDATION) ===")
    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0

        progress_bar = tqdm(train_loader, desc=f"[FULL] Epoch {epoch}/{epochs}", leave=False)
        for text_batch, image_emb_batch in progress_bar:
            text_batch, image_emb_batch = text_batch.to(device), image_emb_batch.to(device)

            optimizer.zero_grad()

            # text proiection
            q = model(text_batch)
            # corresponding image
            k = image_emb_batch

            loss = moco_info_nce_loss(q, k, queue, model.logit_scale)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                enqueue(queue, k, queue_ptr)

            running_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_train_loss = running_loss / len(train_loader)
        print(f"[FULL] Epoch {epoch:03d} | Train Loss: {avg_train_loss:.6f}")

    # 
    Path(save_path).parent.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), save_path)
    print(f"üíæ Saved full-dataset MoCo model to {save_path}")

    return model


In [None]:
from torch.utils.data import TensorDataset

full_dataset = TensorDataset(x, y)


state = torch.load(save_path, map_location=device)
model.load_state_dict(state)
model.to(device)

# no validation, full dataset 
save_path_full = "./models/dir-model-moco-full.pth"

epochs_full = 50  

model = train_model_moco_full(
    model,
    save_path_full,
    full_dataset,
    batch_size=batch_size,  
    epochs=epochs_full,
    lr=lr,                  
    queue_size=queue_size   
)

print("Finished FULL-DATASET training. Now testing on validation set for comparison...")


state_full = torch.load(save_path_full, map_location=device)
model.load_state_dict(state_full)
model.to(device)

results_full = test(val_dataset, model, device)
print("Test Results after full-dataset training:", results_full)
print(f"MRR: {results_full['mrr']:.6f} | Recall@1: {results_full['recall_at_1']:.6f}")


generate_submission(
    model,
    Path(test_path),
    output_file="mojo-pin-full.csv",
    device=device
)
print("new submission: mojo-pin-full.csv")
