### Model definition

In [2]:
from torch import nn
import torch

class Translator(nn.Module):
    def __init__(self, pad: bool, dim_imgs: int = 1536, dim_text: int = 1024,  mode: str ='linear'):
        super().__init__()
        assert mode in ['linear', 'affine', 'isometry'], f'Mode "{mode}" not supported'

        self.mode = mode
        use_bias = mode == 'affine'
        if pad:
            dim = max(dim_imgs, dim_text)
            self.linear = nn.Linear(dim, dim, bias=use_bias)

        else:
            self.linear = nn.Linear(dim_text, dim_imgs, bias=use_bias)

    def forward(self, x):
        return self.linear(x)

    @torch.no_grad()
    def orthogonalize(self):
        assert self.mode == 'isometry', 'Cannot be called for modes != isometry'

        W = self.linear.weight.data
        U, _, Vh = torch.linalg.svd(W, full_matrices=False)
        self.linear.weight.data.copy_(U @ Vh)


### Metrics functions

In [3]:
import numpy as np
import torch

'''Code from https://github.com/Mamiglia/challenge'''

def mrr(pred_indices: np.ndarray, gt_indices: np.ndarray) -> float:
    """
    Compute Mean Reciprocal Rank (MRR)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries (top-K)
        gt_indices: (N,) array of ground truth indices
    Returns:
        mrr: Mean Reciprocal Rank
    """
    reciprocal_ranks = []
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i] == gt_indices[i])[0]
        if matches.size > 0:
            reciprocal_ranks.append(1.0 / (matches[0] + 1))
        else:
            reciprocal_ranks.append(0.0)
    return np.mean(reciprocal_ranks)


def recall_at_k(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int) -> float:
    """Compute Recall@k
    Args:
        pred_indices: (N, N) array of top indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        recall: Recall@k
    """
    recall = 0
    for i in range(len(gt_indices)):
        if gt_indices[i] in pred_indices[i, :k]:
            recall += 1
    recall /= len(gt_indices)
    return recall

import numpy as np

def ndcg(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int = 100) -> float:
    """
    Compute Normalized Discounted Cumulative Gain (NDCG@k)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        ndcg: NDCG@k
    """
    ndcg_total = 0.0
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i, :k] == gt_indices[i])[0]
        if matches.size > 0:
            rank = matches[0] + 1
            ndcg_total += 1.0 / np.log2(rank + 1)  # DCG (IDCG = 1)
    return ndcg_total / len(gt_indices)



@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """Evaluate retrieval performance using cosine similarity
    Args:
        translated_embd: (N_captions, D) translated caption embeddings
        image_embd: (N_images, D) image embeddings
        gt_indices: (N_captions,) ground truth image indices for each caption
        max_indices: number of top predictions to consider
    Returns:
        results: dict of evaluation metrics
    
    """
    # Compute similarity matrix
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()
    
    n_queries = translated_embd.shape[0]
    device = translated_embd.device
    
    # Prepare containers for the fragments to be reassembled
    all_sorted_indices = []
    l2_distances = []
    
    # Process in batches - the narrow gate approach
    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]
        batch_img_embd = image_embd[batch_slice]
        
        # Compute similarity only for this batch
        batch_similarity = batch_translated @ batch_img_embd.T

        # Get top-k predictions for this batch
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.numpy()
        all_sorted_indices.append(gt_indices[batch_slice][batch_indices])

        # Compute L2 distance for this batch
        batch_gt = gt_indices[batch_slice]
        batch_gt_embeddings = image_embd[batch_gt]
        batch_l2 = (batch_translated - batch_gt_embeddings).norm(dim=1)
        l2_distances.append(batch_l2)
    
    # Reassemble the fragments
    sorted_indices = np.concatenate(all_sorted_indices, axis=0)
    
    # Apply the sacred metrics to the whole
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }
    
    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }
    
    l2_dist = torch.cat(l2_distances, dim=0).mean().item()
    results['l2_dist'] = l2_dist
    
    return results

