### model.py

In [43]:
import torch
from typing import Optional
from torch import nn
from torch.nn import functional as F

class Translator(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        dir_hidden_dims: list[int],
        scale_hidden_dims: list[int],
        activation=nn.ReLU,
        dropout_rate: float=0.3
    ):
        super().__init__()

        def build_mlp(hidden_dims, out_dim, apply_softplus=False):
            layers = []
            last_dim = input_dim
            for hidden in hidden_dims:
                layers += [
                    nn.Linear(last_dim, hidden),
                    activation(),
                    #nn.LayerNorm(hidden),
                    nn.Dropout(dropout_rate)
                ]
                last_dim = hidden
            layers.append(nn.Linear(last_dim, out_dim))
            
            if apply_softplus:
                layers.append(nn.Softplus())
            
            return nn.Sequential(*layers)

        self.dir_head = build_mlp(dir_hidden_dims, output_dim, apply_softplus=False)
        self.scale_head = build_mlp(scale_hidden_dims, 1, apply_softplus=True)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x):
        direction = F.normalize(self.dir_head(x), dim=-1)
        scale = self.scale_head(x)
        return direction * scale



def procrustes_align(X, Y, scale=True):
    mu_X = X.mean(dim=0, keepdim=True)
    mu_Y = Y.mean(dim=0, keepdim=True)

    X_centered = X - mu_X
    Y_centered = Y - mu_Y

    C = X_centered.T @ Y_centered

    U, S, Vt = torch.linalg.svd(C, full_matrices=True)
    R = U @ Vt

    if torch.det(R) < 0:
        Vt[-1, :] *= -1
        R = U @ Vt

    if scale:
        s = S.sum() / (X_centered ** 2).sum()
    else:
        s = 1.0

    t = mu_Y.squeeze() - s * (mu_X @ R)

    return R, s, t

def align_matrix(m: torch.Tensor, R, s, t):
    return s * (m @ R) + t

### eval.py

In [62]:
from pathlib import Path
import numpy as np
import torch
import pandas as pd

'''Code from https://github.com/Mamiglia/challenge'''

def mrr(pred_indices: np.ndarray, gt_indices: np.ndarray) -> float:
    """
    Compute Mean Reciprocal Rank (MRR)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries (top-K)
        gt_indices: (N,) array of ground truth indices
    Returns:
        mrr: Mean Reciprocal Rank
    """
    reciprocal_ranks = []
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i] == gt_indices[i])[0]
        if matches.size > 0:
            reciprocal_ranks.append(1.0 / (matches[0] + 1))
        else:
            reciprocal_ranks.append(0.0)
    return np.mean(reciprocal_ranks)


def recall_at_k(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int) -> float:
    """Compute Recall@k
    Args:
        pred_indices: (N, N) array of top indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        recall: Recall@k
    """
    recall = 0
    for i in range(len(gt_indices)):
        if gt_indices[i] in pred_indices[i, :k]:
            recall += 1
    recall /= len(gt_indices)
    return recall

import numpy as np

def ndcg(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int = 100) -> float:
    """
    Compute Normalized Discounted Cumulative Gain (NDCG@k)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        ndcg: NDCG@k
    """
    ndcg_total = 0.0
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i, :k] == gt_indices[i])[0]
        if matches.size > 0:
            rank = matches[0] + 1
            ndcg_total += 1.0 / np.log2(rank + 1)  # DCG (IDCG = 1)
    return ndcg_total / len(gt_indices)



