In [8]:
import torch
import torch.nn.functional as F
import numpy as np

from torch import nn
from typing import Optional, Literal
from torch.utils.data import TensorDataset, DataLoader
from pathlib import Path
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

Using device: cuda


**Model definition**

In [9]:
class Translator(nn.Module):
    def __init__(self, input_dim=1024, output_dim=1536, mode='affine', use_relative=False, anchors: Optional[torch.Tensor] = None):
        super().__init__()
        assert mode in ['linear', 'affine', 'isometry'], f'Mode "{mode}" not supported'
        assert input_dim > 0 and output_dim > 0, "Expecting positive dimensions"
        assert not use_relative or isinstance(anchors, torch.Tensor) , 'Anchors must be set if using relative representations'
        assert anchors is None or (anchors.ndim == 2 and anchors.shape[0] > 0), '2D Anchors must be provided if using relative representations'
        
        self.mode = mode
        self.use_relative = use_relative
        self.anchors = anchors
        
        self.linear = nn.Linear(
            anchors.shape[0] if self.use_relative else input_dim,
            output_dim,
            bias=self.mode == 'affine'
        )
    
    def compute_relative(self, x):
        assert self.anchors is not None, 'Anchors must be set by calling "set_anchors"'
        
        return F.normalize(x, p=2, dim=1) @ F.normalize(self.anchors.T)
        
    def forward(self, x):
        if self.use_relative:
            x = self.compute_relative(x)
        
        return self.linear(x)

**Metrics functions**

In [10]:
'''Code from https://github.com/Mamiglia/challenge'''

def mrr(pred_indices: np.ndarray, gt_indices: np.ndarray) -> float:
    """
    Compute Mean Reciprocal Rank (MRR)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries (top-K)
        gt_indices: (N,) array of ground truth indices
    Returns:
        mrr: Mean Reciprocal Rank
    """
    reciprocal_ranks = []
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i] == gt_indices[i])[0]
        if matches.size > 0:
            reciprocal_ranks.append(1.0 / (matches[0] + 1))
        else:
            reciprocal_ranks.append(0.0)
    return np.mean(reciprocal_ranks)


def recall_at_k(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int) -> float:
    """Compute Recall@k
    Args:
        pred_indices: (N, N) array of top indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        recall: Recall@k
    """
    recall = 0
    for i in range(len(gt_indices)):
        if gt_indices[i] in pred_indices[i, :k]:
            recall += 1
    recall /= len(gt_indices)
    return recall

import numpy as np

def ndcg(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int = 100) -> float:
    """
    Compute Normalized Discounted Cumulative Gain (NDCG@k)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        ndcg: NDCG@k
    """
    ndcg_total = 0.0
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i, :k] == gt_indices[i])[0]
        if matches.size > 0:
            rank = matches[0] + 1
            ndcg_total += 1.0 / np.log2(rank + 1)  # DCG (IDCG = 1)
    return ndcg_total / len(gt_indices)



@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """Evaluate retrieval performance using cosine similarity
    Args:
        translated_embd: (N_captions, D) translated caption embeddings
        image_embd: (N_images, D) image embeddings
        gt_indices: (N_captions,) ground truth image indices for each caption
        max_indices: number of top predictions to consider
    Returns:
        results: dict of evaluation metrics
    
    """
    # Compute similarity matrix
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()
    
    n_queries = translated_embd.shape[0]
    device = translated_embd.device
    
    # Prepare containers for the fragments to be reassembled
    all_sorted_indices = []
    l2_distances = []
    
    # Process in batches - the narrow gate approach
    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]
        batch_img_embd = image_embd[batch_slice]
        
        # Compute similarity only for this batch
        batch_similarity = batch_translated @ batch_img_embd.T

        # Get top-k predictions for this batch
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.numpy()
        all_sorted_indices.append(gt_indices[batch_slice][batch_indices])

        # Compute L2 distance for this batch
        batch_gt = gt_indices[batch_slice]
        batch_gt_embeddings = image_embd[batch_gt]
        batch_l2 = (batch_translated - batch_gt_embeddings).norm(dim=1)
        l2_distances.append(batch_l2)
    
    # Reassemble the fragments
    sorted_indices = np.concatenate(all_sorted_indices, axis=0)
    
    # Apply the sacred metrics to the whole
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }
    
    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }
    
    l2_dist = torch.cat(l2_distances, dim=0).mean().item()
    results['l2_dist'] = l2_dist
    
    return results

