In [14]:
from typing import Optional
from torch import nn
from torch.nn import functional as F
import torch

class SpaceTranslator(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_layers,
        activation,
        dropout_rate,
        scaling=None,
        centering=None
    ):
        super().__init__()

        self.scaling = scaling
        self.centering = centering

        layers = []
        last = input_dim

        for hidden in hidden_layers:
            layers += [
                nn.Linear(last, hidden),
                # nn.LayerNorm(hidden),
                nn.BatchNorm1d(hidden),
                activation(),
                nn.Dropout(dropout_rate)
            ]
            last = hidden

        layers.append(nn.Linear(last, output_dim))
        self.net = nn.Sequential(*layers)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_uniform_(module.weight, a=0, mode='fan_in', nonlinearity='relu')
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.BatchNorm1d):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x):
      return self.net(x)



In [16]:
from pathlib import Path
import numpy as np
import pandas as pd

'''Code from https://github.com/Mamiglia/challenge'''

def mrr(pred_indices: np.ndarray, gt_indices: np.ndarray) -> float:
    """
    Compute Mean Reciprocal Rank (MRR)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries (top-K)
        gt_indices: (N,) array of ground truth indices
    Returns:
        mrr: Mean Reciprocal Rank
    """
    reciprocal_ranks = []
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i] == gt_indices[i])[0]
        if matches.size > 0:
            reciprocal_ranks.append(1.0 / (matches[0] + 1))
        else:
            reciprocal_ranks.append(0.0)
    return np.mean(reciprocal_ranks)


def recall_at_k(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int) -> float:
    """Compute Recall@k
    Args:
        pred_indices: (N, N) array of top indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        recall: Recall@k
    """
    recall = 0
    for i in range(len(gt_indices)):
        if gt_indices[i] in pred_indices[i, :k]:
            recall += 1
    recall /= len(gt_indices)
    return recall

import numpy as np

def ndcg(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int = 100) -> float:
    """
    Compute Normalized Discounted Cumulative Gain (NDCG@k)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        ndcg: NDCG@k
    """
    ndcg_total = 0.0
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i, :k] == gt_indices[i])[0]
        if matches.size > 0:
            rank = matches[0] + 1
            ndcg_total += 1.0 / np.log2(rank + 1)  # DCG (IDCG = 1)
    return ndcg_total / len(gt_indices)



@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """Evaluate retrieval performance using cosine similarity
    Args:
        translated_embd: (N_captions, D) translated caption embeddings
        image_embd: (N_images, D) image embeddings
        gt_indices: (N_captions,) ground truth image indices for each caption
        max_indices: number of top predictions to consider
    Returns:
        results: dict of evaluation metrics

    """
    # Compute similarity matrix
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()

    n_queries = translated_embd.shape[0]
    device = translated_embd.device

    # Prepare containers for the fragments to be reassembled
    all_sorted_indices = []
    l2_distances = []

    # Process in batches - the narrow gate approach
    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]
        batch_img_embd = image_embd[batch_slice]

        # Compute similarity only for this batch
        batch_similarity = batch_translated @ batch_img_embd.T

        # Get top-k predictions for this batch
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.numpy()
        all_sorted_indices.append(gt_indices[batch_slice][batch_indices])

        # Compute L2 distance for this batch
        batch_gt = gt_indices[batch_slice]
        batch_gt_embeddings = image_embd[batch_gt]
        batch_l2 = (batch_translated - batch_gt_embeddings).norm(dim=1)
        l2_distances.append(batch_l2)

    # Reassemble the fragments
    sorted_indices = np.concatenate(all_sorted_indices, axis=0)

    # Apply the sacred metrics to the whole
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }

    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }

    l2_dist = torch.cat(l2_distances, dim=0).mean().item()
    results['l2_dist'] = l2_dist

    return results

def eval_on_val(x_val: np.ndarray, y_val: np.ndarray, model: nn.Module, device) -> dict:
    gt_indices = torch.arange(len(y_val))

    model.eval()

    with torch.inference_mode():
        translated = model(x_val.to(device)).to('cpu')

    results = evaluate_retrieval(translated, y_val, gt_indices)

    return results


def generate_submission(model: nn.Module, test_path: Path, output_file="submission-dirmodel.csv", device=None):
    test_data = np.load(test_path)
    sample_ids = test_data['captions/ids']
    test_embds = test_data['captions/embeddings']
    test_embds = torch.from_numpy(test_embds).float()

    with torch.no_grad():
        pred_embds = model(test_embds.to(device)).cpu()

    print("Generating submission file...")

    if isinstance(pred_embds, torch.Tensor):
        pred_embds = pred_embds.cpu().numpy()

    df_submission = pd.DataFrame({'id': sample_ids, 'embedding': pred_embds.tolist()})

    df_submission.to_csv(output_file, index=False, float_format='%.17g')
    print(f"‚úì Saved submission to {output_file}")

    return df_submission

In [None]:
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