@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """Evaluate retrieval performance using cosine similarity
    Args:
        translated_embd: (N_captions, D) translated caption embeddings
        image_embd: (N_images, D) image embeddings
        gt_indices: (N_captions,) ground truth image indices for each caption
        max_indices: number of top predictions to consider
    Returns:
        results: dict of evaluation metrics
    
    """
    # Compute similarity matrix
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()
    
    n_queries = translated_embd.shape[0]
    device = translated_embd.device
    
    # Prepare containers for the fragments to be reassembled
    all_sorted_indices = []
    l2_distances = []
    
    # Process in batches - the narrow gate approach
    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]
        batch_img_embd = image_embd[batch_slice]
        
        # Compute similarity only for this batch
        batch_similarity = batch_translated @ batch_img_embd.T

        # Get top-k predictions for this batch
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.numpy()
        all_sorted_indices.append(gt_indices[batch_slice][batch_indices])

        # Compute L2 distance for this batch
        batch_gt = gt_indices[batch_slice]
        batch_gt_embeddings = image_embd[batch_gt]
        batch_l2 = (batch_translated - batch_gt_embeddings).norm(dim=1)
        l2_distances.append(batch_l2)
    
    # Reassemble the fragments
    sorted_indices = np.concatenate(all_sorted_indices, axis=0)
    
    # Apply the sacred metrics to the whole
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }
    
    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }
    
    l2_dist = torch.cat(l2_distances, dim=0).mean().item()
    results['l2_dist'] = l2_dist
    
    return results

def eval_on_val(x_val: np.ndarray, y_val: np.ndarray, model: Translator, device) -> dict:
    gt_indices = torch.arange(len(y_val))
    
    model.eval()

    with torch.inference_mode():
        translated = model(x_val.to(device)).to('cpu')

    results = evaluate_retrieval(translated, y_val, gt_indices)
    
    return results

def generate_submission(model: Translator, test_path: Path, output_file="submission.csv", device=None, procrustes_data: tuple=()):
    test_data = np.load(test_path)
    sample_ids = test_data['captions/ids']
    test_embds = test_data['captions/embeddings']
    test_embds = torch.from_numpy(test_embds).float()
    test_data.close()

    with torch.no_grad():
        pred_embds = model(test_embds.to(device)).cpu()

    if procrustes_data:
        print('Using Procrusted')
        pred_embds = align_matrix(pred_embds, R, s, t)

    print("Generating submission file...")

    if isinstance(pred_embds, torch.Tensor):
        pred_embds = pred_embds.cpu().numpy()

    df_submission = pd.DataFrame({'id': sample_ids, 'embedding': pred_embds.tolist()})

    df_submission.to_csv(output_file, index=False, float_format='%.17g')
    print(f"✓ Saved submission to {output_file}")

    return df_submission

### configs

In [26]:

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
from pathlib import Path
from tqdm import tqdm

def info_nce_loss(dir_preds, img_targets, temp: float):
    logits = (dir_preds @ img_targets.T) / temp
    labels = torch.arange(logits.size(0), device=logits.device)
    #loss = F.cross_entropy(logits, labels)    
    #return loss
    loss_t2i = F.cross_entropy(logits, labels)          
    loss_i2t = F.cross_entropy(logits.T, labels)        
    return 0.5 * (loss_t2i + loss_i2t)


def train_model(
    model: Translator,
    model_path: Path,
    train_dataset: TensorDataset,
    val_dataset: TensorDataset,
    batch_size: int,
    epochs: int,
    lr: float,
    patience: int
) -> Translator:    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Using device: {device}")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    best_val_loss = float('inf')
    no_improvements = 0

    for epoch in range(epochs):
        model.train()

        train_loss = 0
        for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_batch = F.normalize(y_batch, dim=-1)

            optimizer.zero_grad()

            outputs = model(X_batch)

            loss = info_nce_loss(outputs, y_batch, temp=model.logit_scale)

            loss.backward()

            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()

        val_loss = 0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                outputs = model(X_batch)

                loss = info_nce_loss(outputs, y_batch, temp=model.logit_scale)

                val_loss += loss.item()

        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
        test(val_dataset, model, device)


        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improvements = 0

            Path(model_path).parent.mkdir(parents=True, exist_ok=True)

            torch.save(model.state_dict(), model_path)

            print(f"✓ Saved best model (val_loss={val_loss:.6f})")
        elif no_improvements >= patience:
            return model
        else:
            no_improvements += 1

    return model



