# Challenge Advanced Machine Learning
### Gruppo STM: Luca Moresca, Valerio Santini, Nicholas Suozzi

## Project Overview

This project addresses the AML Challenge on **text → latent visual space** mapping, with the goal of producing textual embeddings that match the image embeddings generated by a predefined VAE.
The architecture follows a VSE++ approach enriched with a multi-slot mechanism to achieve finer-grained alignment. The main model (TextToVis) projects the textual embedding into the visual space, while an auxiliary module (SlotAuxHead) generates multiple “slot” vectors, useful only during training to improve discrimination and internal diversity of representations.

The training uses a combination of contrastive losses:
- **Global triplet loss** (classic VSE++)
- **Slot-based triplet loss** (max-over-slot)
- **InfoNCE on slots**  
- **ISDL** to prevent collapse between slots  
- **Condensation loss** to maintain consistency between global and slot embeddings  

This configuration allows for more robust and accurate alignment, reducing dependence on a single global vector and leveraging more granular signals.

The pipeline includes:
- a stable image level split, to ensure correct validation
- saving checkpoints for each epoch and selecting the best one based on validation metrics
- qualitative analysis through visual retrieval
- an ensemble strategy of models trained with different seeds, used in inference to increase stability and performance

The result is a system optimised to maximise metrics, combining classic metric learning techniques with modern mechanisms based on multi-slot representations and diversity regularisations.

In [None]:
from pathlib import Path
import random
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import numpy as np
import hashlib
from scipy import sparse

from challenge.src.common import load_data, prepare_train_data, generate_submission
from challenge.src.eval.metrics import mrr, recall_at_k, ndcg
from challenge.src.eval.visualize import visualize_retrieval

RUN_ID = 3

MODEL_PATH = f"models/vsepp_text2vis_best_seed{RUN_ID}.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 24
BATCH_SIZE = 128
LR_INIT = 2e-4
LR_LATE = 2e-5
LR_STEP_EPOCH = 16
WEIGHT_DECAY = 0.0
MARGIN = 0.20

K_SLOTS_AUX = 4
TAU_SLOT = 0.09
LAMBDA_SLOT = 0.15

USE_SLOT_TRIPLET = True
LAMBDA_TRI_GLOBAL = 0.30

LAMBDA_ISDL = 0.02
ISDL_DELTA = 0.55

LAMBDA_COND = 0.05

SEED = 42 + RUN_ID
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if DEVICE.type == "cuda":
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

### Image-based split train/validation

The dataset is split using an image-level split, so that all captions for the same image go into the same set.

- The embeddings are loaded (X for text, y for image).
- Each image is assigned to validation using a stable MD5 hash.
- Captions inherit the split of their image.
- Dedicated DataLoaders are built for training/validation.
- For validation, the image gallery (val_img_embd) and val_label mapping, necessary for calculating the MRR, are also prepared.

A reproducible split is then created, trying to avoid leakage between training and validation.

In [None]:
train_data = load_data("data/train/train.npz")
X, y, label = prepare_train_data(train_data)

img_names_all = train_data['images/names']
img_emb_all = torch.from_numpy(train_data['images/embeddings']).float()

val_ratio = 0.10
def stable_hash(name: str) -> float:
    h = hashlib.md5(name.encode('utf-8')).hexdigest()
    return int(h[:8], 16) / 0xFFFFFFFF

img_hash = np.array([stable_hash(nm) for nm in img_names_all])
IMG_VAL_MASK = (img_hash < val_ratio)
IMG_TRAIN_MASK = ~IMG_VAL_MASK

cap_to_img = train_data['captions/label']
if sparse.issparse(cap_to_img):
    cap_gt_img_idx = cap_to_img.argmax(axis=1).A1
else:
    cap_gt_img_idx = np.argmax(cap_to_img, axis=1)

CAP_TRAIN_MASK = IMG_TRAIN_MASK[cap_gt_img_idx]
CAP_VAL_MASK = IMG_VAL_MASK[cap_gt_img_idx]