def info_nce_loss(dir_preds, img_targets, logit_scale: float):
    dir_preds = F.normalize(dir_preds, dim=-1)
    img_targets = F.normalize(img_targets, dim=-1)

    logit_scale = torch.clamp(logit_scale, max=np.log(100))

    logits = (dir_preds @ img_targets.T) * logit_scale.exp()
    labels = torch.arange(logits.size(0), device=logits.device)

    loss_t2i = F.cross_entropy(logits, labels)
    loss_i2t = F.cross_entropy(logits.T, labels)

    return 0.5 * (loss_t2i + loss_i2t)

def info_nce_loss_hard(dir_preds, img_targets, logit_scale, k=256):
    # Normalize (important)
    dir_preds = F.normalize(dir_preds, dim=-1)
    img_targets = F.normalize(img_targets, dim=-1)

    logit_scale = torch.clamp(logit_scale, max=np.log(100))
    logits = torch.matmul(dir_preds, img_targets.t()) * logit_scale.exp()  # (B, B)
    B = logits.size(0)
    device = logits.device

    labels = torch.arange(B, device=device)

    # Mask out positives 
    diag_mask = torch.eye(B, device=device).bool()
    neg_logits = logits.masked_fill(diag_mask, float('-inf'))  


    # Get top-k hardest negatives per row 
    hard_vals_t2i, hard_idx_t2i = torch.topk(neg_logits, k, dim=1)  # (B, k)
    # Get positive logits
    pos_logits = logits[torch.arange(B, device=device), labels].unsqueeze(1)  # (B, 1)

    logits_t2i_hard = torch.cat([pos_logits, hard_vals_t2i], dim=1)  # (B, 1+k)
    labels_t2i = torch.zeros(B, dtype=torch.long, device=device)  # positive is index 0

    loss_t2i = F.cross_entropy(logits_t2i_hard, labels_t2i)

    neg_logits_T = logits.t().masked_fill(diag_mask, float('-inf'))  # (B, B)
    hard_vals_i2t, hard_idx_i2t = torch.topk(neg_logits_T, k, dim=1)
    pos_logits_i2t = logits.t()[torch.arange(B, device=device), labels].unsqueeze(1)
    logits_i2t_hard = torch.cat([pos_logits_i2t, hard_vals_i2t], dim=1)
    labels_i2t = torch.zeros(B, dtype=torch.long, device=device)
    loss_i2t = F.cross_entropy(logits_i2t_hard, labels_i2t)

    return 0.5 * (loss_t2i + loss_i2t)


def l2_regularization(model, lambda_l2):
    l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
    return lambda_l2 * l2_norm


def train_model_direction(
    model: SpaceTranslator,
    model_path: Path,
    train_dataset: TensorDataset,
    val_dataset: TensorDataset,
    batch_size: int,
    epochs: int,
    lr: float,
    patience: int,
    reg_lambda: float = 0.03
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    print(f"üöÄ Using device: {device}")

    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=3,
        threshold=0.005,
        min_lr=1e-6
    )

    best_mrr = float('-inf')
    no_improvements = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_train_loss = 0.0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)
        for X_batch, y_batch in progress_bar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            # X_batch = F.normalize(X_batch, p=2, dim=-1)
            # y_batch = F.normalize(y_batch, p=2, dim=-1)

            optimizer.zero_grad()
            outputs = model(X_batch)

            loss = info_nce_loss(outputs, y_batch, model.logit_scale)
            #loss += l2_regularization(model, reg_lambda)

            loss.backward()
            optimizer.step()

            running_train_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        train_loss = running_train_loss / len(train_loader)

        # Validation phase
        model.eval()
        running_val_loss = 0.0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)

                # X_batch = F.normalize(X_batch, p=2, dim=-1)
                # y_batch = F.normalize(y_batch, p=2, dim=-1)

                outputs = model(X_batch)

                loss = info_nce_loss(outputs, y_batch, model.logit_scale)
                #loss += l2_regularization(model, reg_lambda)

                running_val_loss += loss.item()

        val_loss = running_val_loss / len(val_loader)

        print(f"üìò Epoch {epoch:03d} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | LR: {optimizer.param_groups[0]['lr']:.2e}")

        # Optional: external validation/test
        results = test(val_dataset, model, device)
        print(results, 'Logit scale:', model.logit_scale.exp().item())
        mrr = results['mrr']

        scheduler.step(mrr)

        # Early stopping
        if mrr > best_mrr:
            best_mrr = mrr
            no_improvements = 0

            Path(model_path).parent.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), model_path)
            print(f"üíæ Saved new best model (mrr={mrr:.6f})")
        else:
            no_improvements += 1
            if no_improvements >= patience:
                print("‚èπ Early stopping triggered.")
                break

    print(f"‚úÖ Training complete. Best mrr: {mrr:.6f}")

    return model


def get_data(data_path: Path):
    data = np.load(data_path)
    caption_embeddings = data['captions/embeddings']
    image_embeddings = data['images/embeddings']
    caption_labels = data['captions/label']
    data.close()

    X_abs, y_abs = torch.tensor(caption_embeddings), torch.tensor(image_embeddings[np.argmax(caption_labels, axis=1)])

    return X_abs, y_abs