**Training functions**

In [11]:
def train_model(model: Translator, model_path: Path, train_loader: DataLoader, val_loader: DataLoader, epochs: int, lr: float) -> Translator:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()

        train_loss = 0
        for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)

            optimizer.zero_grad()

            outputs = model(X_batch)

            #loss = 1 - F.cosine_similarity(outputs, y_batch, dim=1).mean()
            loss = F.mse_loss(outputs, y_batch)

            loss.backward()

            optimizer.step()

            if model_args.get('mode', None) == 'isometry':
                model.orthogonalize()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()

        val_loss = 0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
                outputs = model(X_batch)

                #loss = 1 - F.cosine_similarity(outputs, y_batch, dim=1).mean()
                loss = F.mse_loss(outputs, y_batch)

                val_loss += loss.item()

        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss

            Path(model_path).parent.mkdir(parents=True, exist_ok=True)

            torch.save(model.state_dict(), model_path)

            print(f"✓ Saved best model (val_loss={val_loss:.6f})")


def eval_on_val(X_val: np.ndarray, y_val: np.ndarray, model = None, model_args: dict = {}, model_path: Path = None) -> dict:
    gt_indices = torch.arange(len(y_val))
    
    if model_path:
        model = Translator(**model_args)
        state = torch.load(model_path)
        
        model.load_state_dict(state)
        
    model.eval()

    with torch.inference_mode():
        translated = model(torch.from_numpy(X_val).to(DEVICE))
    
    return evaluate_retrieval(translated.cpu().numpy(), y_val, gt_indices)


**Load training data**

In [12]:
data = np.load(Path('/kaggle/input/aml-competition/train/train/train.npz'))
caption_embeddings = data['captions/embeddings']
image_embeddings = data['images/embeddings']
caption_labels = data['captions/label']

**Anchors choice**

In [13]:
def extract_anchors(data: torch.Tensor, method: Literal['pca', 'k-means', 'random'], anchors_number: int):
    assert isinstance(data, torch.Tensor) and data.ndim == 2 and data.shape[0] > 0, "Expected a valid tensor"
    assert method in ['pca', 'k-means', 'random'], f'Method {method} not supported'
    assert isinstance(anchors_number, int) and anchors_number > 0, "Expected a natural positive number"

    data_np = data.cpu().numpy()

    if method == 'pca':
        # PCA already returns normalized anchors
        pca = PCA(n_components=anchors_number)
        pca.fit(data_np)
        
        anchors = torch.from_numpy(pca.components_).float()
    elif method == 'k-means':
        kmeans = KMeans(n_clusters=anchors_number, init='k-means++', n_init=10, random_state=42)
        kmeans.fit(data_np)
        
        anchors = torch.from_numpy(kmeans.cluster_centers_).float()
    else:
        anchors = data[torch.randperm(data.size(0))[:anchors_number]]

    return anchors

In [26]:
batch_size = 256
epochs = 30
lr = 0.001
anchors_number = 350

X_abs = torch.from_numpy(caption_embeddings) # captions space
y_abs = torch.from_numpy(image_embeddings[np.argmax(caption_labels, axis=1)]) # images space

X_anchors = extract_anchors(X_abs, 'pca', anchors_number).to(DEVICE)
# X_anchors = None

print('Texts shape', X_abs.shape)
print('Images shape', X_abs.shape)
print('Anchors shape', X_anchors.shape if X_anchors is not None else '')

n_train = int(0.9 * X_abs.shape[0])
train_split = torch.zeros(X_abs.shape[0], dtype=torch.bool)
train_split[:n_train] = 1