assert not (IMG_TRAIN_MASK & IMG_VAL_MASK).any(), "Overlap immagini train/val > 0"

X_train = X[CAP_TRAIN_MASK]
y_train = y[CAP_TRAIN_MASK]
X_val = X[CAP_VAL_MASK]
y_val = y[CAP_VAL_MASK]

print(f"Train captions: {len(X_train):,} | Val captions: {len(X_val):,}")
print(f"Train images: {int(IMG_TRAIN_MASK.sum()):,} | Val images: {int(IMG_VAL_MASK.sum()):,}")

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val,   y_val)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=(DEVICE.type=='cuda'))
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=(DEVICE.type=='cuda'))

val_img_embd = F.normalize(img_emb_all[torch.from_numpy(IMG_VAL_MASK)], dim=-1).cpu()
val_img_file = img_names_all[IMG_VAL_MASK]

global_to_val = -np.ones(len(img_names_all), dtype=np.int64)
global_to_val[np.where(IMG_VAL_MASK)[0]] = np.arange(IMG_VAL_MASK.sum(), dtype=np.int64)
val_gt_global = cap_gt_img_idx[CAP_VAL_MASK]
val_label = global_to_val[val_gt_global].astype(np.int64)

### Main model and auxiliary slot head

**TextToVis**  
This is the main adapter: a small MLP that projects the textual embedding into visual space (d_vis).  
The output is normalised to use cosine similarity during training and retrieval.

**SlotAuxHead**  
Auxiliary head that generates K vectors in visual space from the same textual embedding.  
These slots are only used during training (InfoNCE, triplet on slot, ISDL) to provide a richer and more fine-grained signal, but are not used in inference.

Both networks are lightweight and completely separate:  
- model produces the final vector to be used in submission  
- slot_head produces the slots used as regularisation during training.

In [None]:
class TextToVis(nn.Module):
    def __init__(self, d_text=1024, d_vis=1536, hidden=2048):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(d_text, hidden),
            nn.GELU(),
            nn.Linear(hidden, d_vis, bias=True)
        )
    def forward(self, t):
        z = self.proj(t)
        return F.normalize(z, dim=-1)

class SlotAuxHead(nn.Module):
    """Head ausiliaria: produce K vettori per InfoNCE"""
    def __init__(self, d_text=1024, d_vis=1536, K=4, hidden=1024):
        super().__init__()
        self.K = K
        self.ff = nn.Sequential(
            nn.Linear(d_text, hidden), nn.GELU(),
            nn.Linear(hidden, K * d_vis, bias=True)
        )
        self.d_vis = d_vis
    def forward(self, t):
        B = t.size(0)
        S = self.ff(t).view(B, self.K, self.d_vis)
        return F.normalize(S, dim=-1)

model = TextToVis(d_text=X.shape[1], d_vis=y.shape[1], hidden=2048).to(DEVICE)
slot_head = SlotAuxHead(d_text=X.shape[1], d_vis=y.shape[1], K=K_SLOTS_AUX, hidden=1024).to(DEVICE)

### Loss functions

This section lists all the losses used to guide the alignment between textual and visual embeddings. The standard triplet (triplet_hard_neg_cos) works on the global vector Z and follows the VSE++ setting, comparing each positive with the most similar negative in the batch. The triplet_hard_neg_slot_max variant extends the same logic to slots, using the maximum similarity between them to capture the best possible alignment: this approach is more sensitive to details and favours better Rank-1.

To support this mechanism, *condensation loss* is introduced, which pushes Z closer to the slot that best matches its image, thus transferring the fine-grained accuracy of the slots to the final vector used in inference. loss_slot_ce, on the other hand, provides a direct contrastive signal for each slot, exploiting in-batch negatives.

Finally, isdl_hinge keeps the slots diversified by penalising those that are too similar to each other, stabilising multi-slot learning and ensuring that the information captured is not redundant.