def get_datasets(X_abs, y_abs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    print('Texts shape', X_abs.shape)
    print('Images shape', y_abs.shape)

    dataset = TensorDataset(X_abs, y_abs)
    train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42))

    return train_dataset, val_dataset

def test(val_dataset: TensorDataset, model: nn.Module, device):
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
    for x_val, y_val in val_loader:
        results = eval_on_val(x_val, y_val, model=model, device=device)
    return results

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = 1024
output_dim = 1536
hidden_layers=[1256, 1656]
dropout_rate = 0.5

batch_size= 2048
lr=0.01
epochs= 250
patience = 8

data_path= '/content/drive/MyDrive/dav/train.npz'
test_path= '/content/drive/MyDrive/dav/test.clean.npz'

save_path = './models/dir-model.pth'

In [8]:
x, y = get_data(data_path)

train_dataset, val_dataset = get_datasets(x, y)

Texts shape torch.Size([125000, 1024])
Images shape torch.Size([125000, 1536])


In [36]:
model_args = {
    'input_dim': x.size(1),
    'output_dim': y.size(1),
    'hidden_layers': hidden_layers,
    'dropout_rate': dropout_rate,
    'activation': nn.ReLU
}

model = SpaceTranslator(**model_args).to(device)

train_model_direction(model, save_path, train_dataset, val_dataset, batch_size, epochs, lr, patience)

print('Finished training. Now testing using best model...')

state = torch.load(save_path)
model.load_state_dict(state)
results = test(val_dataset, model, device)
print("Test Results:", results)

üöÄ Using device: cuda




üìò Epoch 001 | Train Loss: 4.891643 | Val Loss: 3.718127 | LR: 1.00e-02
{'mrr': np.float64(0.7448052281678955), 'ndcg': np.float64(0.8060800068067063), 'recall_at_1': 0.6124, 'recall_at_3': 0.85448, 'recall_at_5': 0.91964, 'recall_at_10': 0.97016, 'recall_at_50': 0.99888, 'l2_dist': 277.0684814453125} Logit scale: 23.2886962890625
üíæ Saved new best model (mrr=0.744805)




üìò Epoch 002 | Train Loss: 3.565430 | Val Loss: 2.843952 | LR: 1.00e-02
{'mrr': np.float64(0.828902700497973), 'ndcg': np.float64(0.8705432968094536), 'recall_at_1': 0.7294, 'recall_at_3': 0.91752, 'recall_at_5': 0.95708, 'recall_at_10': 0.98412, 'recall_at_50': 0.99908, 'l2_dist': 260.9161682128906} Logit scale: 36.78339767456055
üíæ Saved new best model (mrr=0.828903)




üìò Epoch 003 | Train Loss: 2.977001 | Val Loss: 2.432778 | LR: 1.00e-02
{'mrr': np.float64(0.8639403120457971), 'ndcg': np.float64(0.8972404897036886), 'recall_at_1': 0.78128, 'recall_at_3': 0.93876, 'recall_at_5': 0.969, 'recall_at_10': 0.98888, 'recall_at_50': 0.99932, 'l2_dist': 280.1271057128906} Logit scale: 50.50896453857422
üíæ Saved new best model (mrr=0.863940)




üìò Epoch 004 | Train Loss: 2.642725 | Val Loss: 2.211477 | LR: 1.00e-02
{'mrr': np.float64(0.8805852670009757), 'ndcg': np.float64(0.9098967208232419), 'recall_at_1': 0.80644, 'recall_at_3': 0.94872, 'recall_at_5': 0.9734, 'recall_at_10': 0.99088, 'recall_at_50': 0.99952, 'l2_dist': 288.9298095703125} Logit scale: 62.784698486328125
üíæ Saved new best model (mrr=0.880585)




üìò Epoch 005 | Train Loss: 2.417431 | Val Loss: 2.081589 | LR: 1.00e-02
{'mrr': np.float64(0.8900533708694727), 'ndcg': np.float64(0.91710763895202), 'recall_at_1': 0.8202, 'recall_at_3': 0.95488, 'recall_at_5': 0.97656, 'recall_at_10': 0.99164, 'recall_at_50': 0.99952, 'l2_dist': 300.37689208984375} Logit scale: 73.79359436035156
üíæ Saved new best model (mrr=0.890053)




üìò Epoch 006 | Train Loss: 2.246962 | Val Loss: 1.974991 | LR: 1.00e-02
{'mrr': np.float64(0.897581012426187), 'ndcg': np.float64(0.9228032095624058), 'recall_at_1': 0.8324, 'recall_at_3': 0.95812, 'recall_at_5': 0.97916, 'recall_at_10': 0.9928, 'recall_at_50': 0.99932, 'l2_dist': 308.2652893066406} Logit scale: 83.82197570800781
üíæ Saved new best model (mrr=0.897581)




