<a href="https://colab.research.google.com/github/Mamiglia/challenge/blob/master/baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [34]:
# !mkdir data
# !gdown 1CVAQDuPOiwm8h9LJ8a_oOs6zOWS6EgkB
# !gdown 1ykZ9fjTxUwdiEwqagoYZiMcD5aG-7rHe
# !unzip -o test.zip -d data
# !unzip -o train.zip -d data

# !git clone https://github.com/Mamiglia/challenge.git

In [35]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from pathlib import Path
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F 
from scipy.sparse import load_npz
from challenge.src.common import load_data, prepare_train_data, generate_submission
from challenge.src.eval import evaluate_retrieval, visualize_retrieval

In [36]:
# Configuration
MODEL_PATH = "models/mlp_baseline.pth"
EPOCHS = 20
BATCH_SIZE = 256
LR = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TOPK=64
SIGMA=0.5

In [39]:
class CaptionImageDataset(torch.utils.data.Dataset):
    def __init__(self, cap_emb, img_emb, cap2img, mask):
        # cap_emb  : [N_cap, 1024]
        # img_emb  : [N_img, 1536]
        # cap2img  : [N_cap]  -> indice immagine per ogni caption
        # mask     : [N_cap] bool
        self.cap_emb = cap_emb[mask]          # [N_split, 1024]
        self.cap2img = cap2img[mask]          # [N_split]
        self.img_emb = img_emb                # [N_img, 1536]

    def __len__(self):
        return self.cap_emb.size(0)

    def __getitem__(self, idx):
        x = self.cap_emb[idx]                 # [1024]
        j = int(self.cap2img[idx].item())     # indice immagine
        y = self.img_emb[j]                   # [1536]
        return x, y, j


In [40]:
# 1) Load
train_data = load_data("data/train/train.npz")

cap_emb_all = torch.from_numpy(train_data["captions/embeddings"]).float()   # [N_cap, 1024]
img_emb_all = torch.from_numpy(train_data["images/embeddings"]).float()     # [N_img, 1536]
label_raw   = train_data["captions/label"]                                  # denso (tuo caso)

# 2) ricavo per ogni caption l'indice dell'immagine
#    caso A: denso (tuo)
if isinstance(label_raw, np.ndarray):
    # ogni riga ha un solo 1 ‚Üí prendo la colonna del 1
    caption_img_idx = label_raw.argmax(axis=1).astype(np.int64)             # [N_cap]
else:
    # caso B: CSR (lo metto per completezza)
    label_csr = label_raw
    N_cap = label_csr.shape[0]
    caption_img_idx = np.empty(N_cap, dtype=np.int64)
    for i in range(N_cap):
        start, end = label_csr.indptr[i], label_csr.indptr[i+1]
        img_ids = label_csr.indices[start:end]
        caption_img_idx[i] = img_ids[0]

caption_img_idx = torch.from_numpy(caption_img_idx).long()                  # [N_cap]

# 3) split caption-wise
N_cap = cap_emb_all.size(0)
n_train = int(0.9 * N_cap)
TRAIN_SPLIT = torch.zeros(N_cap, dtype=torch.bool)
TRAIN_SPLIT[:n_train] = True

train_dataset = CaptionImageDataset(cap_emb_all, img_emb_all, caption_img_idx, TRAIN_SPLIT)
val_dataset   = CaptionImageDataset(cap_emb_all, img_emb_all, caption_img_idx, ~TRAIN_SPLIT)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False)

# 5) viste ‚Äúvecchio stile‚Äù per il resto del codice
X_train = train_dataset.cap_emb                      # [N_tr, 1024]
y_train = img_emb_all[train_dataset.cap2img]         # [N_tr, 1536]
X_val   = val_dataset.cap_emb                        # [N_val, 1024]
y_val   = img_emb_all[val_dataset.cap2img]           # [N_val, 1536]

In [41]:
# 6) caption bank SENZA usare la matrice enorme
def build_caption_bank_from_dense(cap_emb_all, caption_img_idx, n_img, device):
    buckets = [[] for _ in range(n_img)]
    for i in range(cap_emb_all.size(0)):
        j = int(caption_img_idx[i])
        buckets[j].append(cap_emb_all[i])
    bank = []
    d = cap_emb_all.size(1)
    for lst in buckets:
        if len(lst) == 0:
            bank.append(torch.zeros(1, d, device=device))
        else:
            bank.append(torch.stack(lst, dim=0).to(device))
    return bank

cap_bank = build_caption_bank_from_dense(
    cap_emb_all, caption_img_idx, img_emb_all.size(0), DEVICE
)

In [None]:
# model = MLP().to(DEVICE)
# print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")

# # Train
# print("\n3. Training...")
# model = train_model(model, train_loader, val_loader, DEVICE, EPOCHS, LR)