In [None]:
def triplet_hard_neg_cos(Z, Y, margin):
    """Triplet VSE++ standard sul vettore globale Z"""
    S = Z @ Y.t()
    pos = S.diag()
    B = S.size(0)
    S = S.clone()
    S[torch.arange(B), torch.arange(B)] = float('-inf')
    neg = S.max(dim=1).values
    return F.relu(margin + neg - pos).mean()

def triplet_hard_neg_slot_max(S_T, Y, margin):
    """MPAS-lite: Triplet con hard negative usando la similarità max over slot"""
    B, K, D = S_T.shape
    sims = torch.einsum('bkd,nd->bkn', S_T, Y) 
    pos_all = sims[:, :, torch.arange(B)]
    pos_max, _ = pos_all.max(dim=1)
    sims = sims.clone()
    idx = torch.arange(B, device=S_T.device)
    sims[:, :, idx] = float('-inf')
    neg_max, _ = sims.view(B, -1).max(dim=1)

    return F.relu(margin + neg_max - pos_max).mean()

def loss_condensation(Z, S_T, Y):
    """Condensation: forza Z ad assomigliare allo slot che meglio allinea Y"""
    sims_ty = torch.einsum('bkd,bd->bk', S_T, Y)
    best_idx = sims_ty.max(dim=1).indices
    s_star = S_T[torch.arange(S_T.size(0), device=S_T.device), best_idx]
    z_n = F.normalize(Z, dim=-1)
    s_n = F.normalize(s_star, dim=-1)
    cos_zs = torch.sum(z_n * s_n, dim=1)

    return (1.0 - cos_zs).mean()

def loss_slot_ce(S_T, Y, tau=0.07):
    """CE/InfoNCE per slot con negativi in batch"""
    B, K, D = S_T.shape
    Yt = Y.t()
    losses = []
    target = torch.arange(B, device=S_T.device)
    for k in range(K):
        logits = (S_T[:, k, :] @ Yt) / tau
        losses.append(F.cross_entropy(logits, target))
    return torch.stack(losses).mean()

def isdl_hinge(S_T, delta=0.55):
    """Penalizza solo sim intra slot > sigma"""
    B, K, D = S_T.shape
    if K <= 1:
        return S_T.new_zeros(())
    C = (S_T @ S_T.transpose(1, 2)).clamp(-1, 1)
    mask = ~torch.eye(K, device=S_T.device, dtype=torch.bool)
    C_off = C[:, mask].view(B, -1)
    return F.relu(C_off - delta).mean()


### Retrieval performance evaluation

These functions are used to efficiently measure the quality of text-image alignment during training.

evaluate_retrieval_global: calculates the similarities between all predicted vectors Z and the image gallery, processing them in blocks. For each query, it obtains the top-k most similar indices and from these derives MRR, NDCG and the various recall@k values. An average L2 distance from the ground truth is also estimated, which is useful as a rough indicator of the geometric consistency of the mapping.

In [None]:
@torch.no_grad()
def evaluate_retrieval_global(Z: torch.Tensor,gallery: torch.Tensor,gt_indices: np.ndarray,topk: int = 100,chunk: int = 512):
    Z = Z.to('cpu'); gallery = gallery.to('cpu')
    Nq, Ng = Z.size(0), gallery.size(0)
    topk = min(topk, Ng)

    all_topk = []
    for start in range(0, Nq, chunk):
        end = min(start + chunk, Nq)
        sims = Z[start:end] @ gallery.T
        topk_idx = torch.topk(sims, k=topk, dim=1, largest=True, sorted=True).indices
        all_topk.append(topk_idx.cpu().numpy())
    pred_indices = np.vstack(all_topk).astype(np.int64)

    l2_dist = (Z - gallery[torch.from_numpy(gt_indices)]).norm(dim=1).mean().item()
    return {
        'mrr': mrr(pred_indices, gt_indices),
        'ndcg': ndcg(pred_indices, gt_indices),
        'recall_at_1':  recall_at_k(pred_indices, gt_indices, 1),
        'recall_at_3':  recall_at_k(pred_indices, gt_indices, 3),
        'recall_at_5':  recall_at_k(pred_indices, gt_indices, 5),
        'recall_at_10': recall_at_k(pred_indices, gt_indices, 10),
        'recall_at_50': recall_at_k(pred_indices, gt_indices, 50),
        'l2_dist': l2_dist,
    }

