In [1]:
from typing import Optional
from torch import nn
from torch.nn import functional as F
import torch
from pathlib import Path
import numpy as np
import pandas as pd


from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

In [2]:
def mrr(pred_indices: np.ndarray, gt_indices: np.ndarray) -> float:
    """
    Compute Mean Reciprocal Rank (MRR)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries (top-K)
        gt_indices: (N,) array of ground truth indices
    Returns:
        mrr: Mean Reciprocal Rank
    """
    reciprocal_ranks = []
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i] == gt_indices[i])[0]
        if matches.size > 0:
            reciprocal_ranks.append(1.0 / (matches[0] + 1))
        else:
            reciprocal_ranks.append(0.0)
    return np.mean(reciprocal_ranks)


def recall_at_k(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int) -> float:
    """Compute Recall@k
    Args:
        pred_indices: (N, N) array of top indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        recall: Recall@k
    """
    recall = 0
    for i in range(len(gt_indices)):
        if gt_indices[i] in pred_indices[i, :k]:
            recall += 1
    recall /= len(gt_indices)
    return recall



def ndcg(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int = 100) -> float:
    """
    Compute Normalized Discounted Cumulative Gain (NDCG@k)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        ndcg: NDCG@k
    """
    ndcg_total = 0.0
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i, :k] == gt_indices[i])[0]
        if matches.size > 0:
            rank = matches[0] + 1
            ndcg_total += 1.0 / np.log2(rank + 1)  # DCG (IDCG = 1)
    return ndcg_total / len(gt_indices)
    
@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """Evaluate retrieval performance using cosine similarity
    Args:
        translated_embd: (N_captions, D) translated caption embeddings
        image_embd: (N_images, D) image embeddings
        gt_indices: (N_captions,) ground truth image indices for each caption
        max_indices: number of top predictions to consider
    Returns:
        results: dict of evaluation metrics
    
    """
    # Compute similarity matrix
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()
    
    n_queries = translated_embd.shape[0]
    device = translated_embd.device
    
    # Prepare containers for the fragments to be reassembled
    all_sorted_indices = []
    l2_distances = []
    
    # Process in batches - the narrow gate approach
    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]
        batch_img_embd = image_embd[batch_slice]
        
        # Compute similarity only for this batch
        batch_similarity = batch_translated @ batch_img_embd.T

        # Get top-k predictions for this batch
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.numpy()
        all_sorted_indices.append(gt_indices[batch_slice][batch_indices])

        # Compute L2 distance for this batch
        batch_gt = gt_indices[batch_slice]
        batch_gt_embeddings = image_embd[batch_gt]
        batch_l2 = (batch_translated - batch_gt_embeddings).norm(dim=1)
        l2_distances.append(batch_l2)
    
    # Reassemble the fragments
    sorted_indices = np.concatenate(all_sorted_indices, axis=0)
    
    # Apply the sacred metrics to the whole
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }
    
    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }
    
    l2_dist = torch.cat(l2_distances, dim=0).mean().item()
    results['l2_dist'] = l2_dist
    
    return results

In [3]:
data_path= '/kaggle/input/aml-competition/train/train/train.npz'
test_path= '/kaggle/input/aml-competition/test/test/test.clean.npz'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def analyze_tensor_distances(X: torch.Tensor, Y: torch.Tensor):
    print("-" * 40)
    
    # Distanza euclidea per riga
    row_distances = torch.norm(X - Y, dim=1)  # Nx1
    mean_euclidean = row_distances.mean().item()
    print(f"üìè Mean Euclidean distance (per row): {mean_euclidean:.4f}")
    
    print("-" * 40)
    
    # Similarit√† coseno per riga
    row_cosine_sim = F.cosine_similarity(X, Y, dim=1)  # Nx1
    mean_cosine_sim = row_cosine_sim.mean().item()
    print(f"üìê Mean Cosine similarity (per row): {mean_cosine_sim:.4f}")
    
    print("-" * 40, end='\n\n')

def get_data(data_path: Path):
    data = np.load(data_path)
    caption_embeddings = data['captions/embeddings']
    image_embeddings = data['images/embeddings']
    caption_labels = data['captions/label']
    data.close()

    X_abs, y_abs = torch.tensor(caption_embeddings), torch.tensor(image_embeddings[np.argmax(caption_labels, axis=1)])

    return X_abs, y_abs
    