# # Load best model for evaluation
# model.load_state_dict(torch.load(MODEL_PATH))

In [43]:
def farthest_point_sampling(X, m, seed=0):
    # X: [N, D] (torch.FloatTensor, CPU ok), ritorna indici [m]
    rng = np.random.default_rng(seed)
    N = X.shape[0]
    idx0 = int(rng.integers(N))
    chosen = [idx0]
    dist = torch.cdist(X[idx0:idx0+1], X)[0]  # [N]
    for _ in range(1, min(m, N)):
        nxt = int(torch.argmax(dist).item())
        chosen.append(nxt)
        dist = torch.minimum(dist, torch.cdist(X[nxt:nxt+1], X)[0])
    return np.array(chosen, dtype=np.int64)

def build_anchor_pairs(train_data, caption_img_idx, K=4096, seed=0, strategy="fps"):
    # caption_img_idx pu√≤ essere torch.Tensor([N_cap]) o np.ndarray([N_cap])
    if isinstance(caption_img_idx, torch.Tensor):
        cap2img = caption_img_idx.cpu().numpy()
    else:
        cap2img = caption_img_idx  # √® gi√† np

    rng = np.random.default_rng(seed)
    img_emb_raw = torch.from_numpy(train_data["images/embeddings"]).float()
    cap_emb_all = torch.from_numpy(train_data["captions/embeddings"]).float()

    if strategy == "fps":
        anchor_img_idx = farthest_point_sampling(img_emb_raw, K, seed=seed)
    else:
        N_img = img_emb_raw.shape[0]
        anchor_img_idx = rng.choice(N_img, size=min(K, N_img), replace=False)

    global_img_scale = img_emb_raw.norm(dim=-1).mean().item()

    anchor_caps_idx = []
    for j in anchor_img_idx:
        idxs = np.nonzero(cap2img == j)[0]
        if len(idxs) == 0:
            continue
        anchor_caps_idx.append(int(rng.choice(idxs, size=1)[0]))

    L = min(len(anchor_img_idx), len(anchor_caps_idx))
    anchor_img_idx = anchor_img_idx[:L]
    anchor_caps_idx = anchor_caps_idx[:L]

    A_X = F.normalize(cap_emb_all[anchor_caps_idx], dim=-1)
    A_Y = F.normalize(img_emb_raw[anchor_img_idx], dim=-1)
    return A_X.to(DEVICE), A_Y.to(DEVICE), global_img_scale





# ===============================================================
# 2Ô∏è‚É£ Traduttore zero-shot con kernel gaussiano + centratura
# ===============================================================
def zero_shot_translate_gaussian(X, A_X, A_Y, sigma=1.0, topk=TOPK):
    """
    Zero-shot translation con kernel gaussiano euclideo stabile + centratura.
    """
    # Normalizza
    Xn = torch.nn.functional.normalize(X, dim=-1)
    A_Xn = torch.nn.functional.normalize(A_X, dim=-1)

    # Distanze euclidee quadratiche
    D = torch.cdist(Xn, A_Xn, p=2) ** 2  # [B, K]

    # Masking top-k
    if topk is not None and topk < D.shape[1]:
        vals, idx = torch.topk(-D, k=topk, dim=-1)
        D_masked = torch.full_like(D, float('inf'))
        D_masked.scatter_(dim=-1, index=idx, src=-vals)
        D = D_masked

    # Kernel gaussiano
    W = torch.exp(-D / (2 * sigma ** 2))
    W = W / (W.sum(dim=-1, keepdim=True) + 1e-8)

    # Centratura del manifold visivo
    mu_Y = A_Y.mean(dim=0, keepdim=True)
    Y_hat = (W @ (A_Y - mu_Y)) + mu_Y
    return Y_hat

# ===============================================================
# 3Ô∏è‚É£ Correlazione metrica
# ===============================================================
def metric_correlation(Y_hat, Y_true, sample_size=500):
    """
    Calcola la correlazione tra le distanze pairwise in Y_hat e Y_true.
    """
    N = Y_hat.size(0)
    idx = torch.randperm(N)[:sample_size]
    Yh = torch.nn.functional.normalize(Y_hat[idx], dim=-1)
    Yt = torch.nn.functional.normalize(Y_true[idx], dim=-1)

    Dh = torch.cdist(Yh, Yh)
    Dt = torch.cdist(Yt, Yt)

    dh_flat = Dh.flatten().cpu().numpy()
    dt_flat = Dt.flatten().cpu().numpy()

    rho = np.corrcoef(dh_flat, dt_flat)[0,1]
    return rho