X_train, X_val = X_abs[train_split], X_abs[~train_split]
y_train, y_val = y_abs[train_split], y_abs[~train_split]

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

model_args = {
    'input_dim': X_train.shape[1],
    'output_dim': y_train.shape[1],
    'mode': 'affine',
    'use_relative': True,
    'anchors': X_anchors
}

model = Translator(**model_args).to(DEVICE)

train_model(model, 'models/exp1.pth', train_loader, val_loader, epochs, lr)

Texts shape torch.Size([125000, 1024])
Images shape torch.Size([125000, 1024])
Anchors shape torch.Size([350, 1024])


Epoch 1/30: 100%|██████████| 440/440 [00:01<00:00, 251.20it/s]


Epoch 1: Train Loss = 0.258852, Val Loss = 0.194873
✓ Saved best model (val_loss=0.194873)


Epoch 2/30: 100%|██████████| 440/440 [00:01<00:00, 246.19it/s]


Epoch 2: Train Loss = 0.173934, Val Loss = 0.164740
✓ Saved best model (val_loss=0.164740)


Epoch 3/30: 100%|██████████| 440/440 [00:01<00:00, 232.70it/s]


Epoch 3: Train Loss = 0.157544, Val Loss = 0.156923
✓ Saved best model (val_loss=0.156923)


Epoch 4/30: 100%|██████████| 440/440 [00:01<00:00, 250.94it/s]


Epoch 4: Train Loss = 0.152834, Val Loss = 0.154366
✓ Saved best model (val_loss=0.154366)


Epoch 5/30: 100%|██████████| 440/440 [00:01<00:00, 248.31it/s]


Epoch 5: Train Loss = 0.151044, Val Loss = 0.153176
✓ Saved best model (val_loss=0.153176)


Epoch 6/30: 100%|██████████| 440/440 [00:01<00:00, 234.58it/s]


Epoch 6: Train Loss = 0.150098, Val Loss = 0.152499
✓ Saved best model (val_loss=0.152499)


Epoch 7/30: 100%|██████████| 440/440 [00:01<00:00, 248.95it/s]


Epoch 7: Train Loss = 0.149462, Val Loss = 0.151989
✓ Saved best model (val_loss=0.151989)


Epoch 8/30: 100%|██████████| 440/440 [00:01<00:00, 255.28it/s]


Epoch 8: Train Loss = 0.149024, Val Loss = 0.151612
✓ Saved best model (val_loss=0.151612)


Epoch 9/30: 100%|██████████| 440/440 [00:01<00:00, 230.22it/s]


Epoch 9: Train Loss = 0.148674, Val Loss = 0.151332
✓ Saved best model (val_loss=0.151332)


Epoch 10/30: 100%|██████████| 440/440 [00:01<00:00, 251.26it/s]


Epoch 10: Train Loss = 0.148409, Val Loss = 0.151079
✓ Saved best model (val_loss=0.151079)


Epoch 11/30: 100%|██████████| 440/440 [00:01<00:00, 254.22it/s]


Epoch 11: Train Loss = 0.148179, Val Loss = 0.150918
✓ Saved best model (val_loss=0.150918)


Epoch 12/30: 100%|██████████| 440/440 [00:01<00:00, 232.82it/s]


Epoch 12: Train Loss = 0.147986, Val Loss = 0.150761
✓ Saved best model (val_loss=0.150761)


Epoch 13/30: 100%|██████████| 440/440 [00:01<00:00, 253.14it/s]


Epoch 13: Train Loss = 0.147834, Val Loss = 0.150592
✓ Saved best model (val_loss=0.150592)


Epoch 14/30: 100%|██████████| 440/440 [00:01<00:00, 251.98it/s]


Epoch 14: Train Loss = 0.147682, Val Loss = 0.150478
✓ Saved best model (val_loss=0.150478)


Epoch 15/30: 100%|██████████| 440/440 [00:01<00:00, 252.34it/s]