def load_data(data_path: Path):
    data = np.load(data_path)
    caption_embeddings = data['captions/embeddings']
    image_embeddings = data['images/embeddings']
    caption_labels = data['captions/label']

    X_abs, y_abs = torch.tensor(caption_embeddings), torch.tensor(image_embeddings[np.argmax(caption_labels, axis=1)])
    
    print('Texts shape', X_abs.shape)
    print('Images shape', X_abs.shape)
    
    dataset = TensorDataset(X_abs, y_abs)
    train_dataset, val_dataset = random_split(dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(42))
    
    return train_dataset, val_dataset


def test(val_dataset: TensorDataset, model: Translator, device):
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
    for x_val, y_val in val_loader:
        results = eval_on_val(x_val, y_val, model=model, device=device)
    return results

In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


batch_size= 1024
lr= 0.0001
epochs= 200
patience = 5

data_path= '/kaggle/input/aml-competition/train/train/train.npz'
test_path= '/kaggle/input/aml-competition/test/test/test.clean.npz'

model_save_path= './models/exp1.pth'

train_dataset, val_dataset = load_data(data_path)

Texts shape torch.Size([125000, 1024])
Images shape torch.Size([125000, 1024])


In [36]:
model_args = {
    'input_dim': 1024,
    'output_dim': 1536,
    'dir_hidden_dims': [1024, 2048, 1024],
    'scale_hidden_dims': [1024, 1024],
    'activation': nn.SiLU,
    'dropout_rate': 0.3
}
model = Translator(**model_args).to(device)

train_model(model, model_save_path, train_dataset, val_dataset, batch_size, epochs, lr, patience)

print('Finished training. Now testing using best model...')

state = torch.load(model_save_path)
model.load_state_dict(state)
results = test(val_dataset, model, device)

print("Test Results:", results)

Using device: cuda


Epoch 1/200: 100%|██████████| 110/110 [00:03<00:00, 32.78it/s]


Epoch 1: Train Loss = 4.667129, Val Loss = 41.517302
✓ Saved best model (val_loss=41.517302)


Epoch 2/200: 100%|██████████| 110/110 [00:03<00:00, 34.23it/s]


Epoch 2: Train Loss = 2.818234, Val Loss = 35.083778
✓ Saved best model (val_loss=35.083778)


Epoch 3/200: 100%|██████████| 110/110 [00:03<00:00, 32.47it/s]


Epoch 3: Train Loss = 2.440394, Val Loss = 31.997264
✓ Saved best model (val_loss=31.997264)


Epoch 4/200: 100%|██████████| 110/110 [00:03<00:00, 32.50it/s]


Epoch 4: Train Loss = 2.218421, Val Loss = 30.146695
✓ Saved best model (val_loss=30.146695)


Epoch 5/200: 100%|██████████| 110/110 [00:03<00:00, 32.39it/s]


Epoch 5: Train Loss = 2.067796, Val Loss = 29.011839
✓ Saved best model (val_loss=29.011839)


Epoch 6/200: 100%|██████████| 110/110 [00:03<00:00, 32.46it/s]


Epoch 6: Train Loss = 1.944504, Val Loss = 27.749160
✓ Saved best model (val_loss=27.749160)


Epoch 7/200: 100%|██████████| 110/110 [00:03<00:00, 32.47it/s]


Epoch 7: Train Loss = 1.844366, Val Loss = 27.214054
✓ Saved best model (val_loss=27.214054)


Epoch 8/200: 100%|██████████| 110/110 [00:03<00:00, 32.78it/s]


Epoch 8: Train Loss = 1.763212, Val Loss = 26.797282
✓ Saved best model (val_loss=26.797282)


Epoch 9/200: 100%|██████████| 110/110 [00:03<00:00, 32.93it/s]


Epoch 9: Train Loss = 1.691858, Val Loss = 26.009679
✓ Saved best model (val_loss=26.009679)


Epoch 10/200: 100%|██████████| 110/110 [00:03<00:00, 33.97it/s]


Epoch 10: Train Loss = 1.627888, Val Loss = 25.680139
✓ Saved best model (val_loss=25.680139)


Epoch 11/200: 100%|██████████| 110/110 [00:03<00:00, 32.86it/s]


Epoch 11: Train Loss = 1.568089, Val Loss = 25.298645
✓ Saved best model (val_loss=25.298645)


