In [None]:
!git clone https://github.com/the-summoning/aml-challenge.git

In [None]:
from torch import nn
from torch.nn import functional as F
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from pathlib import Path
import optuna
from optuna.pruners import MedianPruner
from eval import generate_submission,test
from dataset import get_datasets, get_data
import gdown

Download dataset

In [None]:
url = "https://drive.google.com/drive/folders/1HWFHKCprFzR7H7TYhrE-W7v4bz2Vc7Ia"
gdown.download_folder(url, quiet=True, use_cookies=False)

A small MLP that maps text embeddings to image embeddings space . The output is L2-normalized and includes a learnable logit scaling parameter.

In [None]:
class SpaceTranslator(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_layers,
        activation,
        dropout_rate,
        init_method: str = 'xavier'
    ):
        super().__init__()

        self.init_method = init_method.lower()
        if self.init_method not in ['xavier', 'kaiming']:
            raise ValueError("Unsupported init_method")

        layers = []
        last = input_dim

        for hidden in hidden_layers:
            layers += [
                nn.Linear(last, hidden),
                nn.LayerNorm(hidden),
                activation(),
                nn.Dropout(dropout_rate)
            ]
            last = hidden

        layers.append(nn.Linear(last, output_dim))
        self.net = nn.Sequential(*layers)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            if self.init_method == 'kaiming':
                nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')
            else:
                nn.init.xavier_uniform_(module.weight)
            
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
                
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)


    def forward(self, x):
        return F.normalize(self.net(x), p=2, dim=1)

Training loop

In [None]:

def queue_info_nce_loss(q, k, queue, logit_scale):

    l_pos = torch.sum(q * k, dim=1, keepdim=True)  # (batch,1) (positives)
    l_neg = q @ queue.T                            # (batch, queue_size) (negatives)

    logit_scale = torch.clamp(logit_scale, min=np.log(0.01), max=np.log(100))

    logits = torch.cat([l_pos, l_neg], dim=1)
    logits = logits * logit_scale.exp()

    labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device) # Positives are in the 0-th index

    return F.cross_entropy(logits, labels)


@torch.no_grad()
def enqueue(queue, keys, queue_ptr):
    batch_size = keys.shape[0]
    queue_size = queue.shape[0]
    ptr = int(queue_ptr[0])     # current insert index

    if ptr + batch_size <= queue_size:      # if there is enough space
        queue[ptr:ptr+batch_size, :] = keys
    else:                                   # otherwise wrap around
        first_part = queue_size - ptr
        queue[ptr:, :] = keys[:first_part, :]
        queue[:batch_size - first_part, :] = keys[first_part:, :]

    queue_ptr[0] = (ptr + batch_size) % queue_size # update pointer

def train_model(model, save_path, train_dataset, val_dataset, batch_size, epochs, lr, patience, queue_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    
    scheduler = ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, threshold=0.003, min_lr=1e-5
    )

    queue = torch.zeros(queue_size, 1536, device=device)
    queue_ptr = torch.zeros(1, dtype=torch.long, device=device)

    best_val_loss = float('inf')
    loss_no_improvements = 0 

    for epoch in range(1, epochs+1):
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)

        for text_batch, image_emb_batch in progress_bar:
            text_batch, image_emb_batch = text_batch.to(device), image_emb_batch.to(device)
            
            optimizer.zero_grad()

            q = model(text_batch)
            k = image_emb_batch

            loss = queue_info_nce_loss(q, k, queue, model.logit_scale)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                enqueue(queue, k, queue_ptr)

            running_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        avg_train_loss = running_loss / len(train_loader)

        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for text_batch, image_emb_batch in val_loader:
                text_batch, image_emb_batch = text_batch.to(device), image_emb_batch.to(device)

                q = model(text_batch)
                k = image_emb_batch

                loss = queue_info_nce_loss(q, k, queue, model.logit_scale)
                running_val_loss += loss.item()

        avg_val_loss = running_val_loss / len(val_loader)

        results = test(val_dataset, model, device)
        mrr = results['mrr']

        scheduler.step(avg_val_loss) 

        print(f"Epoch {epoch:03d} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f} | "
              f"MRR: {mrr:.6f} | Recall-1: {results['recall_at_1']:.6f} | LR: {optimizer.param_groups[0]['lr']:.2e}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            loss_no_improvements = 0 
            
            Path(save_path).parent.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), save_path)
            print(f"Saved new best model (Val Loss={avg_val_loss:.6f})")
        else:
            loss_no_improvements += 1
            if loss_no_improvements >= patience:
                print("Early stopping triggered based on Validation Loss.")
                break

    print(f" Training complete")
    return model


Hyperparameters

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_path= '/kaggle/input/challenge-dataset/train.npz'
test_path= '/kaggle/input/challenge-dataset/test.clean.npz'

save_path = './models/best_model.pth'