### Model training 

The train_model function manages the entire training process by combining the main model and the slot head. Optimisation is performed using Adam, while the learning rate is modulated with CosineAnnealing to make the descent more stable in the last epochs.

During each epoch, the model processes the batches, generating both the global vector Z and the slots S_T. The different components of the loss (global triplet, triplet on slots, InfoNCE on slots, ISDL and, if active, condensation loss) are combined to guide the fine-grained alignment between text and image. After backpropagation, the parameters of both modules are updated consistently.

At the end of the epoch, the model is evaluated on the entire validation set: first, a positive average similarity is calculated, then normalised embeddings are generated for use in retrieval. The evaluate_retrieval_global function returns all metrics. Training continues until the predefined number of epochs is completed, finally returning the model with the best performance on the validation gallery.

In [None]:
def train_model(model: nn.Module,train_loader: DataLoader,val_loader: DataLoader,device: torch.device,epochs: int) -> nn.Module:
    opt = torch.optim.Adam(list(model.parameters()) + list(slot_head.parameters()), lr=LR_INIT, weight_decay=WEIGHT_DECAY)
    steps_per_epoch = len(train_loader)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS * steps_per_epoch, eta_min=LR_LATE)
    best_mrr = -1.0
    ckpt_dir = Path(MODEL_PATH).parent / f"checkpoints_vsepp_seed{RUN_ID}"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    
    for epoch in range(1, epochs + 1):
        model.train()
        for Xb, Yb in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
            Xb = Xb.to(device, non_blocking=True)
            Yb = Yb.to(device, non_blocking=True)
            Zb = model(Xb)
            Sb = slot_head(Xb)

            if USE_SLOT_TRIPLET:
                loss_tri_slot = triplet_hard_neg_slot_max(Sb, Yb, margin=MARGIN)
                loss_tri_global = triplet_hard_neg_cos(Zb, Yb, margin=MARGIN)
                loss_tri = loss_tri_slot + LAMBDA_TRI_GLOBAL * loss_tri_global
            else:
                loss_tri = triplet_hard_neg_cos(Zb, Yb, margin=MARGIN)

            loss_slot = loss_slot_ce(Sb, Yb, tau=TAU_SLOT)

            if LAMBDA_ISDL > 0.0:
                loss_isd = isdl_hinge(Sb, delta=ISDL_DELTA)
            else:
                loss_isd = Sb.new_zeros(())

            if LAMBDA_COND > 0.0 and USE_SLOT_TRIPLET:
                loss_cond = loss_condensation(Zb, Sb, Yb)
            else:
                loss_cond = Zb.new_zeros(())

            loss = (loss_tri + LAMBDA_SLOT * loss_slot + LAMBDA_ISDL * loss_isd + LAMBDA_COND * loss_cond)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(list(model.parameters()) + list(slot_head.parameters()), 1.0)
            opt.step()
            scheduler.step()
    
        val_cos_sum, val_batches = 0.0, 0
        with torch.no_grad():
            for Xb, Yb in DataLoader(val_loader.dataset, batch_size=BATCH_SIZE, shuffle=False):
                Xb = Xb.to(device, non_blocking=True)
                Yb = F.normalize(Yb.to(device, non_blocking=True), dim=-1)
                Zb = model(Xb)
                cos_b = torch.einsum('bd,bd->b', Zb, Yb).mean()
                val_cos_sum += float(cos_b.item())
                val_batches += 1

        preds_val = []
        with torch.no_grad():
            for Xb, _ in DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False):
                Xb = Xb.to(device, non_blocking=True)
                preds_val.append(model(Xb).cpu())
        Z_val = F.normalize(torch.cat(preds_val, dim=0), dim=-1).cpu()

        res_val = evaluate_retrieval_global(Z_val, val_img_embd, val_label, topk=100, chunk=512)
        curr_mrr = res_val['mrr']

        print(f"losses: tri={float(loss_tri.item()):.4f} | "
              f"slot={float(loss_slot.item()):.4f} | "
              f"isd={float(loss_isd.item()):.4f} | "
              f"cond={float(loss_cond.item()):.4f}")

        ckpt_path = ckpt_dir / f"epoch_{epoch:03d}.pth"
        torch.save(model.state_dict(), ckpt_path)
        print(f"Saved {ckpt_path.name}")

        if curr_mrr > best_mrr:
            best_mrr = curr_mrr
            torch.save(model.state_dict(), MODEL_PATH)
            print(f"New best MRR ({best_mrr:.5f}),  saved {Path(MODEL_PATH).name}")

    print(f"\nTraining finito. Best MRR val: {best_mrr:.5f}, checkpoint: {MODEL_PATH}")
    return model