def get_datasets(X_abs, y_abs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:    
    print('Texts shape', X_abs.shape)
    print('Images shape', y_abs.shape)
    
    dataset = TensorDataset(X_abs, y_abs)
    train_dataset, val_dataset = random_split(dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(42))
    
    return train_dataset, val_dataset

In [4]:
def anchor_pruning(X: torch.Tensor, n_samples: int, threshold: float = 0.65, seed=324) -> torch.Tensor:
    N, D = X.shape
    device = X.device
    selected = []

    torch.manual_seed(seed)
    if device.type == 'cuda':
        torch.cuda.manual_seed(seed)

    first_idx = torch.randint(0, N, (1,), device=device).item()
    selected.append(first_idx)
    
    first_point = X[first_idx].unsqueeze(0)
    min_dists = 1.0 - torch.abs(X @ first_point.T).squeeze(1)
    
    while len(selected) < n_samples:
        farthest_idx = torch.argmax(min_dists).item()
        max_dist = min_dists[farthest_idx].item()
        
        if max_dist < threshold:
            break
        
        selected.append(farthest_idx)
        
        new_point = X[farthest_idx].unsqueeze(0)
        dist_to_new = 1.0 - torch.abs(X @ new_point.T).squeeze(1)
        min_dists = torch.minimum(min_dists, dist_to_new)
    
    return torch.tensor(selected, dtype=torch.long, device=device)


def anchor_completion(Ax: torch.Tensor, Ay: torch.Tensor, desired_dim: int):
    assert Ax.size(0) == Ay.size(0)
    assert isinstance(desired_dim, int) and desired_dim > 0
    
    if desired_dim < Ax.size(0):
        return Ax, Ay
    
    Sx = torch.eye(Ax.size(1)).to(Ax.device)
    Sy = (Sx @ Ax.T) @ torch.linalg.pinv(Ay.T)
    
    remaining_dim = desired_dim - Ax.size(0)
    
    Ax = torch.cat([Ax, Sx[:remaining_dim]], dim=0)
    Ay = torch.cat([Ay, Sy[:remaining_dim]], dim=0)
    
    return Ax, Ay


def anchor_subspaces(X_normalized: torch.Tensor, Y_normalized: torch.Tensor, anchors_no: int, omega: int, delta: float):
    assert isinstance(omega, int) and omega > 0
    
    anchors = []
    
    for _ in range(omega):
        seed = torch.randint(0, 2**32, (1,)).item()
        indices = anchor_pruning(Y_normalized.to('cuda'), anchors_no, delta, seed).cpu()
    
        Ax = X_normalized[indices]
        Ay = Y_normalized[indices]
    
        Ax, Ay = anchor_completion(Ax, Ay, anchors_no)

        anchors.append((Ax, Ay))
    
    return anchors


def reconstruct_source(X_normalized: torch.Tensor, subspaces: list[tuple[torch.Tensor, torch.Tensor]]):   
    anchors = []

    X_reconstruct = torch.zeros_like(X_normalized).to('cuda')
    
    for Ax, _ in subspaces:
        X_rel = X_normalized @ Ax.T
        X_reconstruct += X_rel @ torch.linalg.pinv(Ax.T)
    
    return X_reconstruct / len(subspaces)


def reconstruct_target(Y_normalized: torch.Tensor, subspaces: list[tuple[torch.Tensor, torch.Tensor]]):   
    anchors = []

    Y_reconstruct = torch.zeros_like(Y_normalized).to('cuda')
    
    for _, Ay in subspaces:
        Y_rel = Y_normalized @ Ay.T
        Y_reconstruct += Y_rel @ torch.linalg.pinv(Ay.T)
    
    return Y_reconstruct / len(subspaces)


def reconstruct_target_from_data(data: torch.Tensor, subspaces: list[tuple[torch.Tensor, torch.Tensor]]):   
    anchors = []

    Y_reconstruct = None #torch.zeros_like(data).to('cuda')
    
    for _, Ay in subspaces:
        #print(data.shape, torch.linalg.pinv(Ay.T).shape)
        if Y_reconstruct is None:
            Y_reconstruct = data @ torch.linalg.pinv(Ay.T)
        else:
            Y_reconstruct += data @ torch.linalg.pinv(Ay.T)
    
    return Y_reconstruct / len(subspaces)


def center(X: torch.Tensor):
    mean = X.mean(dim=0, keepdim=True)
    return X - mean, mean

def normalize(X: torch.Tensor):
    norms = torch.norm(X, p=2, dim=1, keepdim=True)
    norms = torch.clamp(norms, min=1e-8)  # evita divisione per zero
    return X / norms, norms.expand_as(X)

def pad(x: torch.Tensor, M: int) -> torch.Tensor:
    N, D = x.shape
    if M < D:
        raise ValueError(f"M={M} must be >= D={D}")
    
    pad = (0, M - D)  
    return F.pad(x, pad, mode='constant', value=0.0)

In [5]:
X, Y = get_data(data_path)

X = X.to(device)
Y = Y.to(device)

X = pad(X, Y.size(1))

X_centered, X_center = center(X)
Y_centered, Y_center = center(Y)

X_normalized, X_norm = normalize(X_centered)
Y_normalized, Y_norm = normalize(Y_centered)

In [6]:
subspaces = anchor_subspaces(X_normalized, Y_normalized, Y.size(1), 1, 0.65)

In [7]:
X_reconstruct = reconstruct_source(X_normalized, subspaces)
analyze_tensor_distances(X_reconstruct * X_norm + X_center, X)

Y_reconstruct = reconstruct_target(Y_normalized, subspaces)
analyze_tensor_distances(Y_reconstruct * Y_norm + Y_center, Y)

----------------------------------------
üìè Mean Euclidean distance (per row): 0.0024
----------------------------------------
üìê Mean Cosine similarity (per row): 1.0000
----------------------------------------

----------------------------------------
üìè Mean Euclidean distance (per row): 0.2150
----------------------------------------
üìê Mean Cosine similarity (per row): 1.0000
----------------------------------------



In [None]:
import pickle as pk

with open('data.pkl', 'wb') as f:
    data = {
        'X_norm': X_norm,
        'Y_norm': Y_norm,
        'X_center': X_center,
        'Y_center': Y_center,
        'subspaces': subspaces
    }
    pk.dump(data, f)

In [8]:
Ax_batch = torch.stack([Ax for Ax, Ay in subspaces], dim=0)  # [num_subspaces x d x f]
Ay_batch = torch.stack([Ay for Ax, Ay in subspaces], dim=0)  # [num_subspaces x d x f]

X_proj = torch.einsum('nf,sdf->snd', X_normalized, Ax_batch)
Y_proj = torch.einsum('nf,sdf->snd', Y_normalized, Ay_batch)

X_rel = X_proj.mean(dim=0)  # [n x d]
Y_rel = Y_proj.mean(dim=0)  # [n x d]

analyze_tensor_distances(X_rel, Y_rel)

lambda_reg = 0.05
d = X_rel.shape[1]
I = torch.eye(d, device=X_rel.device)
W = torch.linalg.solve(X_rel.T @ X_rel + lambda_reg*I, X_rel.T @ Y_rel)

analyze_tensor_distances(X_rel @ W, Y_rel)

----------------------------------------
üìè Mean Euclidean distance (per row): 5.7006
----------------------------------------
üìê Mean Cosine similarity (per row): 0.2416
----------------------------------------

----------------------------------------
üìè Mean Euclidean distance (per row): 2.6045
----------------------------------------
üìê Mean Cosine similarity (per row): 0.6623
----------------------------------------



In [19]:
data_changed_space = (X_rel @ W)
translation = reconstruct_target_from_data(data_changed_space, subspaces) * Y_norm + Y_center
analyze_tensor_distances(translation, Y)
analyze_tensor_distances(F.normalize(translation, dim=1), F.normalize(Y, dim=1))

----------------------------------------
üìè Mean Euclidean distance (per row): 14.5563
----------------------------------------
üìê Mean Cosine similarity (per row): 0.8255
----------------------------------------

----------------------------------------
üìè Mean Euclidean distance (per row): 0.5858
----------------------------------------
üìê Mean Cosine similarity (per row): 0.8255
----------------------------------------



In [20]:
gt_indices = torch.arange(len(Y_normalized))
evaluate_retrieval(translation.cpu().numpy(), Y.cpu().numpy(), gt_indices)

{'mrr': 0.3435692022718723,
 'ndcg': 0.4919091671844142,
 'recall_at_1': 0.137784,
 'recall_at_3': 0.415616,
 'recall_at_5': 0.693944,
 'recall_at_10': 0.840224,
 'recall_at_50': 0.98788,
 'l2_dist': 14.556254386901855}

In [None]:
# test_data = np.load(test_path)
# sample_ids = test_data['captions/ids']
# test_embds = test_data['captions/embeddings']
# test_embds = torch.from_numpy(test_embds).float()

# test_data.close()

In [None]:
# test_embds = test_embds.to(device)
# test_embds = pad(test_embds, Y.size(1))
# test_embds, test_center = center(test_embds)
# test_normalized, test_norm = normalize(test_embds)

# print(test_normalized.shape)