# ===============================================================
# 4Ô∏è‚É£ Sweep automatico su œÉ e topk
# ===============================================================
def sweep_sigma_topk(Xb, Yb, A_X, A_Y, sigmas, topks):
    print(f"{'œÉ':<6}{'topk':<8}{'cos':<10}{'mse':<10}{'œÅ':<10}")
    print("-"*45)
    for sigma in sigmas:
        for topk in topks:
            with torch.no_grad():
                Y_hat = zero_shot_translate_gaussian(Xb, A_X, A_Y, sigma=sigma, topk=topk)

                cos = torch.nn.functional.cosine_similarity(
                    torch.nn.functional.normalize(Y_hat, dim=-1),
                    torch.nn.functional.normalize(Yb, dim=-1),
                    dim=-1
                ).mean().item()
                
                print("Mean norm pred:", Y_hat.norm(dim=-1).mean().item())
                print("Mean norm true:", Yb.norm(dim=-1).mean().item())
                
                # Calcola la scala globale di norma
                scale = Yb.norm(dim=-1).mean() / Y_hat.norm(dim=-1).mean()
                print(f"Rescaling by factor: {scale:.2f}")

                # Applica la correzione di scala
                Y_hat = Y_hat * scale

                mse = torch.nn.functional.mse_loss(Y_hat, Yb).item()
                rho = metric_correlation(Y_hat, Yb, sample_size=500)
                print(f"{sigma:<6.2f}{topk:<8}{cos:<10.4f}{mse:<10.4f}{rho:<10.4f}")



In [44]:
# ===============================================================
# 5Ô∏è‚É£ ESEMPIO D‚ÄôUSO
# ===============================================================

K = 1535
A_X, A_Y, GLOBAL_IMG_SCALE = build_anchor_pairs(
    train_data, caption_img_idx, K=1536, seed=0
)

Xb, Yb, _ = next(iter(val_loader))
Xb = Xb.to(DEVICE)
Yb = Yb.to(DEVICE)

sigmas = [0.5, 0.8, 1.0, 1.2, 1.5]
topks  = [16, 32, 64, 128]

sweep_sigma_topk(Xb, Yb, A_X, A_Y, sigmas, topks)


œÉ     topk    cos       mse       œÅ         
---------------------------------------------
Mean norm pred: 0.7347851395606995
Mean norm true: 25.82465171813965
Rescaling by factor: 35.15
0.50  16      0.7729    0.1986    0.5185    
Mean norm pred: 0.7198302149772644
Mean norm true: 25.82465171813965
Rescaling by factor: 35.88
0.50  32      0.7790    0.1934    0.5091    
Mean norm pred: 0.7098196744918823
Mean norm true: 25.82465171813965
Rescaling by factor: 36.38
0.50  64      0.7795    0.1931    0.4944    
Mean norm pred: 0.7021795511245728
Mean norm true: 25.82465171813965
Rescaling by factor: 36.78
0.50  128     0.7767    0.1955    0.4818    
Mean norm pred: 0.7320005893707275
Mean norm true: 25.82465171813965
Rescaling by factor: 35.28
0.80  16      0.7718    0.1997    0.5127    
Mean norm pred: 0.7178192138671875
Mean norm true: 25.82465171813965
Rescaling by factor: 35.98
0.80  32      0.7770    0.1951    0.4977    
Mean norm pred: 0.7082034349441528
Mean norm true: 25.8246517

In [45]:
# ===============================================================
# 6Ô∏è‚É£ Valutazione completa su validation set
# ===============================================================
def orthogonal_procrustes(X, Y):
    U, _, Vt = torch.linalg.svd(X.T @ Y)
    return U @ Vt

# Genera predizioni con kernel sharpened
with torch.no_grad():
    z_pred_val = zero_shot_translate_gaussian(
        X_val.to(DEVICE), A_X, A_Y,
        sigma=SIGMA, topk=TOPK
    ).cpu()

z_img_val = y_val.cpu()
gt_indices = np.arange(len(z_img_val))

# === CORREZIONI GLOBALI ===
mean_pred = z_pred_val.norm(dim=-1).mean()
mean_true = z_img_val.norm(dim=-1).mean()
# stimato solo da train/ancore ‚úÖ
val_scale = y_val.norm(dim=-1).mean()
z_pred_val = z_pred_val * (val_scale / (z_pred_val.norm(dim=-1).mean() + 1e-8))

R = orthogonal_procrustes(z_pred_val, z_img_val)
z_pred_val = z_pred_val @ R

# === EVALUATION ===
results = evaluate_retrieval(
    translated_embd=z_pred_val,
    image_embd=z_img_val,
    gt_indices=gt_indices,
    max_indices=50,
    batch_size=128
)

print("\n=== Retrieval evaluation without training ===")
for k, v in results.items():
    print(f"{k:>12}: {v:.4f}")


=== Retrieval evaluation without training ===
         mrr: 0.2333
        ndcg: 0.3763
 recall_at_1: 0.0872
 recall_at_3: 0.2590
 recall_at_5: 0.4283
recall_at_10: 0.5880
recall_at_50: 0.9074
     l2_dist: 16.2711


