### model.py

In [60]:
import torch
from typing import Optional
from torch import nn
from torch.nn import functional as F

class Translator(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        dir_hidden_dims: list[int],
        scale_hidden_dims: list[int],
        activation=nn.ReLU,
        dropout_rate: float=0.3
    ):
        super().__init__()

        def build_mlp(hidden_dims, out_dim, apply_softplus=False):
            layers = []
            last_dim = input_dim
            for hidden in hidden_dims:
                layers += [
                    nn.Linear(last_dim, hidden),
                    activation(),
                    nn.LayerNorm(hidden),
                    nn.Dropout(dropout_rate)
                ]
                last_dim = hidden
            layers.append(nn.Linear(last_dim, out_dim))
            
            if apply_softplus:
                layers.append(nn.Softplus())
            
            return nn.Sequential(*layers)

        self.dir_head = build_mlp(dir_hidden_dims, output_dim, apply_softplus=False)
        self.scale_head = build_mlp(scale_hidden_dims, 1, apply_softplus=True)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x):
        direction = F.normalize(self.dir_head(x), dim=-1)
        scale = self.scale_head(x)
        return direction * scale

### eval.py

In [61]:
from pathlib import Path
import numpy as np
import torch
import pandas as pd

'''Code from https://github.com/Mamiglia/challenge'''

def mrr(pred_indices: np.ndarray, gt_indices: np.ndarray) -> float:
    """
    Compute Mean Reciprocal Rank (MRR)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries (top-K)
        gt_indices: (N,) array of ground truth indices
    Returns:
        mrr: Mean Reciprocal Rank
    """
    reciprocal_ranks = []
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i] == gt_indices[i])[0]
        if matches.size > 0:
            reciprocal_ranks.append(1.0 / (matches[0] + 1))
        else:
            reciprocal_ranks.append(0.0)
    return np.mean(reciprocal_ranks)


def recall_at_k(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int) -> float:
    """Compute Recall@k
    Args:
        pred_indices: (N, N) array of top indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        recall: Recall@k
    """
    recall = 0
    for i in range(len(gt_indices)):
        if gt_indices[i] in pred_indices[i, :k]:
            recall += 1
    recall /= len(gt_indices)
    return recall

import numpy as np

def ndcg(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int = 100) -> float:
    """
    Compute Normalized Discounted Cumulative Gain (NDCG@k)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        ndcg: NDCG@k
    """
    ndcg_total = 0.0
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i, :k] == gt_indices[i])[0]
        if matches.size > 0:
            rank = matches[0] + 1
            ndcg_total += 1.0 / np.log2(rank + 1)  # DCG (IDCG = 1)
    return ndcg_total / len(gt_indices)



@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """Evaluate retrieval performance using cosine similarity
    Args:
        translated_embd: (N_captions, D) translated caption embeddings
        image_embd: (N_images, D) image embeddings
        gt_indices: (N_captions,) ground truth image indices for each caption
        max_indices: number of top predictions to consider
    Returns:
        results: dict of evaluation metrics
    
    """
    # Compute similarity matrix
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()
    
    n_queries = translated_embd.shape[0]
    device = translated_embd.device
    
    # Prepare containers for the fragments to be reassembled
    all_sorted_indices = []
    l2_distances = []
    
    # Process in batches - the narrow gate approach
    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]
        batch_img_embd = image_embd[batch_slice]
        
        # Compute similarity only for this batch
        batch_similarity = batch_translated @ batch_img_embd.T

        # Get top-k predictions for this batch
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.numpy()
        all_sorted_indices.append(gt_indices[batch_slice][batch_indices])

        # Compute L2 distance for this batch
        batch_gt = gt_indices[batch_slice]
        batch_gt_embeddings = image_embd[batch_gt]
        batch_l2 = (batch_translated - batch_gt_embeddings).norm(dim=1)
        l2_distances.append(batch_l2)
    
    # Reassemble the fragments
    sorted_indices = np.concatenate(all_sorted_indices, axis=0)
    
    # Apply the sacred metrics to the whole
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }
    
    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }
    
    l2_dist = torch.cat(l2_distances, dim=0).mean().item()
    results['l2_dist'] = l2_dist
    
    return results

def eval_on_val(x_val: np.ndarray, y_val: np.ndarray, model: Translator, device) -> dict:
    gt_indices = torch.arange(len(y_val))
    
    model.eval()

    with torch.inference_mode():
        translated = model(x_val.to(device)).to('cpu')

    results = evaluate_retrieval(translated, y_val, gt_indices)
    
    return results