Epoch 15: Train Loss = 0.147548, Val Loss = 0.150342
✓ Saved best model (val_loss=0.150342)


Epoch 16/30: 100%|██████████| 440/440 [00:01<00:00, 235.76it/s]


Epoch 16: Train Loss = 0.147441, Val Loss = 0.150244
✓ Saved best model (val_loss=0.150244)


Epoch 17/30: 100%|██████████| 440/440 [00:01<00:00, 251.75it/s]


Epoch 17: Train Loss = 0.147325, Val Loss = 0.150155
✓ Saved best model (val_loss=0.150155)


Epoch 18/30: 100%|██████████| 440/440 [00:01<00:00, 251.30it/s]


Epoch 18: Train Loss = 0.147243, Val Loss = 0.150061
✓ Saved best model (val_loss=0.150061)


Epoch 19/30: 100%|██████████| 440/440 [00:01<00:00, 233.56it/s]


Epoch 19: Train Loss = 0.147158, Val Loss = 0.149972
✓ Saved best model (val_loss=0.149972)


Epoch 20/30: 100%|██████████| 440/440 [00:01<00:00, 255.81it/s]


Epoch 20: Train Loss = 0.147073, Val Loss = 0.149905
✓ Saved best model (val_loss=0.149905)


Epoch 21/30: 100%|██████████| 440/440 [00:01<00:00, 249.05it/s]


Epoch 21: Train Loss = 0.147010, Val Loss = 0.149827
✓ Saved best model (val_loss=0.149827)


Epoch 22/30: 100%|██████████| 440/440 [00:01<00:00, 228.72it/s]


Epoch 22: Train Loss = 0.146940, Val Loss = 0.149783
✓ Saved best model (val_loss=0.149783)


Epoch 23/30: 100%|██████████| 440/440 [00:01<00:00, 252.90it/s]


Epoch 23: Train Loss = 0.146895, Val Loss = 0.149728
✓ Saved best model (val_loss=0.149728)


Epoch 24/30: 100%|██████████| 440/440 [00:01<00:00, 254.27it/s]


Epoch 24: Train Loss = 0.146835, Val Loss = 0.149709
✓ Saved best model (val_loss=0.149709)


Epoch 25/30: 100%|██████████| 440/440 [00:01<00:00, 233.87it/s]


Epoch 25: Train Loss = 0.146802, Val Loss = 0.149645
✓ Saved best model (val_loss=0.149645)


Epoch 26/30: 100%|██████████| 440/440 [00:01<00:00, 253.10it/s]


Epoch 26: Train Loss = 0.146752, Val Loss = 0.149600
✓ Saved best model (val_loss=0.149600)


Epoch 27/30: 100%|██████████| 440/440 [00:01<00:00, 257.42it/s]


Epoch 27: Train Loss = 0.146715, Val Loss = 0.149567
✓ Saved best model (val_loss=0.149567)


Epoch 28/30: 100%|██████████| 440/440 [00:01<00:00, 233.13it/s]


Epoch 28: Train Loss = 0.146691, Val Loss = 0.149531
✓ Saved best model (val_loss=0.149531)


Epoch 29/30: 100%|██████████| 440/440 [00:01<00:00, 254.19it/s]


Epoch 29: Train Loss = 0.146656, Val Loss = 0.149521
✓ Saved best model (val_loss=0.149521)


Epoch 30/30: 100%|██████████| 440/440 [00:01<00:00, 249.25it/s]


Epoch 30: Train Loss = 0.146647, Val Loss = 0.149476
✓ Saved best model (val_loss=0.149476)


In [27]:
results = eval_on_val(X_val.numpy(), y_val.numpy(), model=model, model_args=model_args)
print("Test Results:", results)

Test Results: {'mrr': 0.31603659195619344, 'ndcg': 0.4671519153719432, 'recall_at_1': 0.1248, 'recall_at_3': 0.37264, 'recall_at_5': 0.62152, 'recall_at_10': 0.7832, 'recall_at_50': 0.978, 'l2_dist': 15.000469207763672}