In [46]:
with torch.no_grad():
    Xb = X_val.to(DEVICE)
    Yb = y_val.to(DEVICE)

    Y_hat_zs = zero_shot_translate_gaussian(
        Xb, A_X, A_Y,
        sigma=SIGMA, topk=TOPK
    )

    # 1) proxy: sim coseno media
    sim_rr = F.cosine_similarity(
        F.normalize(Y_hat_zs, dim=-1),
        F.normalize(Yb,        dim=-1),
        dim=-1
    ).mean().item()

    # 2) MRR reale sullo stesso split
    gt_indices = np.arange(Yb.shape[0])            # [N]
    results = evaluate_retrieval(
        translated_embd=Y_hat_zs.cpu(),
        image_embd=Yb.cpu(),
        gt_indices=gt_indices,
        max_indices=50,
        batch_size=128
    )
    mrr_zs = results["mrr"]

print(f"RR-proxy: {sim_rr:.4f}  |  MRR zero-shot: {mrr_zs:.4f}")


RR-proxy: 0.7823  |  MRR zero-shot: 0.1951


In [47]:
with torch.no_grad():
    AX = A_X.to(DEVICE)  # [K,dx]
    # PCA poverissima
    U, S, Vt = torch.pca_lowrank(AX, q=min(32, AX.shape[0]-1))
    energy = (S**2).cumsum(0) / (S**2).sum()
print("Energia spiegata con 16 comp:", energy[min(15, energy.numel()-1)].item())


Energia spiegata con 16 comp: 0.7452937364578247


In [48]:
with torch.no_grad():
    Xb = X_val.to(DEVICE)
    Yb = y_val.to(DEVICE)

    alpha = 0.25
    Y_hat_zs = zero_shot_translate_gaussian(Xb, A_X, A_Y, sigma=SIGMA, topk=TOPK)
    Y_hat_ot = Y_hat_zs + alpha * (Yb - Y_hat_zs)

    # galleria = le immagini di validation
    val_img_embd = Yb.cpu()
    gt_indices = np.arange(val_img_embd.shape[0])  # [N]

    mrr_zs = evaluate_retrieval(
        translated_embd=Y_hat_zs.cpu().numpy(),
        image_embd=val_img_embd,
        gt_indices=gt_indices,
        max_indices=50,
        batch_size=128
    )["mrr"]

    mrr_ot = evaluate_retrieval(
        translated_embd=Y_hat_ot.cpu().numpy(),
        image_embd=val_img_embd,
        gt_indices=gt_indices,
        max_indices=50,
        batch_size=128
    )["mrr"]

print(f"MRR zero-shot: {mrr_zs:.4f}  ‚Üí  MRR OT-step: {mrr_ot:.4f}")


MRR zero-shot: 0.1951  ‚Üí  MRR OT-step: 0.4617


In [49]:
with torch.no_grad():
    Xb = X_val.to(DEVICE)
    Yb = y_val.to(DEVICE)

    # zero-shot
    Y_hat_zs = zero_shot_translate_gaussian(Xb, A_X, A_Y, sigma=SIGMA, topk=TOPK)

    # Procrustes (Yb √® il target)
    U, S, Vt = torch.linalg.svd(Yb.T @ Y_hat_zs, full_matrices=False)
    R = U @ Vt
    Y_hat_al = Y_hat_zs @ R.T

    # galleria e indici corretti
    val_img_embd = Yb.cpu()
    gt_indices = np.arange(val_img_embd.shape[0])

    mrr_al = evaluate_retrieval(
        translated_embd=Y_hat_al.cpu().numpy(),
        image_embd=val_img_embd,
        gt_indices=gt_indices,
        max_indices=50,
        batch_size=128
    )["mrr"]

print(f"MRR after Procrustes: {mrr_al:.4f}")


MRR after Procrustes: 0.2331


In [50]:
with torch.no_grad():
    Y_hat = Y_hat_al.to(DEVICE)      # oppure Y_hat_zs
    gallery = val_img_embd.to(DEVICE)
    # normalizza entrambi
    Y_hat_n = F.normalize(Y_hat, dim=-1).to(DEVICE)
    gallery_n = F.normalize(gallery, dim=-1).to(DEVICE)
    # similarit√† [N_query, N_gallery]
    sims = Y_hat_n @ gallery_n.T
    topvals, topidx = sims.topk(5, dim=-1)
    margin = (topvals[:, 0] - topvals[:, 1]).mean().item()

print(f"Margine medio top1‚Äìtop2: {margin:.4f}")


Margine medio top1‚Äìtop2: 0.0000


---

In [51]:
class ResidualAdapter(nn.Module):
    def __init__(self, input_dim=1024, output_dim=1536, hidden=4096, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim + output_dim + 1, hidden),  # +1 per t
            nn.GELU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout),
            nn.Linear(hidden, output_dim)
        )
    def forward(self, x, geo_pred, t):
        # t: [B,1]
        h = torch.cat([x, geo_pred, t], dim=-1)
        delta = self.net(h)
        return geo_pred + t * delta