Epoch 12/200: 100%|██████████| 110/110 [00:03<00:00, 32.66it/s]


Epoch 12: Train Loss = 1.507927, Val Loss = 24.786871
✓ Saved best model (val_loss=24.786871)


Epoch 13/200: 100%|██████████| 110/110 [00:03<00:00, 32.91it/s]


Epoch 13: Train Loss = 1.465404, Val Loss = 24.824306


Epoch 14/200: 100%|██████████| 110/110 [00:03<00:00, 32.86it/s]


Epoch 14: Train Loss = 1.424134, Val Loss = 24.343636
✓ Saved best model (val_loss=24.343636)


Epoch 15/200: 100%|██████████| 110/110 [00:03<00:00, 32.69it/s]


Epoch 15: Train Loss = 1.384076, Val Loss = 24.204970
✓ Saved best model (val_loss=24.204970)


Epoch 16/200: 100%|██████████| 110/110 [00:03<00:00, 32.78it/s]


Epoch 16: Train Loss = 1.341548, Val Loss = 24.077963
✓ Saved best model (val_loss=24.077963)


Epoch 17/200: 100%|██████████| 110/110 [00:03<00:00, 32.77it/s]


Epoch 17: Train Loss = 1.304856, Val Loss = 23.733538
✓ Saved best model (val_loss=23.733538)


Epoch 18/200: 100%|██████████| 110/110 [00:03<00:00, 34.08it/s]


Epoch 18: Train Loss = 1.270314, Val Loss = 23.857379


Epoch 19/200: 100%|██████████| 110/110 [00:03<00:00, 32.37it/s]


Epoch 19: Train Loss = 1.237943, Val Loss = 23.814359


Epoch 20/200: 100%|██████████| 110/110 [00:03<00:00, 32.89it/s]


Epoch 20: Train Loss = 1.208326, Val Loss = 23.569187
✓ Saved best model (val_loss=23.569187)


Epoch 21/200: 100%|██████████| 110/110 [00:03<00:00, 32.84it/s]


Epoch 21: Train Loss = 1.178924, Val Loss = 23.415336
✓ Saved best model (val_loss=23.415336)


Epoch 22/200: 100%|██████████| 110/110 [00:03<00:00, 32.91it/s]


Epoch 22: Train Loss = 1.150263, Val Loss = 23.454975


Epoch 23/200: 100%|██████████| 110/110 [00:03<00:00, 32.56it/s]


Epoch 23: Train Loss = 1.117995, Val Loss = 23.223691
✓ Saved best model (val_loss=23.223691)


Epoch 24/200: 100%|██████████| 110/110 [00:03<00:00, 32.51it/s]


Epoch 24: Train Loss = 1.094549, Val Loss = 23.437679


Epoch 25/200: 100%|██████████| 110/110 [00:03<00:00, 32.88it/s]


Epoch 25: Train Loss = 1.074457, Val Loss = 23.325184


Epoch 26/200: 100%|██████████| 110/110 [00:03<00:00, 32.44it/s]


Epoch 26: Train Loss = 1.050563, Val Loss = 23.311684


Epoch 27/200: 100%|██████████| 110/110 [00:03<00:00, 33.00it/s]


Epoch 27: Train Loss = 1.030228, Val Loss = 23.096148
✓ Saved best model (val_loss=23.096148)


Epoch 28/200: 100%|██████████| 110/110 [00:03<00:00, 33.87it/s]


Epoch 28: Train Loss = 1.003861, Val Loss = 23.115674


Epoch 29/200: 100%|██████████| 110/110 [00:03<00:00, 32.71it/s]


Epoch 29: Train Loss = 0.991370, Val Loss = 23.144019


Epoch 30/200: 100%|██████████| 110/110 [00:03<00:00, 32.78it/s]


Epoch 30: Train Loss = 0.965004, Val Loss = 23.351679


Epoch 31/200: 100%|██████████| 110/110 [00:03<00:00, 32.77it/s]


Epoch 31: Train Loss = 0.942836, Val Loss = 23.545061