üìò Epoch 007 | Train Loss: 2.107811 | Val Loss: 1.910707 | LR: 1.00e-02
{'mrr': np.float64(0.9032651542170819), 'ndcg': np.float64(0.9271036524105821), 'recall_at_1': 0.84104, 'recall_at_3': 0.96036, 'recall_at_5': 0.98072, 'recall_at_10': 0.99276, 'recall_at_50': 0.99948, 'l2_dist': 309.3061218261719} Logit scale: 92.93132781982422
üíæ Saved new best model (mrr=0.903265)




üìò Epoch 008 | Train Loss: 1.992101 | Val Loss: 1.850050 | LR: 1.00e-02
{'mrr': np.float64(0.9084615343303902), 'ndcg': np.float64(0.931014161246044), 'recall_at_1': 0.84944, 'recall_at_3': 0.9632, 'recall_at_5': 0.9816, 'recall_at_10': 0.993, 'recall_at_50': 0.99948, 'l2_dist': 311.7698974609375} Logit scale: 100.69995880126953
üíæ Saved new best model (mrr=0.908462)




üìò Epoch 009 | Train Loss: 1.889147 | Val Loss: 1.811159 | LR: 1.00e-02
{'mrr': np.float64(0.9111145997723354), 'ndcg': np.float64(0.9330271172842343), 'recall_at_1': 0.8536, 'recall_at_3': 0.96452, 'recall_at_5': 0.98224, 'recall_at_10': 0.99424, 'recall_at_50': 0.99952, 'l2_dist': 302.7626647949219} Logit scale: 101.06170654296875
üíæ Saved new best model (mrr=0.911115)




üìò Epoch 010 | Train Loss: 1.808816 | Val Loss: 1.774472 | LR: 1.00e-02
{'mrr': np.float64(0.9154171242385657), 'ndcg': np.float64(0.9362706310871305), 'recall_at_1': 0.85996, 'recall_at_3': 0.96648, 'recall_at_5': 0.98328, 'recall_at_10': 0.99336, 'recall_at_50': 0.9994, 'l2_dist': 296.2349853515625} Logit scale: 101.0403060913086
üíæ Saved new best model (mrr=0.915417)




üìò Epoch 011 | Train Loss: 1.737900 | Val Loss: 1.754196 | LR: 1.00e-02
{'mrr': np.float64(0.9157091669814397), 'ndcg': np.float64(0.936481907722531), 'recall_at_1': 0.86104, 'recall_at_3': 0.96708, 'recall_at_5': 0.98296, 'recall_at_10': 0.99376, 'recall_at_50': 0.9994, 'l2_dist': 285.4093322753906} Logit scale: 101.0167007446289
üíæ Saved new best model (mrr=0.915709)




üìò Epoch 012 | Train Loss: 1.676498 | Val Loss: 1.727795 | LR: 1.00e-02
{'mrr': np.float64(0.9180617768564787), 'ndcg': np.float64(0.9382519496798783), 'recall_at_1': 0.8646, 'recall_at_3': 0.96832, 'recall_at_5': 0.98296, 'recall_at_10': 0.99356, 'recall_at_50': 0.99952, 'l2_dist': 276.0736999511719} Logit scale: 100.99310302734375
üíæ Saved new best model (mrr=0.918062)




üìò Epoch 013 | Train Loss: 1.618933 | Val Loss: 1.706196 | LR: 1.00e-02
{'mrr': np.float64(0.9200216354640882), 'ndcg': np.float64(0.9397169031957647), 'recall_at_1': 0.86812, 'recall_at_3': 0.96864, 'recall_at_5': 0.98388, 'recall_at_10': 0.99328, 'recall_at_50': 0.9994, 'l2_dist': 273.0699768066406} Logit scale: 100.9695053100586
üíæ Saved new best model (mrr=0.920022)




üìò Epoch 014 | Train Loss: 1.573649 | Val Loss: 1.691857 | LR: 1.00e-02
{'mrr': np.float64(0.920184612270918), 'ndcg': np.float64(0.9398266743692593), 'recall_at_1': 0.86812, 'recall_at_3': 0.968, 'recall_at_5': 0.98312, 'recall_at_10': 0.9934, 'recall_at_50': 0.99956, 'l2_dist': 266.114990234375} Logit scale: 100.9459228515625
üíæ Saved new best model (mrr=0.920185)




üìò Epoch 015 | Train Loss: 1.526073 | Val Loss: 1.677204 | LR: 1.00e-02
{'mrr': np.float64(0.9217017193660212), 'ndcg': np.float64(0.9409671858890585), 'recall_at_1': 0.87092, 'recall_at_3': 0.96904, 'recall_at_5': 0.98312, 'recall_at_10': 0.99336, 'recall_at_50': 0.99952, 'l2_dist': 259.4028625488281} Logit scale: 100.92233276367188
üíæ Saved new best model (mrr=0.921702)