### Training functions

In [9]:

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from pathlib import Path
from tqdm import tqdm

def pad_and_standardize(data: np.array, pad: bool, pad_val: int) -> torch.Tensor:
    data_torch = torch.from_numpy(data).float()
    if pad:
        data_torch = F.pad(data_torch, (0, pad_val), mode="constant", value=0)

    mean = data_torch.mean(dim=0, keepdim=True)
    std = data_torch.std(dim=0, keepdim=True) + 1e-8
    data_standardized = (data_torch - mean) / std

    return data_standardized


def preprocess(X_abs: np.array, Y_abs: np.array, pad: bool, normalize: bool=True) -> tuple[torch.Tensor, torch.Tensor]:
    assert X_abs.ndim == 2 and Y_abs.ndim == 2, "Both data must be 2D"

    x_pad = max(Y_abs.shape[1] - X_abs.shape[1], 0)
    y_pad = max(X_abs.shape[1] - Y_abs.shape[1], 0)

    X_pre = pad_and_standardize(X_abs, pad, x_pad)
    Y_pre = pad_and_standardize(Y_abs, pad, y_pad)

    if normalize:
        X_pre = F.normalize(X_pre, dim=1)
        Y_pre = F.normalize(Y_pre, dim=1)

    return X_pre, Y_pre


def train_model(model_path: Path, mode: str, 
                train_loader: DataLoader, val_loader: DataLoader,
                pad: bool, epochs: int, lr: float) -> Translator:
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Using device: {device}")

    model = Translator(pad=pad, mode=mode).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    best_val_loss = float('inf')

    for epoch in range(epochs):
        model.train()

        train_loss = 0
        for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()

            outputs = model(X_batch)

            loss = 1 - F.cosine_similarity(outputs, y_batch, dim=1).mean()
            #loss = F.mse_loss(outputs, y_batch)

            loss.backward()

            optimizer.step()

            if mode == 'isometry':
                model.orthogonalize()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()

        val_loss = 0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                outputs = model(X_batch)

                loss = 1 - F.cosine_similarity(outputs, y_batch, dim=1).mean()
                #loss = F.mse_loss(outputs, y_batch)

                val_loss += loss.item()

        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss

            Path(model_path).parent.mkdir(parents=True, exist_ok=True)

            torch.save(model.state_dict(), model_path)

            print(f"✓ Saved best model (val_loss={val_loss:.6f})")

    return model

def eval_on_val(X_val: torch.Tensor, y_val: torch.Tensor, pad: bool, 
                normalize: bool, model = None, 
                model_path: Path = None) -> dict:
    gt_indices = torch.arange(len(y_val))
    
    X, y = preprocess(X_val, y_val, pad, normalize)
    model = Translator(pad=pad, mode='linear')

    if model_path:
        state = torch.load(model_path)
        model.load_state_dict(state)
        
    model.eval()

    with torch.inference_mode():
        translated = model(X)

    results = evaluate_retrieval(translated, y, gt_indices)
    
    return results

### Load training data

In [5]:
data = np.load(Path('data/train/train.npz'))
caption_embeddings = data['captions/embeddings']
image_embeddings = data['images/embeddings']
caption_labels = data['captions/label']


### Train

In [6]:
batch_size = 512
epochs = 15
lr = 0.0005

X_abs = caption_embeddings # captions space
y_abs = image_embeddings[np.argmax(caption_labels, axis=1)] # images space

X, y = preprocess(X_abs, y_abs, pad=False, normalize=False)


n_train = int(0.9 * len(X))
train_split = torch.zeros(len(X), dtype=torch.bool)
train_split[:n_train] = 1

X_train, X_val = X[train_split], X[~train_split]
y_train, y_val = y[train_split], y[~train_split]

print(X_train.shape, X_val.shape)
print(y_train.shape, y_val.shape)


train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)


model = train_model('models/exp1.pth', 'affine', train_loader, val_loader, False,epochs, lr)