Epoch 32/200: 100%|██████████| 110/110 [00:03<00:00, 32.63it/s]


Epoch 32: Train Loss = 0.931619, Val Loss = 23.215234


Epoch 33/200: 100%|██████████| 110/110 [00:03<00:00, 32.74it/s]


Epoch 33: Train Loss = 0.912538, Val Loss = 23.517752
Finished training. Now testing using best model...
Test Results: {'mrr': 0.9325880964633875, 'ndcg': 0.9491635656130114, 'recall_at_1': 0.88944, 'recall_at_3': 0.97312, 'recall_at_5': 0.98648, 'recall_at_10': 0.99496, 'recall_at_50': 0.99928, 'l2_dist': 516.6212158203125}


In [37]:
generate_submission(model, Path(test_path), device=device)

Generating submission file...
✓ Saved submission to submission.csv


Unnamed: 0,id,embedding
0,1,"[14.116164207458496, 4.051480293273926, 14.437..."
1,2,"[-8.252291679382324, -5.295539379119873, -5.83..."
2,3,"[-6.238521099090576, 3.012822151184082, 28.372..."
3,4,"[19.693193435668945, -25.15383529663086, 1.899..."
4,5,"[20.222675323486328, 24.787092208862305, 0.767..."
...,...,...
1495,1496,"[8.458430290222168, -6.62744665145874, 25.4995..."
1496,1497,"[8.686895370483398, 10.717864036560059, 40.116..."
1497,1498,"[-1.205926775932312, -0.44514358043670654, 4.7..."
1498,1499,"[-15.145681381225586, -0.44480353593826294, 9...."


In [42]:
from torchsummary import summary
summary(model, input_size=(1024,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 1024]       1,049,600
              SiLU-2                 [-1, 1024]               0
           Dropout-3                 [-1, 1024]               0
            Linear-4                 [-1, 2048]       2,099,200
              SiLU-5                 [-1, 2048]               0
           Dropout-6                 [-1, 2048]               0
            Linear-7                 [-1, 1024]       2,098,176
              SiLU-8                 [-1, 1024]               0
           Dropout-9                 [-1, 1024]               0
           Linear-10                 [-1, 1536]       1,574,400
           Linear-11                 [-1, 1024]       1,049,600
             SiLU-12                 [-1, 1024]               0
          Dropout-13                 [-1, 1024]               0
           Linear-14                 [-

In [66]:
all_X = torch.stack([train_dataset[i][0] for i in range(len(train_dataset))])
all_y = torch.stack([train_dataset[i][1] for i in range(len(train_dataset))])

with torch.inference_mode():
    translated = model(all_X.to(device)).to('cpu')

print("Mean squared distance before alignment:", ((translated - all_y) ** 2).mean().item())

R, s, t = procrustes_align(translated, all_y)

translated_align = align_matrix(translated, R, s, t)

print("Mean squared distance after alignment:", ((translated_align - all_y) ** 2).mean().item())

Mean squared distance before alignment: 175.68894958496094
Mean squared distance after alignment: 0.15851448476314545


In [64]:
generate_submission(model, Path(test_path), output_file="submission-aligned.csv", device=device, procrustes_data=(R, s, t))


Using Procrusted
Generating submission file...
✓ Saved submission to submission-aligned.csv


Unnamed: 0,id,embedding
0,1,"[0.6429033875465393, 0.11905805766582489, -0.0..."
1,2,"[0.34022626280784607, 0.06264924257993698, -0...."
2,3,"[0.38776659965515137, 0.10746285319328308, 0.2..."
3,4,"[0.5170254707336426, 0.17834457755088806, -0.0..."
4,5,"[0.7246093153953552, 0.2914654016494751, -0.04..."
...,...,...
1495,1496,"[0.4948069751262665, -0.05043584108352661, 0.0..."
1496,1497,"[0.7435898184776306, 0.17692360281944275, 0.06..."
1497,1498,"[0.8017455339431763, 0.019157446920871735, -0...."
1498,1499,"[0.6866579651832581, 0.12024844437837601, -0.3..."