üìò Epoch 016 | Train Loss: 1.489168 | Val Loss: 1.667654 | LR: 1.00e-02
{'mrr': np.float64(0.9232213677659135), 'ndcg': np.float64(0.9421126226247271), 'recall_at_1': 0.87312, 'recall_at_3': 0.97032, 'recall_at_5': 0.98368, 'recall_at_10': 0.9934, 'recall_at_50': 0.99952, 'l2_dist': 256.88397216796875} Logit scale: 100.89875793457031
üíæ Saved new best model (mrr=0.923221)




üìò Epoch 017 | Train Loss: 1.451431 | Val Loss: 1.655464 | LR: 1.00e-02
{'mrr': np.float64(0.9231691622123692), 'ndcg': np.float64(0.9420509924874326), 'recall_at_1': 0.87352, 'recall_at_3': 0.96912, 'recall_at_5': 0.98352, 'recall_at_10': 0.99332, 'recall_at_50': 0.9994, 'l2_dist': 250.40480041503906} Logit scale: 100.87519073486328




üìò Epoch 018 | Train Loss: 1.341547 | Val Loss: 1.622047 | LR: 5.00e-03
{'mrr': np.float64(0.9246851534687374), 'ndcg': np.float64(0.9431888402702717), 'recall_at_1': 0.87628, 'recall_at_3': 0.97032, 'recall_at_5': 0.98364, 'recall_at_10': 0.99332, 'recall_at_50': 0.9994, 'l2_dist': 240.87246704101562} Logit scale: 100.8634033203125
üíæ Saved new best model (mrr=0.924685)




üìò Epoch 019 | Train Loss: 1.274832 | Val Loss: 1.610799 | LR: 5.00e-03
{'mrr': np.float64(0.9252885841736561), 'ndcg': np.float64(0.9436515272442634), 'recall_at_1': 0.877, 'recall_at_3': 0.97028, 'recall_at_5': 0.98392, 'recall_at_10': 0.99328, 'recall_at_50': 0.99944, 'l2_dist': 233.29104614257812} Logit scale: 100.85161590576172
üíæ Saved new best model (mrr=0.925289)




üìò Epoch 020 | Train Loss: 1.237424 | Val Loss: 1.601771 | LR: 5.00e-03
{'mrr': np.float64(0.9254553415946366), 'ndcg': np.float64(0.9437723496481311), 'recall_at_1': 0.87728, 'recall_at_3': 0.97048, 'recall_at_5': 0.98416, 'recall_at_10': 0.99316, 'recall_at_50': 0.99952, 'l2_dist': 226.0605926513672} Logit scale: 100.83983612060547
üíæ Saved new best model (mrr=0.925455)




üìò Epoch 021 | Train Loss: 1.211116 | Val Loss: 1.596574 | LR: 5.00e-03
{'mrr': np.float64(0.9265547539256684), 'ndcg': np.float64(0.9445954489543181), 'recall_at_1': 0.87912, 'recall_at_3': 0.9706, 'recall_at_5': 0.98408, 'recall_at_10': 0.99328, 'recall_at_50': 0.99952, 'l2_dist': 221.04371643066406} Logit scale: 100.82805633544922
üíæ Saved new best model (mrr=0.926555)




üìò Epoch 022 | Train Loss: 1.184829 | Val Loss: 1.595622 | LR: 5.00e-03
{'mrr': np.float64(0.9264834888277234), 'ndcg': np.float64(0.9445377598230831), 'recall_at_1': 0.87904, 'recall_at_3': 0.97076, 'recall_at_5': 0.9842, 'recall_at_10': 0.99304, 'recall_at_50': 0.99936, 'l2_dist': 216.5664825439453} Logit scale: 100.81627655029297




üìò Epoch 023 | Train Loss: 1.137863 | Val Loss: 1.579742 | LR: 2.50e-03
{'mrr': np.float64(0.9273002866107932), 'ndcg': np.float64(0.9451620504290176), 'recall_at_1': 0.88012, 'recall_at_3': 0.97096, 'recall_at_5': 0.98452, 'recall_at_10': 0.99308, 'recall_at_50': 0.99932, 'l2_dist': 211.97291564941406} Logit scale: 100.81156921386719
üíæ Saved new best model (mrr=0.927300)




üìò Epoch 024 | Train Loss: 1.108592 | Val Loss: 1.573265 | LR: 2.50e-03
{'mrr': np.float64(0.9282403796542763), 'ndcg': np.float64(0.9458558829241452), 'recall_at_1': 0.88188, 'recall_at_3': 0.97116, 'recall_at_5': 0.98416, 'recall_at_10': 0.99316, 'recall_at_50': 0.9994, 'l2_dist': 209.7266845703125} Logit scale: 100.80685424804688
üíæ Saved new best model (mrr=0.928240)




üìò Epoch 025 | Train Loss: 1.092943 | Val Loss: 1.572593 | LR: 2.50e-03
{'mrr': np.float64(0.9280009171716843), 'ndcg': np.float64(0.9456646536940462), 'recall_at_1': 0.88164, 'recall_at_3': 0.97128, 'recall_at_5': 0.98384, 'recall_at_10': 0.99316, 'recall_at_50': 0.99936, 'l2_dist': 206.2266387939453} Logit scale: 100.8021469116211