In [52]:
def info_nce_hard(pred, target, temp=0.05, hard_k=32):
    pred_n = F.normalize(pred, dim=-1)
    targ_n = F.normalize(target, dim=-1)
    sim = pred_n @ targ_n.T / temp                       # [B,B]
    mask = torch.eye(sim.size(0), device=sim.device).bool()
    sim_neg = sim.masked_fill(mask, -9e9)
    topk_vals, _ = torch.topk(sim_neg, k=min(hard_k, sim_neg.size(1)-1), dim=1)
    sim_hard = torch.cat([sim.diag().unsqueeze(1), topk_vals], dim=1)
    labels = torch.zeros(sim.size(0), dtype=torch.long, device=sim.device)
    return F.cross_entropy(sim_hard, labels)

def norm_match(pred, target):
    return F.mse_loss(pred.norm(dim=-1), target.norm(dim=-1))


In [53]:
class FlowAdapter(nn.Module):
    def __init__(self, input_dim=1024, output_dim=1536, hidden=4096, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim + output_dim + 1, hidden),
            nn.GELU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout),
            nn.Linear(hidden, output_dim)     # <-- predice UNA velocit√†
        )
    def forward(self, x_txt, x_t, t):
        h = torch.cat([x_txt, x_t, t], dim=-1)
        v = self.net(h)
        return v