input_dim = 1024
output_dim = 1536
hidden_layers = [1256, 2048]
dropout_rate = 0.5
batch_size = 256
lr = 0.01
epochs = 250
patience = 5

queue_size = 10000

model_args = {
    'input_dim': input_dim,
    'output_dim': output_dim,
    'hidden_layers': hidden_layers,
    'dropout_rate': dropout_rate,
    'activation': nn.GELU
}

In [None]:
x, y = get_data(data_path)
train_dataset, val_dataset = get_datasets(x, y)

Training

In [None]:
model = SpaceTranslator(**model_args)

#train_model(model, save_path, train_dataset, val_dataset, batch_size, epochs, lr, patience, queue_size)

print('Finished training. Now testing using best model...')

Testing

In [None]:

state = torch.load(save_path)
model.load_state_dict(state)
results = test(val_dataset, model, device)
print("Test Results:", results)

In [None]:
generate_submission(model, Path(test_path), output_file="sub.csv", device=device)

Code to perform hyperparameter optimization using optuna library

In [None]:
def objective(trial, train_dataset, val_dataset, epochs: int = 15, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Optimizer params
    lr = 0.001
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    
    # Training params
    batch_size = trial.suggest_categorical("batch_size", [128, 256, 512])
    
    queue_size = trial.suggest_categorical("queue_size", [8192, 16384, 32768, 65536])

    # Model architecture params
    dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5)
    n_layers = trial.suggest_int("n_layers", 1, 3)
    hidden_dims = []
    for i in range(n_layers):
        dim = trial.suggest_categorical(f"hidden_dim_layer_{i}", [1280, 1472, 1856, 2048])
        hidden_dims.append(dim)

    # Weights init
    init_method = trial.suggest_categorical("init_method", ['xavier', 'kaiming'])
    

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)


    model_args = {
        'input_dim': 1024,
        'output_dim': 1536,
        'hidden_layers': hidden_dims,
        'dropout_rate': dropout_rate,
        'activation': nn.GELU,
        'init_method': init_method
    }
    model = SpaceTranslator(**model_args).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    scheduler = ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2, threshold=0.003, min_lr=1e-6
    )

    queue = torch.zeros(queue_size, 1536, device=device)
    queue_ptr = torch.zeros(1, dtype=torch.long, device=device)

    best_val_loss = float('inf')

    for epoch in range(1, epochs+1):
        model.train()
        running_loss = 0.0

        for text_batch, image_emb_batch in train_loader:
            text_batch, image_emb_batch = text_batch.to(device), image_emb_batch.to(device)
            
            optimizer.zero_grad()
            q = model(text_batch)
            k = image_emb_batch
            loss = queue_info_nce_loss(q, k, queue, model.logit_scale)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                enqueue(queue, k, queue_ptr)

            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)

        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for text_batch, image_emb_batch in val_loader:
                text_batch, image_emb_batch = text_batch.to(device), image_emb_batch.to(device)
                q = model(text_batch)
                k = image_emb_batch
                loss = queue_info_nce_loss(q, k, queue, model.logit_scale)
                running_val_loss += loss.item()
        
        avg_val_loss = running_val_loss / len(val_loader)
        results = test(val_dataset, model, device) 
        mrr = results['mrr']

        scheduler.step(avg_val_loss) 

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
        
        trial.report(avg_val_loss, epoch)
        if trial.should_prune():
            print(f"⏹ Optuna pruning triggered for Trial {trial.number}.")
            raise optuna.exceptions.TrialPruned()

    print(f"✅ Trial {trial.number} complete. Best Val Loss: {best_val_loss}, MRR: {mrr:.6f}")
    return best_val_loss



def run_optuna_search(data_path: Path, n_trials: int = 150, epochs: int = 30, n_jobs: int = 1, sampler=None, pruner=None):
    if pruner is None:
        pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=3) 

    X, y = get_data(data_path)
    train_dataset, val_dataset = get_datasets(X, y)

    study = optuna.create_study(direction="minimize", pruner=pruner, sampler=sampler)
    
    func = lambda trial: objective(trial, train_dataset=train_dataset, val_dataset=val_dataset, epochs=epochs)
    
    try:
        study.optimize(func, n_trials=n_trials, n_jobs=n_jobs)
    except KeyboardInterrupt:
        print("Search stopped manually.")

    print("Study statistics:")
    print("  Number of finished trials: ", len(study.trials))
    print("  Best trial:")
    trial = study.best_trial
    print(f"    Value (Min Val Loss): {trial.value:.6f}")
    print("    Params: ")
    for k, v in trial.params.items():
        print(f"      {k}: {v}")

    return study

In [None]:
#study = run_optuna_search(data_path=data_path, n_trials=150, epochs=35, n_jobs=1)
#study.trials_dataframe().to_csv("optuna_trials.csv", index=False)

#print("Best params:", study.best_params)
#print("Best trial number:", study.best_trial.number)