üìò Epoch 026 | Train Loss: 1.076156 | Val Loss: 1.566753 | LR: 2.50e-03
{'mrr': np.float64(0.9288723113896328), 'ndcg': np.float64(0.9463196352355275), 'recall_at_1': 0.88296, 'recall_at_3': 0.97144, 'recall_at_5': 0.98416, 'recall_at_10': 0.99272, 'recall_at_50': 0.99936, 'l2_dist': 203.7335968017578} Logit scale: 100.79743194580078
üíæ Saved new best model (mrr=0.928872)




üìò Epoch 027 | Train Loss: 1.049442 | Val Loss: 1.563583 | LR: 1.25e-03
{'mrr': np.float64(0.9290771560158183), 'ndcg': np.float64(0.9464602342346443), 'recall_at_1': 0.88364, 'recall_at_3': 0.97104, 'recall_at_5': 0.9838, 'recall_at_10': 0.99316, 'recall_at_50': 0.99932, 'l2_dist': 202.90989685058594} Logit scale: 100.79508209228516
üíæ Saved new best model (mrr=0.929077)




üìò Epoch 028 | Train Loss: 1.035203 | Val Loss: 1.560664 | LR: 1.25e-03
{'mrr': np.float64(0.9291593542002589), 'ndcg': np.float64(0.9465352459436297), 'recall_at_1': 0.8834, 'recall_at_3': 0.97124, 'recall_at_5': 0.98412, 'recall_at_10': 0.9932, 'recall_at_50': 0.99936, 'l2_dist': 201.26904296875} Logit scale: 100.792724609375
üíæ Saved new best model (mrr=0.929159)




üìò Epoch 029 | Train Loss: 1.031067 | Val Loss: 1.557243 | LR: 1.25e-03
{'mrr': np.float64(0.9296569620717768), 'ndcg': np.float64(0.9468965950458721), 'recall_at_1': 0.88456, 'recall_at_3': 0.97104, 'recall_at_5': 0.9838, 'recall_at_10': 0.99304, 'recall_at_50': 0.9994, 'l2_dist': 200.1531219482422} Logit scale: 100.79037475585938
üíæ Saved new best model (mrr=0.929657)




üìò Epoch 030 | Train Loss: 1.020053 | Val Loss: 1.555266 | LR: 1.25e-03
{'mrr': np.float64(0.929291663006627), 'ndcg': np.float64(0.9466254861452352), 'recall_at_1': 0.8838, 'recall_at_3': 0.97088, 'recall_at_5': 0.98428, 'recall_at_10': 0.99312, 'recall_at_50': 0.9994, 'l2_dist': 197.5496063232422} Logit scale: 100.78801727294922




üìò Epoch 031 | Train Loss: 1.010725 | Val Loss: 1.553608 | LR: 1.25e-03
{'mrr': np.float64(0.9294624616483831), 'ndcg': np.float64(0.9467463089761186), 'recall_at_1': 0.88424, 'recall_at_3': 0.97136, 'recall_at_5': 0.98396, 'recall_at_10': 0.99292, 'recall_at_50': 0.9994, 'l2_dist': 196.97219848632812} Logit scale: 100.7856674194336




üìò Epoch 032 | Train Loss: 1.001863 | Val Loss: 1.553002 | LR: 1.25e-03
{'mrr': np.float64(0.9293408477714518), 'ndcg': np.float64(0.9466533207974729), 'recall_at_1': 0.88416, 'recall_at_3': 0.97076, 'recall_at_5': 0.98388, 'recall_at_10': 0.99316, 'recall_at_50': 0.9994, 'l2_dist': 194.99513244628906} Logit scale: 100.78330993652344




üìò Epoch 033 | Train Loss: 0.996867 | Val Loss: 1.552147 | LR: 1.25e-03
{'mrr': np.float64(0.9295561098138111), 'ndcg': np.float64(0.9468081309371825), 'recall_at_1': 0.88464, 'recall_at_3': 0.97096, 'recall_at_5': 0.98404, 'recall_at_10': 0.99308, 'recall_at_50': 0.99932, 'l2_dist': 193.18850708007812} Logit scale: 100.78095245361328




üìò Epoch 034 | Train Loss: 0.982834 | Val Loss: 1.548294 | LR: 6.25e-04
{'mrr': np.float64(0.929701539703521), 'ndcg': np.float64(0.9469235348543653), 'recall_at_1': 0.88472, 'recall_at_3': 0.97128, 'recall_at_5': 0.98408, 'recall_at_10': 0.99304, 'recall_at_50': 0.99936, 'l2_dist': 192.81277465820312} Logit scale: 100.77859497070312
üíæ Saved new best model (mrr=0.929702)




üìò Epoch 035 | Train Loss: 0.974833 | Val Loss: 1.547715 | LR: 6.25e-04
{'mrr': np.float64(0.9296114281649918), 'ndcg': np.float64(0.9468575401490896), 'recall_at_1': 0.88456, 'recall_at_3': 0.97096, 'recall_at_5': 0.98416, 'recall_at_10': 0.99316, 'recall_at_50': 0.9994, 'l2_dist': 192.4840545654297} Logit scale: 100.77623748779297