def generate_submission(model: Translator, test_path: Path, output_file="submission.csv", device=None):
    test_data = np.load(test_path)
    sample_ids = test_data['captions/ids']
    test_embds = test_data['captions/embeddings']
    test_embds = torch.from_numpy(test_embds).float()

    with torch.no_grad():
        pred_embds = model(test_embds.to(device)).cpu()

    print("Generating submission file...")

    if isinstance(pred_embds, torch.Tensor):
        pred_embds = pred_embds.cpu().numpy()

    df_submission = pd.DataFrame({'id': sample_ids, 'embedding': pred_embds.tolist()})

    df_submission.to_csv(output_file, index=False, float_format='%.17g')
    print(f"✓ Saved submission to {output_file}")

    return df_submission

### configs

### main.py

In [62]:

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
from pathlib import Path
from tqdm import tqdm

# def info_nce_loss(preds, img_targets, temp: float = 0.07, lambda_mag: float = 0.1):
#     preds_norm = F.normalize(preds, dim=1)
#     targets_norm = F.normalize(img_targets, dim=1)

#     direction_logits = (preds_norm @ targets_norm.T) / temp

#     preds_mag = preds.norm(dim=1).unsqueeze(1)
#     targets_mag = img_targets.norm(dim=1).unsqueeze(0
#     magnitude_logits = -(preds_mag - targets_mag)**2

#     combined_logits = direction_logits + lambda_mag * magnitude_logits

#     labels = torch.arange(preds.size(0), device=preds.device)

#     loss_t2i = F.cross_entropy(combined_logits, labels)
#     loss_i2t = F.cross_entropy(combined_logits.T, labels)

#     return 0.5 * (loss_t2i + loss_i2t)

def info_nce_loss(preds_norm, targets_norm, temp=0.07):
    logits = (preds_norm @ targets_norm.T) / temp
    labels = torch.arange(logits.size(0), device=logits.device)

    loss_t2i = F.cross_entropy(logits, labels)          
    loss_i2t = F.cross_entropy(logits.T, labels)        
    
    return 0.5 * (loss_t2i + loss_i2t)


def train_model(
    model: Translator,
    model_path: Path,
    train_dataset: TensorDataset,
    val_dataset: TensorDataset,
    batch_size: int,
    epochs: int,
    lr: float,
    patience: int
) -> Translator:    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Using device: {device}")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    best_val_loss = float('inf')
    no_improvements = 0

    for epoch in range(epochs):
        model.train()

        train_loss = 0
        for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_batch = F.normalize(y_batch, dim=-1)

            optimizer.zero_grad()

            outputs = model(X_batch)

            loss = info_nce_loss(outputs, y_batch, temp=model.logit_scale)

            loss.backward()

            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()

        val_loss = 0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                outputs = model(X_batch)
                
                y_batch = F.normalize(y_batch, dim=-1)

                loss = info_nce_loss(outputs, y_batch, temp=model.logit_scale)

                val_loss += loss.item()

        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
        test(val_dataset, model, device)


        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improvements = 0

            Path(model_path).parent.mkdir(parents=True, exist_ok=True)

            torch.save(model.state_dict(), model_path)

            print(f"✓ Saved best model (val_loss={val_loss:.6f})")
        elif no_improvements >= patience:
            return model
        else:
            no_improvements += 1

    return model



def load_data(data_path: Path):
    data = np.load(data_path)
    caption_embeddings = data['captions/embeddings']
    image_embeddings = data['images/embeddings']
    caption_labels = data['captions/label']

    X_abs, y_abs = torch.tensor(caption_embeddings), torch.tensor(image_embeddings[np.argmax(caption_labels, axis=1)])
    
    print('Texts shape', X_abs.shape)
    print('Images shape', X_abs.shape)

    def print_stats():
        mean_X = X_abs.mean(dim=0)
        std_X = X_abs.std(dim=0)
        
        mean_Y = y_abs.mean(dim=0)
        std_Y = y_abs.std(dim=0)

        print("X: mean of stds per dim =", std_X.mean().item(), ", max =", std_X.max().item(), ", min =", std_X.min().item())
        print("Y: mean of stds per dim =", std_Y.mean().item(), ", max =", std_Y.max().item(), ", min =", std_Y.min().item())

    print_stats()
    
    dataset = TensorDataset(X_abs, y_abs)
    train_dataset, val_dataset = random_split(dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(42))
    
    return train_dataset, val_dataset