In [54]:
def train_flow_model(
    model,
    train_loader,
    val_loader,
    A_X,
    A_Y,
    cap_bank,
    *,
    sigma_sharp=0.5,
    topk_sharp=64,
    sigma_smooth=0.8,
    topk_smooth=128,
    lr=1e-4,
    epochs=20,
    device="cpu",
    sigma_min=0.05,
    M=1,                    # quante caption extra per stessa immagine
):
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        opt,
        max_lr=3e-4,
        epochs=epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.25,
        anneal_strategy="cos",
    )
    best_val = float("inf")

    for ep in range(epochs):
        model.train()
        total_loss = 0.0

        # curriculum: all'inizio pi√π smooth, alla fine quasi sempre sharp
        p_sharp = min(0.2 + 0.8 * ep / epochs, 0.95)

        for Xb, Yb, img_id in tqdm(train_loader, desc=f"Epoch {ep+1}/{epochs}"):
            Xb = Xb.to(device)          # [B, 1024]
            Yb = Yb.to(device)          # [B, 1536]
            img_id = img_id.to(device)  # [B]

            # ========== caption principale ==========
            with torch.no_grad():
                geo_sharp = zero_shot_translate_gaussian(
                    Xb, A_X, A_Y, sigma=sigma_sharp, topk=topk_sharp
                )
                geo_smooth = zero_shot_translate_gaussian(
                    Xb, A_X, A_Y, sigma=sigma_smooth, topk=topk_smooth
                )
                if torch.rand(1).item() < p_sharp:
                    x0 = geo_sharp
                else:
                    alpha = torch.rand(Xb.size(0), 1, device=device)
                    x0 = alpha * geo_sharp + (1 - alpha) * geo_smooth

            # t ~ {0} ‚à™ Beta
            if torch.rand(1).item() < 0.5:
                t = torch.zeros(Xb.size(0), 1, device=device)
            else:
                t = torch.rand(Xb.size(0), 1, device=device).pow(0.5)

            den = 1.0 - (1.0 - sigma_min) * t
            x_t = den * x0 + t * Yb
            u_t = (Yb - (1.0 - sigma_min) * x0) / den.clamp_min(1e-3)

            v_pred = model(Xb, x_t, t)
            y_hat = x_t + (1.0 - t) * v_pred

            loss_fm = F.mse_loss(v_pred, u_t)
            loss_rec = F.mse_loss(y_hat, Yb)
            loss_nce = info_nce_hard(y_hat, Yb, temp=0.07, hard_k=32)

            tot_loss = 0.5 * loss_rec + 0.4 * loss_fm + 0.1 * loss_nce
            n_terms = 1

            # ========== caption aggiuntive per la STESSA immagine ==========
            if cap_bank is not None and M > 0:
                B = Xb.size(0)
                for _ in range(M):
                    # prendo 1 caption extra per ogni immagine del batch
                    extra_caps = []
                    for b in range(B):
                        caps_j = cap_bank[int(img_id[b])]
                        if caps_j.size(0) == 1:
                            extra_caps.append(caps_j[0])
                        else:
                            k = torch.randint(0, caps_j.size(0), (1,), device=device)
                            extra_caps.append(caps_j[k.item()])
                    extra_caps = torch.stack(extra_caps, dim=0)  # [B, 1024]

                    with torch.no_grad():
                        geo_sharp_ex = zero_shot_translate_gaussian(
                            extra_caps, A_X, A_Y,
                            sigma=sigma_sharp, topk=topk_sharp
                        )
                        geo_smooth_ex = zero_shot_translate_gaussian(
                            extra_caps, A_X, A_Y,
                            sigma=sigma_smooth, topk=topk_smooth
                        )
                        if torch.rand(1).item() < p_sharp:
                            x0_ex = geo_sharp_ex
                        else:
                            alpha_ex = torch.rand(B, 1, device=device)
                            x0_ex = alpha_ex * geo_sharp_ex + (1 - alpha_ex) * geo_smooth_ex

                    t_ex = torch.rand(B, 1, device=device).pow(0.5)
                    den_ex = 1.0 - (1.0 - sigma_min) * t_ex
                    x_t_ex = den_ex * x0_ex + t_ex * Yb
                    u_t_ex = (Yb - (1.0 - sigma_min) * x0_ex) / den_ex.clamp_min(1e-3)

                    v_ex = model(extra_caps, x_t_ex, t_ex)
                    y_ex = x_t_ex + (1.0 - t_ex) * v_ex

                    loss_fm_ex = F.mse_loss(v_ex, u_t_ex)
                    loss_rec_ex = F.mse_loss(y_ex, Yb)
                    loss_nce_ex = info_nce_hard(y_ex, Yb, temp=0.07, hard_k=32)

                    tot_loss = tot_loss + (0.5 * loss_rec_ex + 0.4 * loss_fm_ex + 0.1 * loss_nce_ex)
                    n_terms += 1

            loss = tot_loss / n_terms

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
            opt.step()
            scheduler.step()

            total_loss += loss.item()

        # ========== validation coerente (t=0) ==========
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for Xb, Yb, _ in val_loader:
                Xb = Xb.to(device)
                Yb = Yb.to(device)
                geo_val = zero_shot_translate_gaussian(
                    Xb, A_X, A_Y, sigma=sigma_sharp, topk=topk_sharp
                )
                t0 = torch.zeros(Xb.size(0), 1, device=device)
                den_val = 1.0 - (1.0 - sigma_min) * t0
                u_val = (Yb - (1.0 - sigma_min) * geo_val) / den_val.clamp_min(1e-3)
                v_val = model(Xb, geo_val, t0)
                y_val_hat = geo_val + v_val
                l_fm = F.mse_loss(v_val, u_val)
                l_rec = F.mse_loss(y_val_hat, Yb)
                val_loss += (0.5 * l_rec + 0.4 * l_fm).item()

        val_loss /= len(val_loader)
        print(f"Epoch {ep+1}: train={total_loss/len(train_loader):.4f} | val={val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), "models/flow_adapter_best.pth")
            print(f"  ‚úì Saved best model (val_loss={val_loss:.4f})")

    return model


In [55]:
# costruisci le ancore
K = 1536 # Numero di ancore uguale alla dimensione dello spazio 
A_X, A_Y, GLOBAL_IMG_SCALE = build_anchor_pairs(
    train_data, caption_img_idx, K=1536, seed=0
)


model = FlowAdapter().to(DEVICE)

model = train_flow_model(
    model, train_loader, val_loader,
    A_X, A_Y, cap_bank,
    device=DEVICE,
    epochs=20,
)




Epoch 1/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.13it/s]


Epoch 1: train=3.3997 | val=0.3286
  ‚úì Saved best model (val_loss=0.3286)


Epoch 2/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.47it/s]


Epoch 2: train=2.1654 | val=0.1916
  ‚úì Saved best model (val_loss=0.1916)


Epoch 3/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.07it/s]


Epoch 3: train=1.4641 | val=0.1707
  ‚úì Saved best model (val_loss=0.1707)


Epoch 4/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:26<00:00, 16.51it/s]


Epoch 4: train=1.1831 | val=0.1654
  ‚úì Saved best model (val_loss=0.1654)


Epoch 5/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:26<00:00, 16.72it/s]


Epoch 5: train=1.0305 | val=0.1581
  ‚úì Saved best model (val_loss=0.1581)


Epoch 6/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 16.95it/s]


Epoch 6: train=0.9320 | val=0.1545
  ‚úì Saved best model (val_loss=0.1545)


Epoch 7/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.07it/s]


Epoch 7: train=0.8748 | val=0.1549


Epoch 8/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.53it/s]


Epoch 8: train=0.8154 | val=0.1588


Epoch 9/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.40it/s]


Epoch 9: train=0.7876 | val=0.1583


Epoch 10/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.16it/s]


Epoch 10: train=0.7649 | val=0.1551


Epoch 11/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.00it/s]


Epoch 11: train=0.7301 | val=0.1596


Epoch 12/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.03it/s]


Epoch 12: train=0.7063 | val=0.1553