üìò Epoch 036 | Train Loss: 0.973461 | Val Loss: 1.547908 | LR: 6.25e-04
{'mrr': np.float64(0.9289045654493877), 'ndcg': np.float64(0.9463311071950766), 'recall_at_1': 0.88324, 'recall_at_3': 0.9712, 'recall_at_5': 0.98424, 'recall_at_10': 0.99328, 'recall_at_50': 0.99936, 'l2_dist': 191.2478790283203} Logit scale: 100.77388763427734




üìò Epoch 037 | Train Loss: 0.967326 | Val Loss: 1.547713 | LR: 6.25e-04
{'mrr': np.float64(0.9290907158968797), 'ndcg': np.float64(0.9464606794291943), 'recall_at_1': 0.88372, 'recall_at_3': 0.97136, 'recall_at_5': 0.98384, 'recall_at_10': 0.99296, 'recall_at_50': 0.9994, 'l2_dist': 191.04635620117188} Logit scale: 100.77153778076172




üìò Epoch 038 | Train Loss: 0.958739 | Val Loss: 1.545992 | LR: 3.13e-04
{'mrr': np.float64(0.9291883104469748), 'ndcg': np.float64(0.9465377966958666), 'recall_at_1': 0.88376, 'recall_at_3': 0.97136, 'recall_at_5': 0.98376, 'recall_at_10': 0.993, 'recall_at_50': 0.9994, 'l2_dist': 190.48269653320312} Logit scale: 100.76918029785156




üìò Epoch 039 | Train Loss: 0.954723 | Val Loss: 1.545706 | LR: 3.13e-04
{'mrr': np.float64(0.9292880102777151), 'ndcg': np.float64(0.9466120576423411), 'recall_at_1': 0.884, 'recall_at_3': 0.97124, 'recall_at_5': 0.9838, 'recall_at_10': 0.99308, 'recall_at_50': 0.99936, 'l2_dist': 190.5919189453125} Logit scale: 100.7668228149414




üìò Epoch 040 | Train Loss: 0.956957 | Val Loss: 1.544283 | LR: 3.13e-04
{'mrr': np.float64(0.9296174182082967), 'ndcg': np.float64(0.9468581847871166), 'recall_at_1': 0.88456, 'recall_at_3': 0.97124, 'recall_at_5': 0.98392, 'recall_at_10': 0.99308, 'recall_at_50': 0.9994, 'l2_dist': 190.02696228027344} Logit scale: 100.76447296142578




üìò Epoch 041 | Train Loss: 0.950188 | Val Loss: 1.544157 | LR: 3.13e-04
{'mrr': np.float64(0.92990678480899), 'ndcg': np.float64(0.9470688892323228), 'recall_at_1': 0.8852, 'recall_at_3': 0.9712, 'recall_at_5': 0.98392, 'recall_at_10': 0.99304, 'recall_at_50': 0.99936, 'l2_dist': 189.5748748779297} Logit scale: 100.76211547851562
üíæ Saved new best model (mrr=0.929907)




üìò Epoch 042 | Train Loss: 0.947687 | Val Loss: 1.543070 | LR: 1.56e-04
{'mrr': np.float64(0.9295358064570918), 'ndcg': np.float64(0.9467957780708097), 'recall_at_1': 0.88448, 'recall_at_3': 0.97104, 'recall_at_5': 0.98412, 'recall_at_10': 0.99312, 'recall_at_50': 0.99936, 'l2_dist': 189.65907287597656} Logit scale: 100.76211547851562




üìò Epoch 043 | Train Loss: 0.945150 | Val Loss: 1.543259 | LR: 1.56e-04
{'mrr': np.float64(0.9297849687200863), 'ndcg': np.float64(0.9469785170244117), 'recall_at_1': 0.88496, 'recall_at_3': 0.97116, 'recall_at_5': 0.98404, 'recall_at_10': 0.993, 'recall_at_50': 0.99936, 'l2_dist': 189.41714477539062} Logit scale: 100.76211547851562




üìò Epoch 044 | Train Loss: 0.945692 | Val Loss: 1.542981 | LR: 1.56e-04
{'mrr': np.float64(0.9299668128234295), 'ndcg': np.float64(0.9471116568627964), 'recall_at_1': 0.8854, 'recall_at_3': 0.97124, 'recall_at_5': 0.98416, 'recall_at_10': 0.993, 'recall_at_50': 0.9994, 'l2_dist': 189.603515625} Logit scale: 100.76211547851562
üíæ Saved new best model (mrr=0.929967)




üìò Epoch 045 | Train Loss: 0.943841 | Val Loss: 1.543500 | LR: 1.56e-04
{'mrr': np.float64(0.9294940605506691), 'ndcg': np.float64(0.9467586443798195), 'recall_at_1': 0.88452, 'recall_at_3': 0.97108, 'recall_at_5': 0.98412, 'recall_at_10': 0.99308, 'recall_at_50': 0.9994, 'l2_dist': 189.244873046875} Logit scale: 100.76211547851562