def test(val_dataset: TensorDataset, model: Translator, device):
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
    for x_val, y_val in val_loader:
        results = eval_on_val(x_val, y_val, model=model, device=device)
    return results

In [63]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size= 2048
lr= 0.01
epochs= 200
patience = 5

data_path= '/kaggle/input/aml-competition/train/train/train.npz'
test_path= '/kaggle/input/aml-competition/test/test/test.clean.npz'

model_save_path= './models/exp1.pth'

train_dataset, val_dataset = load_data(data_path)

Texts shape torch.Size([125000, 1024])
Images shape torch.Size([125000, 1024])
X: mean of stds per dim = 0.788078248500824 , max = 3.573546886444092 , min = 0.3716050386428833
Y: mean of stds per dim = 0.4244377911090851 , max = 1.8597956895828247 , min = 0.08161858469247818


In [65]:
model_args = {
    'input_dim': 1024,
    'output_dim': 1536,
    'dir_hidden_dims': [1256, 2048],
    'scale_hidden_dims': [256, 128],
    'activation': nn.SiLU,
    'dropout_rate': 0.25
}
model = Translator(**model_args).to(device)

train_model(model, model_save_path, train_dataset, val_dataset, batch_size, epochs, lr, patience)

print('Finished training. Now testing using best model...')

state = torch.load(model_save_path)
model.load_state_dict(state)
results = test(val_dataset, model, device)

print("Test Results:", results)

Using device: cuda


Epoch 1/200: 100%|██████████| 55/55 [00:03<00:00, 16.33it/s]


Epoch 1: Train Loss = 4.365358, Val Loss = 2.574099
✓ Saved best model (val_loss=2.574099)


Epoch 2/200: 100%|██████████| 55/55 [00:03<00:00, 16.21it/s]


Epoch 2: Train Loss = 2.557806, Val Loss = 2.049838
✓ Saved best model (val_loss=2.049838)


Epoch 3/200: 100%|██████████| 55/55 [00:03<00:00, 16.16it/s]


Epoch 3: Train Loss = 2.044762, Val Loss = 1.859903
✓ Saved best model (val_loss=1.859903)


Epoch 4/200: 100%|██████████| 55/55 [00:03<00:00, 16.07it/s]


Epoch 4: Train Loss = 1.725505, Val Loss = 1.736459
✓ Saved best model (val_loss=1.736459)


Epoch 5/200: 100%|██████████| 55/55 [00:03<00:00, 15.82it/s]


Epoch 5: Train Loss = 1.477817, Val Loss = 1.674191
✓ Saved best model (val_loss=1.674191)


Epoch 6/200: 100%|██████████| 55/55 [00:03<00:00, 15.95it/s]


Epoch 6: Train Loss = 1.300288, Val Loss = 1.637229
✓ Saved best model (val_loss=1.637229)


Epoch 7/200: 100%|██████████| 55/55 [00:03<00:00, 16.30it/s]


Epoch 7: Train Loss = 1.151557, Val Loss = 1.623874
✓ Saved best model (val_loss=1.623874)


Epoch 8/200: 100%|██████████| 55/55 [00:03<00:00, 16.34it/s]


Epoch 8: Train Loss = 1.022859, Val Loss = 1.625003


Epoch 9/200: 100%|██████████| 55/55 [00:03<00:00, 16.43it/s]


Epoch 9: Train Loss = 0.923676, Val Loss = 1.624463


Epoch 10/200: 100%|██████████| 55/55 [00:03<00:00, 16.13it/s]


Epoch 10: Train Loss = 0.844318, Val Loss = 1.637386


Epoch 11/200: 100%|██████████| 55/55 [00:03<00:00, 15.32it/s]


Epoch 11: Train Loss = 0.768754, Val Loss = 1.649925


Epoch 12/200: 100%|██████████| 55/55 [00:03<00:00, 16.49it/s]


Epoch 12: Train Loss = 0.712459, Val Loss = 1.638953


Epoch 13/200: 100%|██████████| 55/55 [00:03<00:00, 16.44it/s]