print(f"RUN_ID = {RUN_ID}, SEED = {SEED}")
model = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=DEVICE,
    epochs=EPOCHS,
)

### Preparation of the validation gallery and evaluation functions

After training, the model is switched to evaluation mode and the same image split used previously is reconstructed. This ensures that the validation gallery remains identical to the one seen during training and that no leakage occurs between training and validation.

All captions associated with the validation images are then selected, the relevant DataLoader is reconstructed, and the embeddings of the gallery images are normalised. For each caption, the exact index of the correct image within the gallery is also calculated: this mapping (val_label) is essential for measuring MRR.

The evaluate_retrieval_global function performs the actual retrieval: it calculates the similarities between the predicted vectors and the gallery, extracts the top-k most similar ones and, from these, computes MRR, NDCG and recall@k. The calculation is done in blocks to avoid memory problems.

In [None]:
torch.manual_seed(int.from_bytes(os.urandom(4), "little"))
model.eval()

img_names_all = train_data['images/names']
img_emb_all = torch.from_numpy(train_data['images/embeddings']).float()

val_ratio = 0.10

img_hash = np.array([stable_hash(nm) for nm in img_names_all])
IMG_VAL_MASK = (img_hash < val_ratio)
IMG_TRAIN_MASK = ~IMG_VAL_MASK
assert not (IMG_VAL_MASK & IMG_TRAIN_MASK).any(), "Overlap immagini train/val > 0"

cap_to_img = train_data['captions/label']
if sparse.issparse(cap_to_img):
    cap_gt_img_idx = cap_to_img.argmax(axis=1).A1
else:
    cap_gt_img_idx = np.argmax(cap_to_img, axis=1)

CAP_TRAIN_MASK = IMG_TRAIN_MASK[cap_gt_img_idx]
CAP_VAL_MASK = IMG_VAL_MASK[cap_gt_img_idx]

X_val = X[CAP_VAL_MASK]
y_val = y[CAP_VAL_MASK]
val_dataset = TensorDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=(DEVICE.type=='cuda'))

print(f"Val captions: {len(X_val):,}, Val images: {int(IMG_VAL_MASK.sum()):,}")

val_img_embd = F.normalize(img_emb_all[torch.from_numpy(IMG_VAL_MASK)], dim=-1).cpu()  # (N_img_val, D)
val_img_file = img_names_all[IMG_VAL_MASK]

global_to_val = -np.ones(len(img_names_all), dtype=np.int64)
global_to_val[np.where(IMG_VAL_MASK)[0]] = np.arange(IMG_VAL_MASK.sum(), dtype=np.int64)
val_gt_global = cap_gt_img_idx[CAP_VAL_MASK]
val_label = global_to_val[val_gt_global]