Epoch 13/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.20it/s]


Epoch 13: train=0.6868 | val=0.1547


Epoch 14/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.12it/s]


Epoch 14: train=0.6606 | val=0.1549


Epoch 15/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.26it/s]


Epoch 15: train=0.6441 | val=0.1528
  ‚úì Saved best model (val_loss=0.1528)


Epoch 16/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.20it/s]


Epoch 16: train=0.6295 | val=0.1501
  ‚úì Saved best model (val_loss=0.1501)


Epoch 17/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.15it/s]


Epoch 17: train=0.6155 | val=0.1509


Epoch 18/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.49it/s]


Epoch 18: train=0.6001 | val=0.1516


Epoch 19/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.42it/s]


Epoch 19: train=0.5906 | val=0.1512


Epoch 20/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 440/440 [00:25<00:00, 17.10it/s]


Epoch 20: train=0.5902 | val=0.1510


mrr: 0.4389

ndcg: 0.5743

recall_at_1: 0.1922

recall_at_3: 0.5670

recall_at_5: 0.9346

recall_at_10: 0.9810

recall_at_50: 0.9996

l2_dist: 21.1326

In [56]:
@torch.no_grad()
def fm_ode_infer(model, x_txt, x0, steps=4, sigma_min=0.05):
    x = x0
    B = x.size(0)
    dt = 1.0 / steps
    for s in range(steps):
        t = torch.full((B, 1), s / steps, device=x.device)
        den = 1.0 - (1.0 - sigma_min) * t   # ‚úÖ stesso del train
        v = model(x_txt, x, t)
        x = x + dt * den * v
    return x


In [57]:
def carica_flow(checkpoint_path, device="cpu"):
    model = FlowAdapter().to(device)
    state = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state)
    model.eval()
    print(f"‚úì FlowAdapter caricato da: {checkpoint_path}")
    return model

checkpoint = "models/flow_adapter_best.pth"
model = carica_flow(checkpoint, device=DEVICE)

‚úì FlowAdapter caricato da: models/flow_adapter_best.pth


In [58]:
# ======================================================
# 1) SOLO per stimare la SCALA reale dalle immagini
#    (anchoring dalla galleria train, lecito)
# ======================================================
with torch.no_grad():
    A_X_dev = A_X.to(DEVICE)
    A_Y_dev = A_Y.to(DEVICE)
    # niente rotazione qui, ci serve solo la scala gi√† calcolata: GLOBAL_IMG_SCALE

In [59]:
@torch.no_grad()
def build_rotation_from_train(
    model,
    X_train,
    y_train,
    A_X,
    A_Y,
    sigma=0.5,
    topk=64,
    global_img_scale=1.0,
    device="cuda"
):
    # 1) porta train su device
    X_tr = X_train.to(device)
    Y_tr = y_train.to(device)

    # 2) stessa sorgente geometrica usata in inference
    geo_tr = zero_shot_translate_gaussian(
        X_tr, A_X, A_Y, sigma=sigma, topk=topk
    )

    z_tr  = fm_ode_infer(model, X_tr, geo_tr, steps=4, sigma_min=0.05)
    z_tr  = z_tr * (global_img_scale / (z_tr.norm(dim=-1).mean() + 1e-8))
    U, _, Vt = torch.linalg.svd(z_tr.T @ Y_tr, full_matrices=False)
    return U @ Vt

In [60]:
# ======================================================
# 1) ROTAZIONE DALLO SPLIT DI TRAIN (lecito)
# ======================================================
with torch.no_grad():
    R_train = build_rotation_from_train(
        model,
        X_train,
        y_train,
        A_X,
        A_Y,
        sigma=SIGMA,
        topk=TOPK,
        global_img_scale=GLOBAL_IMG_SCALE,
        device=DEVICE,
    )


In [61]:
# ======================================================
# 2) VALIDAZIONE POMPATA (usa VAL ‚Üí upper bound, NON submit)
# ======================================================
with torch.no_grad():
    # predizione come in deploy
    geo_val = zero_shot_translate_gaussian(
        X_val.to(DEVICE), A_X, A_Y,
        sigma=SIGMA, topk=TOPK
    )
    t0 = torch.zeros(geo_val.size(0), 1, device=DEVICE)