torch.Size([112500, 1024]) torch.Size([12500, 1024])
torch.Size([112500, 1536]) torch.Size([12500, 1536])
Using device: cuda


Epoch 1/15: 100%|██████████| 220/220 [00:02<00:00, 83.42it/s] 


Epoch 1: Train Loss = 0.618765, Val Loss = 0.581936
✓ Saved best model (val_loss=0.581936)


Epoch 2/15: 100%|██████████| 220/220 [00:01<00:00, 131.50it/s]


Epoch 2: Train Loss = 0.560854, Val Loss = 0.566835
✓ Saved best model (val_loss=0.566835)


Epoch 3/15: 100%|██████████| 220/220 [00:01<00:00, 129.89it/s]


Epoch 3: Train Loss = 0.548437, Val Loss = 0.560744
✓ Saved best model (val_loss=0.560744)


Epoch 4/15: 100%|██████████| 220/220 [00:01<00:00, 121.82it/s]


Epoch 4: Train Loss = 0.542065, Val Loss = 0.557830
✓ Saved best model (val_loss=0.557830)


Epoch 5/15: 100%|██████████| 220/220 [00:01<00:00, 129.01it/s]


Epoch 5: Train Loss = 0.538090, Val Loss = 0.555744
✓ Saved best model (val_loss=0.555744)


Epoch 6/15: 100%|██████████| 220/220 [00:01<00:00, 129.90it/s]


Epoch 6: Train Loss = 0.535366, Val Loss = 0.554358
✓ Saved best model (val_loss=0.554358)


Epoch 7/15: 100%|██████████| 220/220 [00:01<00:00, 126.69it/s]


Epoch 7: Train Loss = 0.533368, Val Loss = 0.553725
✓ Saved best model (val_loss=0.553725)


Epoch 8/15: 100%|██████████| 220/220 [00:01<00:00, 120.25it/s]


Epoch 8: Train Loss = 0.531854, Val Loss = 0.553100
✓ Saved best model (val_loss=0.553100)


Epoch 9/15: 100%|██████████| 220/220 [00:01<00:00, 127.38it/s]


Epoch 9: Train Loss = 0.530674, Val Loss = 0.552447
✓ Saved best model (val_loss=0.552447)


Epoch 10/15: 100%|██████████| 220/220 [00:01<00:00, 122.43it/s]


Epoch 10: Train Loss = 0.529700, Val Loss = 0.551818
✓ Saved best model (val_loss=0.551818)


Epoch 11/15: 100%|██████████| 220/220 [00:01<00:00, 132.29it/s]


Epoch 11: Train Loss = 0.528884, Val Loss = 0.551753
✓ Saved best model (val_loss=0.551753)


Epoch 12/15: 100%|██████████| 220/220 [00:01<00:00, 131.86it/s]


Epoch 12: Train Loss = 0.528271, Val Loss = 0.551557
✓ Saved best model (val_loss=0.551557)


Epoch 13/15: 100%|██████████| 220/220 [00:01<00:00, 130.98it/s]


Epoch 13: Train Loss = 0.527630, Val Loss = 0.551278
✓ Saved best model (val_loss=0.551278)


Epoch 14/15: 100%|██████████| 220/220 [00:01<00:00, 122.04it/s]


Epoch 14: Train Loss = 0.527160, Val Loss = 0.551113
✓ Saved best model (val_loss=0.551113)


Epoch 15/15: 100%|██████████| 220/220 [00:01<00:00, 129.04it/s]


Epoch 15: Train Loss = 0.526697, Val Loss = 0.551160


# Test

In [10]:
results = eval_on_val(X_val.numpy(), y_val.numpy(), pad=False, normalize=False, model=model)
print("Test Results:", results)

Test Results: {'mrr': 0.05214059919609994, 'ndcg': 0.20843233014038676, 'recall_at_1': 0.00984, 'recall_at_3': 0.03064, 'recall_at_5': 0.05048, 'recall_at_10': 0.1028, 'recall_at_50': 0.49816, 'l2_dist': 45.048789978027344}