In [None]:
model.eval()

preds_val = []
with torch.no_grad():
    for Xb, _ in DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False):
        Xb = Xb.to(DEVICE, non_blocking=True)
        Eb = model(Xb)
        preds_val.append(Eb.cpu())
Z_val_best = F.normalize(torch.cat(preds_val, dim=0), dim=-1).cpu()

res_val = evaluate_retrieval_global(Z_val_best, val_img_embd, val_label, topk=100, chunk=512)
print("\nVal (image-level split, global gallery) — BEST CKPT (seed corrente)")
for k, v in res_val.items():
    print(f"{k:15s}: {v:.4f}")

for _ in range(3):
    i = np.random.randint(0, len(X_val))
    with torch.no_grad():
        zb = model(X_val[i:i+1].to(DEVICE)).cpu()
    caption_text = train_data['captions/text'][CAP_VAL_MASK][i]
    gt_idx = int(val_label[i])
    visualize_retrieval(zb, gt_idx, val_img_file, caption_text, val_img_embd, k=5)


### Ensemble of the best models and submission generation

To improve the quality of the final embeddings, an ensemble of models trained with different seeds can be used.  
Each checkpoint is loaded into a separate instance of TextToVis, kept in eval mode, so as to combine independent but architecturally identical predictions. In this notebook we use a single seed for simplicity, but the same code supports multiple checkpoints (e.g., 3–6 seeds).

During inference on the test set, each batch of text embeddings is passed to all models in the ensemble; their outputs are summed and then averaged. Final normalisation ensures that the resulting embeddings remain compatible with the cosine similarity used in evaluation.

The embeddings produced are concatenated and verified in their final form, then saved in the format required by the challenge using generate_submission.  
The result is a more stable and robust submission than when using a single model, especially when multiple seeds are active.


In [None]:
CHECKPOINTS_ENSEMBLE = [
    # "models/vsepp_text2vis_best_seed0.pth",
    # "models/vsepp_text2vis_best_seed1.pth",
    # "models/vsepp_text2vis_best_seed2.pth",
    "models/vsepp_text2vis_best_seed3.pth",
    # "models/vsepp_text2vis_best_seed4.pth",
    # "models/vsepp_text2vis_best_seed5.pth",
]

def build_model_for_inference():
    m = TextToVis(d_text=X.shape[1], d_vis=y.shape[1], hidden=2048).to(DEVICE)
    return m

models_ens = []
for ckpt_path in CHECKPOINTS_ENSEMBLE:
    state = torch.load(ckpt_path, map_location=DEVICE)
    m = build_model_for_inference()
    m.load_state_dict(state)
    m.eval()
    models_ens.append(m)
    print(f"Caricato per ensemble: {ckpt_path}")

test = load_data("data/test/test.clean.npz")
test_ids = test['captions/ids']
test_emb = torch.from_numpy(test['captions/embeddings']).float()

pred_chunks = []
with torch.inference_mode():
    for Xb in DataLoader(test_emb, batch_size=BATCH_SIZE, shuffle=False, pin_memory=(DEVICE.type=='cuda')):
        Xb = Xb.to(DEVICE, non_blocking=True)
        Z_agg = None
        for m in models_ens:
            Zb = m(Xb)
            Z_agg = Zb if Z_agg is None else (Z_agg + Zb)

        Z_agg = Z_agg / len(models_ens)
        Z_agg = F.normalize(Z_agg, dim=-1)
        pred_chunks.append(Z_agg.cpu().to(torch.float32))

pred_test = torch.cat(pred_chunks, dim=0)

assert pred_test.shape[0] == len(test_ids)
assert pred_test.shape[1] == y.shape[1], f"{pred_test.shape[1]} != {y.shape[1]} (D_img)"

generate_submission(test_ids, pred_test, "submission_ensemble.csv")
print("Submission pronta: submission_ensemble.csv")