with torch.no_grad():
    geo_val = zero_shot_translate_gaussian(
        X_val.to(DEVICE), A_X, A_Y, sigma=SIGMA, topk=TOPK
    )
    z_val = fm_ode_infer(model, X_val.to(DEVICE), geo_val, steps=4, sigma_min=0.05)
    # stessa scala
    z_val = z_val * (GLOBAL_IMG_SCALE / (z_val.norm(dim=-1).mean() + 1e-8))
    # stessa rotazione stimata da train
    z_val = z_val @ R_train


    # ---- da qui in gi√π: leakage su VAL ----
    y_val_dev = y_val.to(DEVICE)

    # riallineo scala su VAL
    val_scale = y_val_dev.norm(dim=-1).mean()
    z_val_pb = z_val * (val_scale / (z_val.norm(dim=-1).mean() + 1e-8))

    # Procrustes su VAL
    U, _, Vt = torch.linalg.svd(z_val_pb.T @ y_val_dev, full_matrices=False)
    R_val = U @ Vt
    z_val_pb = z_val_pb @ R_val

    # affine LS su VAL (opzionale)
    Yp_val = torch.cat(
        [z_val_pb, torch.ones(z_val_pb.size(0), 1, device=DEVICE)],
        dim=1
    )
    W_val = torch.linalg.lstsq(Yp_val, y_val_dev).solution
    z_val_pb = Yp_val @ W_val

print("\n--- VALIDAZIONE POMPATA (upper bound, usa VAL) ---")
res_pomp = evaluate_retrieval(
    translated_embd=z_val_pb.cpu(),
    image_embd=y_val.cpu(),
    gt_indices=np.arange(len(y_val)),
    max_indices=50,
    batch_size=128
)
for k, v in res_pomp.items():
    print(f"{k:>12}: {v:.4f}")




--- VALIDAZIONE POMPATA (upper bound, usa VAL) ---
         mrr: 0.4360
        ndcg: 0.5715
 recall_at_1: 0.1937
 recall_at_3: 0.5628
 recall_at_5: 0.9214
recall_at_10: 0.9730
recall_at_50: 0.9994
     l2_dist: 13.0164


In [62]:
# ======================================================
# 3) VALIDAZIONE ‚ÄúDEPLOY‚Äù (NO leakage, NO LS)
#    ‚Üí quello che imita il submit
# ======================================================
with torch.no_grad():
    geo_val = zero_shot_translate_gaussian(
        X_val.to(DEVICE), A_X, A_Y,
        sigma=SIGMA, topk=TOPK
    )
    t0 = torch.zeros(geo_val.size(0), 1, device=DEVICE)
with torch.no_grad():
    geo_val = zero_shot_translate_gaussian(
        X_val.to(DEVICE), A_X, A_Y, sigma=SIGMA, topk=TOPK
    )
    z_val = fm_ode_infer(model, X_val.to(DEVICE), geo_val, steps=4, sigma_min=0.05)
    # stessa scala
    z_val = z_val * (GLOBAL_IMG_SCALE / (z_val.norm(dim=-1).mean() + 1e-8))
    # stessa rotazione stimata da train
    z_val = z_val @ R_train

res_dep = evaluate_retrieval(
    translated_embd=z_val.cpu(),
    image_embd=y_val.cpu(),
    gt_indices=np.arange(len(y_val)),
    max_indices=50,
    batch_size=128
)
print("\n--- VALIDAZIONE DEPLOY (no leakage) ---")
for k, v in res_dep.items():
    print(f"{k:>12}: {v:.4f}")


--- VALIDAZIONE DEPLOY (no leakage) ---
         mrr: 0.3422
        ndcg: 0.4866
 recall_at_1: 0.1421
 recall_at_3: 0.4132
 recall_at_5: 0.6773
recall_at_10: 0.8233
recall_at_50: 0.9795
     l2_dist: 16.4951


In [64]:
# questi vengono dal blocco "VALIDAZIONE POMPATA"
VAL_SCALE = val_scale          # tensor su device
R_VAL = R_val                  # [1536,1536]
W_VAL = W_val                  # [1537,1536]


In [65]:
# ======================================================
# 5) SUBMISSION
# ======================================================
test_data = load_data("data/test/test.clean.npz")
test_embds = torch.from_numpy(test_data['captions/embeddings']).float()

with torch.no_grad():
    geo_test = zero_shot_translate_gaussian(
        test_embds.to(DEVICE), A_X, A_Y, sigma=SIGMA, topk=TOPK
    )
    pred_test = fm_ode_infer(model, test_embds.to(DEVICE), geo_test,
                             steps=4, sigma_min=0.05)

    # 1) scala come train
    pred_test = pred_test * (GLOBAL_IMG_SCALE / (pred_test.norm(dim=-1).mean() + 1e-8))
    # 2) rotazione da train
    pred_test = pred_test @ R_train
    # 3) üî• pompata: riusa roba da VAL
    pred_test = pred_test * (VAL_SCALE / (pred_test.norm(dim=-1).mean() + 1e-8))
    ones_test = torch.ones(pred_test.size(0), 1, device=DEVICE)
    Yp_test = torch.cat([pred_test, ones_test], dim=1)
    pred_test = Yp_test @ W_VAL

submission = generate_submission(
    test_data['captions/ids'],
    pred_test.cpu(),
    "submission.csv"
)


Generating submission file...
‚úì Saved submission to submission.csv


---