üìò Epoch 046 | Train Loss: 0.945473 | Val Loss: 1.543117 | LR: 7.81e-05
{'mrr': np.float64(0.929883377366751), 'ndcg': np.float64(0.9470472764876444), 'recall_at_1': 0.88528, 'recall_at_3': 0.97108, 'recall_at_5': 0.98404, 'recall_at_10': 0.99308, 'recall_at_50': 0.9994, 'l2_dist': 189.17323303222656} Logit scale: 100.76211547851562




üìò Epoch 047 | Train Loss: 0.941299 | Val Loss: 1.542837 | LR: 7.81e-05
{'mrr': np.float64(0.9298137324093603), 'ndcg': np.float64(0.9469957192198211), 'recall_at_1': 0.88516, 'recall_at_3': 0.97108, 'recall_at_5': 0.98404, 'recall_at_10': 0.993, 'recall_at_50': 0.9994, 'l2_dist': 188.52926635742188} Logit scale: 100.76211547851562




üìò Epoch 048 | Train Loss: 0.941972 | Val Loss: 1.542526 | LR: 7.81e-05
{'mrr': np.float64(0.9297132672773586), 'ndcg': np.float64(0.9469215497453004), 'recall_at_1': 0.88492, 'recall_at_3': 0.971, 'recall_at_5': 0.98408, 'recall_at_10': 0.993, 'recall_at_50': 0.99936, 'l2_dist': 188.7519989013672} Logit scale: 100.76211547851562




üìò Epoch 049 | Train Loss: 0.937992 | Val Loss: 1.542267 | LR: 7.81e-05
{'mrr': np.float64(0.9297524359036611), 'ndcg': np.float64(0.9469500398656241), 'recall_at_1': 0.88504, 'recall_at_3': 0.97112, 'recall_at_5': 0.98408, 'recall_at_10': 0.99288, 'recall_at_50': 0.9994, 'l2_dist': 188.34521484375} Logit scale: 100.76211547851562




üìò Epoch 050 | Train Loss: 0.939952 | Val Loss: 1.542368 | LR: 3.91e-05
{'mrr': np.float64(0.929758754678339), 'ndcg': np.float64(0.9469551543782836), 'recall_at_1': 0.88504, 'recall_at_3': 0.97104, 'recall_at_5': 0.98408, 'recall_at_10': 0.99296, 'recall_at_50': 0.9994, 'l2_dist': 188.19631958007812} Logit scale: 100.76211547851562




üìò Epoch 051 | Train Loss: 0.940660 | Val Loss: 1.542644 | LR: 3.91e-05
{'mrr': np.float64(0.9297640232634562), 'ndcg': np.float64(0.9469627405692421), 'recall_at_1': 0.88496, 'recall_at_3': 0.97132, 'recall_at_5': 0.98412, 'recall_at_10': 0.99296, 'recall_at_50': 0.9994, 'l2_dist': 188.56350708007812} Logit scale: 100.76211547851562




üìò Epoch 052 | Train Loss: 0.936226 | Val Loss: 1.542225 | LR: 3.91e-05
{'mrr': np.float64(0.9298726110844082), 'ndcg': np.float64(0.9470423322592552), 'recall_at_1': 0.88512, 'recall_at_3': 0.97108, 'recall_at_5': 0.98404, 'recall_at_10': 0.99296, 'recall_at_50': 0.9994, 'l2_dist': 188.78956604003906} Logit scale: 100.76211547851562
‚èπ Early stopping triggered.
‚úÖ Training complete. Best mrr: 0.929873
Finished training. Now testing using best model...
Test Results: {'mrr': np.float64(0.9299668128234295), 'ndcg': np.float64(0.9471116568627964), 'recall_at_1': 0.8854, 'recall_at_3': 0.97124, 'recall_at_5': 0.98416, 'recall_at_10': 0.993, 'recall_at_50': 0.9994, 'l2_dist': 189.603515625}


In [37]:
generate_submission(model, Path(test_path), output_file="dav.csv", device=device)

Generating submission file...
‚úì Saved submission to dav.csv


Unnamed: 0,id,embedding
0,1,"[0.7686144709587097, 2.381528854370117, 1.0353..."
1,2,"[-4.1902289390563965, 0.8161473870277405, -0.6..."
2,3,"[-1.2787296772003174, -1.561556100845337, -2.2..."
3,4,"[11.185599327087402, -4.036659240722656, -3.34..."
4,5,"[6.494488716125488, 8.447858810424805, 1.51297..."
...,...,...
1495,1496,"[-0.43608617782592773, -5.116342067718506, 5.4..."
1496,1497,"[-0.33950138092041016, 2.236297130584717, 9.56..."
1497,1498,"[2.6319985389709473, -4.231359004974365, -1.67..."
1498,1499,"[-3.7497448921203613, -3.935755729675293, -2.3..."


In [38]:
torch.save(model, "full_model.pth")