Epoch 13: Train Loss = 0.661334, Val Loss = 1.674062
Finished training. Now testing using best model...
Test Results: {'mrr': 0.9309442176248136, 'ndcg': 0.9478297591884011, 'recall_at_1': 0.8876, 'recall_at_3': 0.97072, 'recall_at_5': 0.98288, 'recall_at_10': 0.99344, 'recall_at_50': 0.99968, 'l2_dist': 241.52784729003906}


In [66]:
generate_submission(model, Path(test_path), device=device)

Generating submission file...
✓ Saved submission to submission.csv


Unnamed: 0,id,embedding
0,1,"[0.8063499927520752, -1.2960084676742554, 6.71..."
1,2,"[0.33625540137290955, 1.8729370832443237, -1.3..."
2,3,"[-0.25203096866607666, 1.7044193744659424, 3.7..."
3,4,"[8.706969261169434, -6.267874240875244, -1.426..."
4,5,"[8.63264274597168, 9.660758018493652, 0.019903..."
...,...,...
1495,1496,"[1.6442174911499023, 1.4861117601394653, 15.40..."
1496,1497,"[5.36964225769043, 9.188952445983887, 14.46781..."
1497,1498,"[1.1236704587936401, -10.325639724731445, 3.51..."
1498,1499,"[-3.5596485137939453, -1.9238801002502441, -2...."


In [67]:
from torchsummary import summary
summary(model, input_size=(1024,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 1256]       1,287,400
              SiLU-2                 [-1, 1256]               0
         LayerNorm-3                 [-1, 1256]           2,512
           Dropout-4                 [-1, 1256]               0
            Linear-5                 [-1, 2048]       2,574,336
              SiLU-6                 [-1, 2048]               0
         LayerNorm-7                 [-1, 2048]           4,096
           Dropout-8                 [-1, 2048]               0
            Linear-9                 [-1, 1536]       3,147,264
           Linear-10                  [-1, 256]         262,400
             SiLU-11                  [-1, 256]               0
        LayerNorm-12                  [-1, 256]             512
          Dropout-13                  [-1, 256]               0
           Linear-14                  [

In [68]:
import optuna
from optuna.pruners import MedianPruner

def objective(trial, train_dataset, val_dataset,
              epochs: int = 10, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    n_layers = trial.suggest_int("n_layers", 1, 4)
    layer_choices = [512, 1024, 2048, 4096]
    hidden_layers = [trial.suggest_categorical(f"n_units_l{i}", layer_choices) for i in range(n_layers)]

    batch_size = trial.suggest_categorical("batch_size", [256, 512, 1024, 2048, 4096])
    lr = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
    temp = trial.suggest_float("temp", 0.05, 0.3, log=True)
    dropout_rate = trial.suggest_categorical('dropout_rate', [0.1,0.2,0.3,0.4,0.5])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = Translator(input_dim=input_dim, output_dim=output_dim, hidden_layers=hidden_layers, dropout_rate=dropout_rate)
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            y_batch = F.normalize(y_batch, dim=-1)

            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = info_nce_loss(outputs, y_batch, temp=model.logit_scale)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                y_batch = F.normalize(y_batch, dim=-1)
                
                outputs = model(X_batch)
                loss = info_nce_loss(outputs, y_batch, temp=model.logit_scale)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        results = test(val_dataset, model, device)

        trial.report(results['mrr'], epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()


    return results['mrr']


def run_optuna_search(data_path: Path,
                      n_trials: int = 30, epochs: int = 10,
                      n_jobs: int = 1, sampler=None, pruner=None):
    if pruner is None:
        pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=1)
    train_dataset, val_dataset = load_data(data_path)

    study = optuna.create_study(direction="maximize", pruner=pruner)
    func = lambda trial: objective(trial, train_dataset=train_dataset, val_dataset=val_dataset,
                                   epochs=epochs)
    study.optimize(func, n_trials=n_trials, n_jobs=n_jobs)

    print("Study statistics:")
    print("  Number of finished trials: ", len(study.trials))
    print("  Best trial:")
    trial = study.best_trial
    print("    Value: ", trial.value)
    print("    Params: ")
    for k, v in trial.params.items():
        print(f"      {k}: {v}")

    return study

In [None]:
study = run_optuna_search(data_path=data_path,
                          n_trials=100, epochs=10, n_jobs=1)
optuna.study.study.Storage  
study.trials_dataframe().to_csv("optuna_trials.csv", index=False)

best_trial_number = study.best_trial.number
print("Best params:", study.best_params)
print("Best trial number:", study.best_trial.number)