In [1]:
from typing import Optional
from torch import nn
from torch.nn import functional as F
import torch

class _SpaceTranslator(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_layers,
        activation,
        dropout_rate
    ):
        super().__init__()

        layers = []
        last = input_dim

        for hidden in hidden_layers:
            layers += [
                nn.Linear(last, hidden),
                # nn.LayerNorm(hidden),
                nn.BatchNorm1d(hidden),
                activation(),
                nn.Dropout(dropout_rate)
            ]
            last = hidden
        self.body = nn.Sequential(*layers)

        self.final = nn.Linear(last, output_dim)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_uniform_(module.weight, a=0, mode='fan_in', nonlinearity='relu')
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.BatchNorm1d):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x):
        x = self.body(x)
        return self.final(x)



class SpaceTranslator(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_layers,
        activation,
        dropout_rate
    ):
        super().__init__()

        layers = []
        last = input_dim

        for hidden in hidden_layers:
            layers += [
                nn.Linear(last, hidden),
                nn.LayerNorm(hidden),
                activation(),
                nn.Dropout(dropout_rate)
            ]
            last = hidden

        layers.append(nn.Linear(last, output_dim))
        self.net = nn.Sequential(*layers)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)


    def forward(self, x):
      return F.normalize(self.net(x), p=2, dim=1)



In [9]:
from pathlib import Path
import numpy as np
import torch
import pandas as pd

'''Code from https://github.com/Mamiglia/challenge'''

def mrr(pred_indices: np.ndarray, gt_indices: np.ndarray) -> float:
    """
    Compute Mean Reciprocal Rank (MRR)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries (top-K)
        gt_indices: (N,) array of ground truth indices
    Returns:
        mrr: Mean Reciprocal Rank
    """
    reciprocal_ranks = []
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i] == gt_indices[i])[0]
        if matches.size > 0:
            reciprocal_ranks.append(1.0 / (matches[0] + 1))
        else:
            reciprocal_ranks.append(0.0)
    return np.mean(reciprocal_ranks)


def recall_at_k(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int) -> float:
    """Compute Recall@k
    Args:
        pred_indices: (N, N) array of top indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        recall: Recall@k
    """
    recall = 0
    for i in range(len(gt_indices)):
        if gt_indices[i] in pred_indices[i, :k]:
            recall += 1
    recall /= len(gt_indices)
    return recall

import numpy as np

def ndcg(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int = 100) -> float:
    """
    Compute Normalized Discounted Cumulative Gain (NDCG@k)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        ndcg: NDCG@k
    """
    ndcg_total = 0.0
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i, :k] == gt_indices[i])[0]
        if matches.size > 0:
            rank = matches[0] + 1
            ndcg_total += 1.0 / np.log2(rank + 1)  # DCG (IDCG = 1)
    return ndcg_total / len(gt_indices)



@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """Evaluate retrieval performance using cosine similarity
    Args:
        translated_embd: (N_captions, D) translated caption embeddings
        image_embd: (N_images, D) image embeddings
        gt_indices: (N_captions,) ground truth image indices for each caption
        max_indices: number of top predictions to consider
    Returns:
        results: dict of evaluation metrics
    
    """
    # Compute similarity matrix
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()
    
    n_queries = translated_embd.shape[0]
    device = translated_embd.device
    
    # Prepare containers for the fragments to be reassembled
    all_sorted_indices = []
    l2_distances = []
    
    # Process in batches - the narrow gate approach
    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]
        batch_img_embd = image_embd[batch_slice]
        
        # Compute similarity only for this batch
        batch_similarity = batch_translated @ batch_img_embd.T

        # Get top-k predictions for this batch
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.numpy()
        all_sorted_indices.append(gt_indices[batch_slice][batch_indices])

        # Compute L2 distance for this batch
        batch_gt = gt_indices[batch_slice]
        batch_gt_embeddings = image_embd[batch_gt]
        batch_l2 = (batch_translated - batch_gt_embeddings).norm(dim=1)
        l2_distances.append(batch_l2)
    
    # Reassemble the fragments
    sorted_indices = np.concatenate(all_sorted_indices, axis=0)
    
    # Apply the sacred metrics to the whole
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }
    
    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }
    
    l2_dist = torch.cat(l2_distances, dim=0).mean().item()
    results['l2_dist'] = l2_dist
    
    return results



def generate_submission(models: nn.Module, test_path: Path, ensemble, output_file="submission-dirmodel.csv", device=None):
    test_data = np.load(test_path)
    sample_ids = test_data['captions/ids']
    test_embds = test_data['captions/embeddings']
    test_embds = torch.from_numpy(test_embds).float()

    '''with torch.no_grad():
        pred_embds = model(test_embds.to(device)).cpu()'''
    with torch.no_grad():
        if ensemble:
            preds = []
            for model in models:
                model.eval()
                preds.append(model(test_embds.to(device)).to('cpu'))
                
            pred_embds = torch.stack(preds).mean(dim=0)
        else:
            pred_embds = model(test_embds.to(device)).cpu()
    print("Generating submission file...")

    if isinstance(pred_embds, torch.Tensor):
        pred_embds = pred_embds.cpu().numpy()

    df_submission = pd.DataFrame({'id': sample_ids, 'embedding': pred_embds.tolist()})

    df_submission.to_csv(output_file, index=False, float_format='%.17g')
    print(f"✓ Saved submission to {output_file}")

    return df_submission

In [3]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
from pathlib import Path
from tqdm import tqdm

def compute_targets_metrics(train_dataset, device):
    sum_vec = None
    sum_sq = None
    n = len(train_dataset)
    dummy_loader = DataLoader(train_dataset, len(train_dataset))
    for x,y in dummy_loader:
        b = y.shape[0]

    sum_vec = y.sum(dim=0).double().to(device)
    sum_sq = (y * y).sum(dim=0).double().to(device)
    
    target_mean = (sum_vec / n).float().to(device)
    var = (sum_sq / n - target_mean.double()**2).float().clamp(min=0.0)
    target_std = torch.sqrt(var + 1e-12).to(device)
    return target_mean, target_std

    
def moment_loss(preds, target_mean, target_std, mode='per_dim', coef=1e-2):
    if mode == 'per_dim':
        pm = preds.mean(dim=0)
        ps = preds.std(dim=0)
        loss_mean = F.mse_loss(pm, target_mean)
        loss_std  = F.mse_loss(ps,  target_std)
        return coef * (loss_mean + loss_std)
    elif mode == 'scalar':
        pm = preds.mean()
        ps = preds.std()
        tm = target_mean.mean()
        ts = target_std.mean()
        return coef * (F.mse_loss(pm, tm) + F.mse_loss(ps, ts))
    
'''
def info_nce_loss(dir_preds, img_targets, logit_scale):
    dir_preds = F.normalize(dir_preds, dim=-1)
    img_targets = F.normalize(img_targets, dim=-1)
    
    logit_scale = torch.clamp(logit_scale, max=np.log(100))
    logits = (dir_preds @ img_targets.T) * logit_scale.exp()

    labels = torch.arange(logits.size(0), device=logits.device)

    loss_t2i = F.cross_entropy(logits, labels)
    loss_i2t = F.cross_entropy(logits.T, labels)

    return 0.5 * (loss_t2i + loss_i2t)
'''

def info_nce_loss(
    dir_preds,
    img_targets,
    logit_scale: torch.Tensor,
    margin: float = 0.3,
    alpha: float = 0.7
):
    """
    InfoNCE simmetrico + hard-negative Margin Ranking Loss su entrambe le direzioni.
    """
    dir_preds = F.normalize(dir_preds, dim=-1)
    img_targets = F.normalize(img_targets, dim=-1)

    # Clamp logit scale per stabilità
    logit_scale = torch.clamp(logit_scale, min=np.log(0.01), max=np.log(100))

    # --- InfoNCE simmetrico ---
    logits = dir_preds @ img_targets.T * logit_scale.exp()
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_t2i = F.cross_entropy(logits, labels)
    loss_i2t = F.cross_entropy(logits.T, labels)
    loss_nce = 0.5 * (loss_t2i + loss_i2t)

    # --- Hard negative Margin Ranking Loss per testo → immagine ---
    mask = torch.eye(logits.size(0), device=logits.device)
    logits_no_pos = logits - mask * 1e9
    hardest_neg_t2i = logits_no_pos.max(dim=1).values
    positive_sim_t2i = torch.diag(logits)
    loss_hard_t2i = F.relu(hardest_neg_t2i - positive_sim_t2i + margin).mean()

    # --- Hard negative Margin Ranking Loss per immagine → testo ---
    logits_no_pos_i2t = logits.T - mask * 1e9
    hardest_neg_i2t = logits_no_pos_i2t.max(dim=1).values
    positive_sim_i2t = torch.diag(logits.T)
    loss_hard_i2t = F.relu(hardest_neg_i2t - positive_sim_i2t + margin).mean()

    # Loss finale combinata
    loss_hard_total = 0.5 * (loss_hard_t2i + loss_hard_i2t)
    loss = loss_nce + alpha * loss_hard_total

    return loss




def train_model_direction(
    model: SpaceTranslator,
    model_path: Path,
    train_dataset: TensorDataset,
    val_dataset: TensorDataset,
    batch_size: int,
    epochs: int,
    lr: float,
    patience: int,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    print(f"Using device: {device}")

    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3, threshold=0.001, min_lr=1e-9
    )
    
    ema_state = {k: v.clone().detach().cpu() for k, v in model.state_dict().items()}
    ema_decay = 0.9995
    
    best_mrr = float('-inf')
    best_val_loss = float('inf')
    no_improvements = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_train_loss = 0.0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)
        for X_batch, y_batch in progress_bar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            # X_batch = F.normalize(X_batch, dim=-1)
            # y_batch = F.normalize(y_batch, dim=-1)

            optimizer.zero_grad()
            outputs = model(X_batch)
            
            loss = info_nce_loss(outputs, y_batch, model.logit_scale) #+ moment_loss(outputs, target_mean, target_std)
            loss.backward()
            optimizer.step()
            ema_update(model, ema_state, ema_decay)

            running_train_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        train_loss = running_train_loss / len(train_loader)

        # Validation phase
        model.eval()
        running_val_loss = 0.0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                # X_batch = F.normalize(X_batch, dim=-1)
                # y_batch = F.normalize(y_batch, dim=-1)

                outputs = model(X_batch)
                loss = info_nce_loss(outputs, y_batch, model.logit_scale) #+ moment_loss(outputs, target_mean, target_std)

                running_val_loss += loss.item()

        val_loss = running_val_loss / len(val_loader)

        

        print(f"Epoch {epoch:03d} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | LR: {optimizer.param_groups[0]['lr']:.2e}")

        # Optional: external validation/test
        results = test(val_dataset, model, device)
        print(results)
        mrr = results['mrr']

        # Step the scheduler
        scheduler.step(mrr)
        
        # Save best model and early stopping
        if mrr > best_mrr:
            best_mrr = mrr
            no_improvements = 0

            Path(model_path).parent.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), model_path)
            print(f"Saved new best model (mrr={mrr:.6f})")
        else:
            no_improvements += 1
            if no_improvements >= patience:
                print("Early stopping triggered.")
                break

    print(f"Training complete. Best mrr: {mrr:.6f}")

    return model, ema_state

def ema_update(model, ema_state, decay):
    with torch.no_grad():
        for k, v in model.state_dict().items():
            if v.dtype.is_floating_point:
                ema_state[k].mul_(decay).add_(v.detach().cpu(), alpha=1 - decay)


def get_data(data_path: Path):
    data = np.load(data_path)
    caption_embeddings = data['captions/embeddings']
    image_embeddings = data['images/embeddings']
    caption_labels = data['captions/label']
    data.close()

    X_abs, y_abs = torch.tensor(caption_embeddings), torch.tensor(image_embeddings[np.argmax(caption_labels, axis=1)])

    return X_abs, y_abs

    
def get_datasets(X_abs, y_abs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:    
    print('Texts shape', X_abs.shape)
    print('Images shape', y_abs.shape)
    
    dataset = TensorDataset(X_abs, y_abs)
    train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
    
    return train_dataset, val_dataset

def test(val_dataset: TensorDataset, model, device, scale=None):
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
    for x_val, y_val in val_loader:
        results = eval_on_val(x_val, y_val, model=model, device=device, scale=scale)
    return results
    
def eval_on_val(x_val: np.ndarray, y_val: np.ndarray, model, device, scale=None) -> dict:
    gt_indices = torch.arange(len(y_val))
    
    model.eval()

    with torch.no_grad():
        translated = model(x_val.to(device)).to('cpu')

    results = evaluate_retrieval(translated, y_val, gt_indices)
    
    return results


def test_ensemble(val_dataset: TensorDataset, models, device, scale=None):
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
    for x_val, y_val in val_loader:
        results = eval_ensemble(x_val, y_val, models=models, device=device, scale=scale)
    return results
    

def eval_ensemble(x_val: np.ndarray, y_val: np.ndarray, models, device, scale=None) -> dict:
    gt_indices = torch.arange(len(y_val))
    preds = []
    with torch.no_grad():
        for model in models:
            model.eval()
            preds.append(model(x_val.to(device)).to('cpu'))
            
    avg_pred = torch.stack(preds).mean(dim=0)

    results = evaluate_retrieval(avg_pred, y_val, gt_indices)
    
    return results


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = 1024
output_dim = 1536
hidden_layers=[1256, 1536]
dropout_rate = 0.5

batch_size= 4096
lr=0.01
epochs= 250
patience = 10


data_path= '/kaggle/input/amlcomp/data/train/train.npz'
test_path= '/kaggle/input/amlcomp/data/test/test.clean.npz'

In [5]:
dir_x, dir_y = get_data(data_path)

target_norm_mean = dir_y.norm(dim=1).mean().item()

dir_train_dataset, dir_val_dataset = get_datasets(dir_x, dir_y)

print('Target norm mean', target_norm_mean)

Texts shape torch.Size([125000, 1024])
Images shape torch.Size([125000, 1536])
Target norm mean 25.93919563293457


In [7]:

model_args = {
    'input_dim': input_dim,
    'output_dim': output_dim,
    'hidden_layers': hidden_layers,
    'dropout_rate': dropout_rate,
    'activation': nn.GELU,
}

last_model = SpaceTranslator(**model_args).to(device)
last_model, ema_state = train_model_direction(last_model, './models/best_model.pth', dir_train_dataset, dir_val_dataset, batch_size, epochs, lr, patience)

print('Finished training. Now testing using best model...')

ema_model = SpaceTranslator(**model_args).to(device)
sd = {k: v.to(device) for k, v in ema_state.items()}
ema_model.load_state_dict(sd)

best_model = SpaceTranslator(**model_args).to(device)
state = torch.load('./models/best_model.pth')
best_model.load_state_dict(state)

models = [best_model, last_model, ema_model]
results = test_ensemble(dir_val_dataset, models, device)
print("Test Results:", results)


Using device: cuda


                                                                       

Epoch 001 | Train Loss: 8.380910 | Val Loss: 7.383249 | LR: 1.00e-02
{'mrr': 0.35580608111744055, 'ndcg': 0.49583899478748644, 'recall_at_1': 0.18932, 'recall_at_3': 0.41204, 'recall_at_5': 0.5488, 'recall_at_10': 0.7406, 'recall_at_50': 0.98944, 'l2_dist': 25.613479614257812}
Saved new best model (mrr=0.355806)


                                                                       

Epoch 002 | Train Loss: 7.389417 | Val Loss: 6.538291 | LR: 1.00e-02
{'mrr': 0.5540739395520453, 'ndcg': 0.657135108267455, 'recall_at_1': 0.37944, 'recall_at_3': 0.66484, 'recall_at_5': 0.78676, 'recall_at_10': 0.9048, 'recall_at_50': 0.9974, 'l2_dist': 25.63945770263672}
Saved new best model (mrr=0.554074)


                                                                       

Epoch 003 | Train Loss: 6.656401 | Val Loss: 5.798542 | LR: 1.00e-02
{'mrr': 0.6746046666404142, 'ndcg': 0.7517534664516107, 'recall_at_1': 0.5212, 'recall_at_3': 0.79124, 'recall_at_5': 0.87868, 'recall_at_10': 0.95372, 'recall_at_50': 0.9986, 'l2_dist': 25.6640567779541}
Saved new best model (mrr=0.674605)


                                                                       

Epoch 004 | Train Loss: 6.003433 | Val Loss: 5.101874 | LR: 1.00e-02
{'mrr': 0.7597227165049936, 'ndcg': 0.8175587733763083, 'recall_at_1': 0.63316, 'recall_at_3': 0.86472, 'recall_at_5': 0.92772, 'recall_at_10': 0.97404, 'recall_at_50': 0.99896, 'l2_dist': 25.67354965209961}
Saved new best model (mrr=0.759723)


                                                                       

Epoch 005 | Train Loss: 5.427717 | Val Loss: 4.548949 | LR: 1.00e-02
{'mrr': 0.8108942331647567, 'ndcg': 0.8568531328198261, 'recall_at_1': 0.70336, 'recall_at_3': 0.90488, 'recall_at_5': 0.95016, 'recall_at_10': 0.98324, 'recall_at_50': 0.99932, 'l2_dist': 25.693864822387695}
Saved new best model (mrr=0.810894)


                                                                       

Epoch 006 | Train Loss: 4.946915 | Val Loss: 4.133269 | LR: 1.00e-02
{'mrr': 0.8464540934941884, 'ndcg': 0.8839154049114717, 'recall_at_1': 0.75652, 'recall_at_3': 0.92568, 'recall_at_5': 0.96276, 'recall_at_10': 0.98748, 'recall_at_50': 0.99932, 'l2_dist': 25.71617317199707}
Saved new best model (mrr=0.846454)


                                                                       

Epoch 007 | Train Loss: 4.577795 | Val Loss: 3.851835 | LR: 1.00e-02
{'mrr': 0.8648983891531877, 'ndcg': 0.8979857975738033, 'recall_at_1': 0.78276, 'recall_at_3': 0.93856, 'recall_at_5': 0.96868, 'recall_at_10': 0.9896, 'recall_at_50': 0.99952, 'l2_dist': 25.73394203186035}
Saved new best model (mrr=0.864898)


                                                                       

Epoch 008 | Train Loss: 4.273684 | Val Loss: 3.642885 | LR: 1.00e-02
{'mrr': 0.8789125186376634, 'ndcg': 0.9086085171807898, 'recall_at_1': 0.80408, 'recall_at_3': 0.9474, 'recall_at_5': 0.97256, 'recall_at_10': 0.99048, 'recall_at_50': 0.99952, 'l2_dist': 25.753223419189453}
Saved new best model (mrr=0.878913)


                                                                       

Epoch 009 | Train Loss: 4.017218 | Val Loss: 3.463404 | LR: 1.00e-02
{'mrr': 0.8903844087722886, 'ndcg': 0.9172789375365563, 'recall_at_1': 0.82248, 'recall_at_3': 0.95156, 'recall_at_5': 0.97564, 'recall_at_10': 0.99148, 'recall_at_50': 0.99948, 'l2_dist': 25.76553726196289}
Saved new best model (mrr=0.890384)


                                                                        

Epoch 010 | Train Loss: 3.805119 | Val Loss: 3.343525 | LR: 1.00e-02
{'mrr': 0.8966019973236101, 'ndcg': 0.9220252132956973, 'recall_at_1': 0.83096, 'recall_at_3': 0.957, 'recall_at_5': 0.97808, 'recall_at_10': 0.99228, 'recall_at_50': 0.9994, 'l2_dist': 25.778079986572266}
Saved new best model (mrr=0.896602)


                                                                        

Epoch 011 | Train Loss: 3.615846 | Val Loss: 3.237405 | LR: 1.00e-02
{'mrr': 0.9005923167356508, 'ndcg': 0.9250760432655348, 'recall_at_1': 0.837, 'recall_at_3': 0.95924, 'recall_at_5': 0.97948, 'recall_at_10': 0.99312, 'recall_at_50': 0.9994, 'l2_dist': 25.787567138671875}
Saved new best model (mrr=0.900592)


                                                                        

Epoch 012 | Train Loss: 3.443288 | Val Loss: 3.142991 | LR: 1.00e-02
{'mrr': 0.9076585661014327, 'ndcg': 0.9304228724765103, 'recall_at_1': 0.84804, 'recall_at_3': 0.96232, 'recall_at_5': 0.982, 'recall_at_10': 0.99356, 'recall_at_50': 0.99952, 'l2_dist': 25.794769287109375}
Saved new best model (mrr=0.907659)


                                                                        

Epoch 013 | Train Loss: 3.301806 | Val Loss: 3.067406 | LR: 1.00e-02
{'mrr': 0.9123512656359155, 'ndcg': 0.933922459051361, 'recall_at_1': 0.85608, 'recall_at_3': 0.96456, 'recall_at_5': 0.9822, 'recall_at_10': 0.99328, 'recall_at_50': 0.9994, 'l2_dist': 25.801929473876953}
Saved new best model (mrr=0.912351)


                                                                        

Epoch 014 | Train Loss: 3.168260 | Val Loss: 3.019124 | LR: 1.00e-02
{'mrr': 0.9132785006389482, 'ndcg': 0.9346240239411833, 'recall_at_1': 0.85764, 'recall_at_3': 0.96468, 'recall_at_5': 0.98268, 'recall_at_10': 0.9936, 'recall_at_50': 0.9994, 'l2_dist': 25.805822372436523}
Saved new best model (mrr=0.913279)


                                                                        

Epoch 015 | Train Loss: 3.049540 | Val Loss: 2.962376 | LR: 1.00e-02
{'mrr': 0.9180012910369498, 'ndcg': 0.9381988702052382, 'recall_at_1': 0.8648, 'recall_at_3': 0.9672, 'recall_at_5': 0.98332, 'recall_at_10': 0.9938, 'recall_at_50': 0.99944, 'l2_dist': 25.80666160583496}
Saved new best model (mrr=0.918001)


                                                                        

Epoch 016 | Train Loss: 2.929943 | Val Loss: 2.927990 | LR: 1.00e-02
{'mrr': 0.91831103037989, 'ndcg': 0.938412932853889, 'recall_at_1': 0.86556, 'recall_at_3': 0.9678, 'recall_at_5': 0.98364, 'recall_at_10': 0.99348, 'recall_at_50': 0.99936, 'l2_dist': 25.80826759338379}
Saved new best model (mrr=0.918311)


                                                                        

Epoch 017 | Train Loss: 2.833404 | Val Loss: 2.903053 | LR: 1.00e-02
{'mrr': 0.9202249908921978, 'ndcg': 0.9398535508632168, 'recall_at_1': 0.86892, 'recall_at_3': 0.96836, 'recall_at_5': 0.98348, 'recall_at_10': 0.99432, 'recall_at_50': 0.99936, 'l2_dist': 25.807594299316406}
Saved new best model (mrr=0.920225)


                                                                        

Epoch 018 | Train Loss: 2.746961 | Val Loss: 2.860410 | LR: 1.00e-02
{'mrr': 0.9212542752833387, 'ndcg': 0.9406448614973074, 'recall_at_1': 0.87016, 'recall_at_3': 0.969, 'recall_at_5': 0.98388, 'recall_at_10': 0.99428, 'recall_at_50': 0.99936, 'l2_dist': 25.807334899902344}
Saved new best model (mrr=0.921254)


                                                                        

Epoch 019 | Train Loss: 2.661004 | Val Loss: 2.830985 | LR: 1.00e-02
{'mrr': 0.9228837394119696, 'ndcg': 0.9418572410116838, 'recall_at_1': 0.87296, 'recall_at_3': 0.96904, 'recall_at_5': 0.98364, 'recall_at_10': 0.994, 'recall_at_50': 0.99928, 'l2_dist': 25.806922912597656}
Saved new best model (mrr=0.922884)


                                                                        

Epoch 020 | Train Loss: 2.587376 | Val Loss: 2.829229 | LR: 1.00e-02
{'mrr': 0.9244149224415631, 'ndcg': 0.9429819913265683, 'recall_at_1': 0.87608, 'recall_at_3': 0.96884, 'recall_at_5': 0.98424, 'recall_at_10': 0.99396, 'recall_at_50': 0.99932, 'l2_dist': 25.80609703063965}
Saved new best model (mrr=0.924415)


                                                                        

Epoch 021 | Train Loss: 2.519988 | Val Loss: 2.798049 | LR: 1.00e-02
{'mrr': 0.9252509688307761, 'ndcg': 0.9436459330507627, 'recall_at_1': 0.8768, 'recall_at_3': 0.97072, 'recall_at_5': 0.98488, 'recall_at_10': 0.99456, 'recall_at_50': 0.9994, 'l2_dist': 25.806129455566406}
Saved new best model (mrr=0.925251)


                                                                        

Epoch 022 | Train Loss: 2.458097 | Val Loss: 2.790547 | LR: 1.00e-02
{'mrr': 0.9258039490496521, 'ndcg': 0.9440449247795694, 'recall_at_1': 0.87796, 'recall_at_3': 0.97072, 'recall_at_5': 0.9846, 'recall_at_10': 0.994, 'recall_at_50': 0.99948, 'l2_dist': 25.805112838745117}
Saved new best model (mrr=0.925804)


                                                                        

Epoch 023 | Train Loss: 2.404535 | Val Loss: 2.761713 | LR: 1.00e-02
{'mrr': 0.9274354969005641, 'ndcg': 0.9452582276845849, 'recall_at_1': 0.8808, 'recall_at_3': 0.97088, 'recall_at_5': 0.98496, 'recall_at_10': 0.99368, 'recall_at_50': 0.99932, 'l2_dist': 25.803247451782227}
Saved new best model (mrr=0.927435)


                                                                        

Epoch 024 | Train Loss: 2.347515 | Val Loss: 2.755812 | LR: 1.00e-02
{'mrr': 0.928383452663581, 'ndcg': 0.9459725736936131, 'recall_at_1': 0.88232, 'recall_at_3': 0.97028, 'recall_at_5': 0.98504, 'recall_at_10': 0.9936, 'recall_at_50': 0.99948, 'l2_dist': 25.80314064025879}
Saved new best model (mrr=0.928383)


                                                                        

Epoch 025 | Train Loss: 2.303176 | Val Loss: 2.745455 | LR: 1.00e-02
{'mrr': 0.9280781939428003, 'ndcg': 0.9457310524575538, 'recall_at_1': 0.88184, 'recall_at_3': 0.97132, 'recall_at_5': 0.98468, 'recall_at_10': 0.99404, 'recall_at_50': 0.99932, 'l2_dist': 25.802305221557617}


                                                                        

Epoch 026 | Train Loss: 2.248813 | Val Loss: 2.735030 | LR: 1.00e-02
{'mrr': 0.9290460212519926, 'ndcg': 0.9464614525986059, 'recall_at_1': 0.88352, 'recall_at_3': 0.97196, 'recall_at_5': 0.98476, 'recall_at_10': 0.99392, 'recall_at_50': 0.99932, 'l2_dist': 25.80044937133789}
Saved new best model (mrr=0.929046)


                                                                        

Epoch 027 | Train Loss: 2.199817 | Val Loss: 2.730633 | LR: 1.00e-02
{'mrr': 0.9288574095491803, 'ndcg': 0.9463168970332081, 'recall_at_1': 0.88308, 'recall_at_3': 0.97132, 'recall_at_5': 0.98492, 'recall_at_10': 0.99384, 'recall_at_50': 0.99928, 'l2_dist': 25.80037498474121}


                                                                        

Epoch 028 | Train Loss: 2.156844 | Val Loss: 2.713791 | LR: 1.00e-02
{'mrr': 0.9293321298723217, 'ndcg': 0.9466679879865599, 'recall_at_1': 0.88448, 'recall_at_3': 0.97088, 'recall_at_5': 0.98464, 'recall_at_10': 0.99408, 'recall_at_50': 0.99932, 'l2_dist': 25.80062484741211}
Saved new best model (mrr=0.929332)


                                                                        

Epoch 029 | Train Loss: 2.115837 | Val Loss: 2.711139 | LR: 1.00e-02
{'mrr': 0.9300936436543854, 'ndcg': 0.947258838026343, 'recall_at_1': 0.88488, 'recall_at_3': 0.97192, 'recall_at_5': 0.98556, 'recall_at_10': 0.99396, 'recall_at_50': 0.99924, 'l2_dist': 25.799917221069336}
Saved new best model (mrr=0.930094)


                                                                        

Epoch 030 | Train Loss: 2.079249 | Val Loss: 2.708037 | LR: 1.00e-02
{'mrr': 0.9311285617090628, 'ndcg': 0.9480110060528593, 'recall_at_1': 0.88708, 'recall_at_3': 0.9716, 'recall_at_5': 0.98528, 'recall_at_10': 0.99376, 'recall_at_50': 0.9994, 'l2_dist': 25.798791885375977}
Saved new best model (mrr=0.931129)


                                                                        

Epoch 031 | Train Loss: 2.039774 | Val Loss: 2.703317 | LR: 1.00e-02
{'mrr': 0.9311532877025941, 'ndcg': 0.9480271596993, 'recall_at_1': 0.88724, 'recall_at_3': 0.9722, 'recall_at_5': 0.98492, 'recall_at_10': 0.99368, 'recall_at_50': 0.99928, 'l2_dist': 25.797462463378906}
Saved new best model (mrr=0.931153)


                                                                        

Epoch 032 | Train Loss: 2.011385 | Val Loss: 2.686060 | LR: 1.00e-02
{'mrr': 0.9307549318717766, 'ndcg': 0.9477416665384872, 'recall_at_1': 0.88648, 'recall_at_3': 0.97144, 'recall_at_5': 0.98512, 'recall_at_10': 0.99404, 'recall_at_50': 0.99932, 'l2_dist': 25.79694175720215}


                                                                        

Epoch 033 | Train Loss: 1.978727 | Val Loss: 2.686340 | LR: 1.00e-02
{'mrr': 0.9309662506376771, 'ndcg': 0.9479085206158799, 'recall_at_1': 0.88644, 'recall_at_3': 0.97232, 'recall_at_5': 0.98532, 'recall_at_10': 0.99376, 'recall_at_50': 0.99936, 'l2_dist': 25.797170639038086}


                                                                        

Epoch 034 | Train Loss: 1.946780 | Val Loss: 2.685623 | LR: 1.00e-02
{'mrr': 0.9311725753712627, 'ndcg': 0.9480304991446379, 'recall_at_1': 0.8874, 'recall_at_3': 0.9722, 'recall_at_5': 0.98472, 'recall_at_10': 0.99388, 'recall_at_50': 0.99944, 'l2_dist': 25.796310424804688}
Saved new best model (mrr=0.931173)


                                                                        

Epoch 035 | Train Loss: 1.838080 | Val Loss: 2.649760 | LR: 5.00e-03
{'mrr': 0.9326254943421306, 'ndcg': 0.9491000343471561, 'recall_at_1': 0.89024, 'recall_at_3': 0.97192, 'recall_at_5': 0.98456, 'recall_at_10': 0.99368, 'recall_at_50': 0.99932, 'l2_dist': 25.794919967651367}
Saved new best model (mrr=0.932625)


                                                                        

Epoch 036 | Train Loss: 1.752867 | Val Loss: 2.636541 | LR: 5.00e-03
{'mrr': 0.9333037091886133, 'ndcg': 0.949643514910053, 'recall_at_1': 0.8908, 'recall_at_3': 0.97324, 'recall_at_5': 0.98532, 'recall_at_10': 0.99416, 'recall_at_50': 0.99932, 'l2_dist': 25.794681549072266}
Saved new best model (mrr=0.933304)


                                                                        

Epoch 037 | Train Loss: 1.709147 | Val Loss: 2.633349 | LR: 5.00e-03
{'mrr': 0.9323623194403796, 'ndcg': 0.948930878278772, 'recall_at_1': 0.8892, 'recall_at_3': 0.97272, 'recall_at_5': 0.9854, 'recall_at_10': 0.99392, 'recall_at_50': 0.99928, 'l2_dist': 25.794340133666992}


                                                                        

Epoch 038 | Train Loss: 1.686249 | Val Loss: 2.627029 | LR: 5.00e-03
{'mrr': 0.9333713807460171, 'ndcg': 0.9496789123347132, 'recall_at_1': 0.89108, 'recall_at_3': 0.97244, 'recall_at_5': 0.98556, 'recall_at_10': 0.99384, 'recall_at_50': 0.99928, 'l2_dist': 25.793954849243164}
Saved new best model (mrr=0.933371)


                                                                        

Epoch 039 | Train Loss: 1.662844 | Val Loss: 2.627471 | LR: 5.00e-03
{'mrr': 0.9328818330534576, 'ndcg': 0.9493081420743241, 'recall_at_1': 0.89036, 'recall_at_3': 0.97192, 'recall_at_5': 0.98548, 'recall_at_10': 0.99356, 'recall_at_50': 0.99924, 'l2_dist': 25.79345703125}


                                                                         

Epoch 040 | Train Loss: 1.610342 | Val Loss: 2.613528 | LR: 2.50e-03
{'mrr': 0.9337026438760349, 'ndcg': 0.9499266877208143, 'recall_at_1': 0.89156, 'recall_at_3': 0.97272, 'recall_at_5': 0.98528, 'recall_at_10': 0.99372, 'recall_at_50': 0.99924, 'l2_dist': 25.79306411743164}
Saved new best model (mrr=0.933703)


                                                                         

Epoch 041 | Train Loss: 1.584051 | Val Loss: 2.608236 | LR: 2.50e-03
{'mrr': 0.9340878944514277, 'ndcg': 0.9502237904025213, 'recall_at_1': 0.89232, 'recall_at_3': 0.97292, 'recall_at_5': 0.98508, 'recall_at_10': 0.99416, 'recall_at_50': 0.99924, 'l2_dist': 25.792781829833984}
Saved new best model (mrr=0.934088)


                                                                        

Epoch 042 | Train Loss: 1.554579 | Val Loss: 2.608389 | LR: 2.50e-03
{'mrr': 0.9338970597341559, 'ndcg': 0.9500709501897607, 'recall_at_1': 0.89212, 'recall_at_3': 0.97256, 'recall_at_5': 0.985, 'recall_at_10': 0.99376, 'recall_at_50': 0.99924, 'l2_dist': 25.79255485534668}


                                                                         

Epoch 043 | Train Loss: 1.542355 | Val Loss: 2.600090 | LR: 2.50e-03
{'mrr': 0.9336916145116207, 'ndcg': 0.9499340549498992, 'recall_at_1': 0.89152, 'recall_at_3': 0.97308, 'recall_at_5': 0.98512, 'recall_at_10': 0.99404, 'recall_at_50': 0.99932, 'l2_dist': 25.7926082611084}


                                                                         

Epoch 044 | Train Loss: 1.530987 | Val Loss: 2.601380 | LR: 2.50e-03
{'mrr': 0.933487868403298, 'ndcg': 0.9497687165715075, 'recall_at_1': 0.8914, 'recall_at_3': 0.97284, 'recall_at_5': 0.98476, 'recall_at_10': 0.99388, 'recall_at_50': 0.99932, 'l2_dist': 25.792112350463867}


                                                                         

Epoch 045 | Train Loss: 1.496105 | Val Loss: 2.599255 | LR: 1.25e-03
{'mrr': 0.9332075602392791, 'ndcg': 0.9495637211543897, 'recall_at_1': 0.89072, 'recall_at_3': 0.9726, 'recall_at_5': 0.985, 'recall_at_10': 0.99388, 'recall_at_50': 0.99924, 'l2_dist': 25.79210090637207}


                                                                         

Epoch 046 | Train Loss: 1.482428 | Val Loss: 2.594173 | LR: 1.25e-03
{'mrr': 0.9339258830502695, 'ndcg': 0.9501031700342107, 'recall_at_1': 0.89196, 'recall_at_3': 0.97268, 'recall_at_5': 0.98544, 'recall_at_10': 0.9938, 'recall_at_50': 0.9992, 'l2_dist': 25.79204559326172}


                                                                         

Epoch 047 | Train Loss: 1.473732 | Val Loss: 2.594404 | LR: 1.25e-03
{'mrr': 0.9344063211820803, 'ndcg': 0.950462712048782, 'recall_at_1': 0.89272, 'recall_at_3': 0.973, 'recall_at_5': 0.98528, 'recall_at_10': 0.99376, 'recall_at_50': 0.99928, 'l2_dist': 25.79180908203125}
Saved new best model (mrr=0.934406)


                                                                         

Epoch 048 | Train Loss: 1.458167 | Val Loss: 2.596526 | LR: 1.25e-03
{'mrr': 0.9343364166791048, 'ndcg': 0.9504002462953953, 'recall_at_1': 0.89276, 'recall_at_3': 0.97288, 'recall_at_5': 0.98524, 'recall_at_10': 0.99396, 'recall_at_50': 0.9992, 'l2_dist': 25.791662216186523}


                                                                         

Epoch 049 | Train Loss: 1.449331 | Val Loss: 2.594361 | LR: 6.25e-04
{'mrr': 0.9342122152278471, 'ndcg': 0.9503100994620056, 'recall_at_1': 0.89252, 'recall_at_3': 0.97244, 'recall_at_5': 0.98524, 'recall_at_10': 0.99392, 'recall_at_50': 0.99924, 'l2_dist': 25.79153823852539}


                                                                         

Epoch 050 | Train Loss: 1.445197 | Val Loss: 2.593144 | LR: 6.25e-04
{'mrr': 0.9343138969267267, 'ndcg': 0.9503861738864958, 'recall_at_1': 0.89272, 'recall_at_3': 0.9726, 'recall_at_5': 0.98516, 'recall_at_10': 0.99392, 'recall_at_50': 0.9992, 'l2_dist': 25.79148292541504}


                                                                         

Epoch 051 | Train Loss: 1.440666 | Val Loss: 2.591266 | LR: 6.25e-04
{'mrr': 0.9347131891900016, 'ndcg': 0.9506799312825273, 'recall_at_1': 0.8936, 'recall_at_3': 0.97268, 'recall_at_5': 0.98544, 'recall_at_10': 0.99404, 'recall_at_50': 0.99924, 'l2_dist': 25.791385650634766}
Saved new best model (mrr=0.934713)


                                                                         

Epoch 052 | Train Loss: 1.432362 | Val Loss: 2.590660 | LR: 6.25e-04
{'mrr': 0.9346984444919628, 'ndcg': 0.9506746384038447, 'recall_at_1': 0.8934, 'recall_at_3': 0.97316, 'recall_at_5': 0.9852, 'recall_at_10': 0.994, 'recall_at_50': 0.99924, 'l2_dist': 25.7912654876709}


                                                                         

Epoch 053 | Train Loss: 1.430760 | Val Loss: 2.589856 | LR: 6.25e-04
{'mrr': 0.9347953094810894, 'ndcg': 0.9507543758988949, 'recall_at_1': 0.89336, 'recall_at_3': 0.97308, 'recall_at_5': 0.98556, 'recall_at_10': 0.99396, 'recall_at_50': 0.99924, 'l2_dist': 25.791284561157227}
Saved new best model (mrr=0.934795)


                                                                         

Epoch 054 | Train Loss: 1.424198 | Val Loss: 2.588831 | LR: 6.25e-04
{'mrr': 0.9346199669911136, 'ndcg': 0.9506177373231574, 'recall_at_1': 0.89324, 'recall_at_3': 0.97304, 'recall_at_5': 0.9854, 'recall_at_10': 0.99392, 'recall_at_50': 0.99928, 'l2_dist': 25.79121208190918}


                                                                         

Epoch 055 | Train Loss: 1.423280 | Val Loss: 2.587293 | LR: 6.25e-04
{'mrr': 0.9343762840633089, 'ndcg': 0.9504412761956023, 'recall_at_1': 0.89268, 'recall_at_3': 0.97308, 'recall_at_5': 0.9852, 'recall_at_10': 0.99392, 'recall_at_50': 0.99928, 'l2_dist': 25.791088104248047}


                                                                         

Epoch 056 | Train Loss: 1.412294 | Val Loss: 2.587547 | LR: 3.13e-04
{'mrr': 0.9343690761909182, 'ndcg': 0.9504355636762454, 'recall_at_1': 0.89264, 'recall_at_3': 0.97296, 'recall_at_5': 0.98508, 'recall_at_10': 0.99384, 'recall_at_50': 0.99932, 'l2_dist': 25.791000366210938}


                                                                         

Epoch 057 | Train Loss: 1.410353 | Val Loss: 2.586589 | LR: 3.13e-04
{'mrr': 0.9345267091185141, 'ndcg': 0.9505551556495302, 'recall_at_1': 0.89292, 'recall_at_3': 0.97296, 'recall_at_5': 0.98508, 'recall_at_10': 0.994, 'recall_at_50': 0.99928, 'l2_dist': 25.790964126586914}


                                                                         

Epoch 058 | Train Loss: 1.408190 | Val Loss: 2.586844 | LR: 3.13e-04
{'mrr': 0.934510752606386, 'ndcg': 0.950544772791889, 'recall_at_1': 0.89288, 'recall_at_3': 0.97336, 'recall_at_5': 0.9852, 'recall_at_10': 0.99392, 'recall_at_50': 0.99928, 'l2_dist': 25.79092025756836}


                                                                         

Epoch 059 | Train Loss: 1.401390 | Val Loss: 2.586945 | LR: 3.13e-04
{'mrr': 0.9345887974873985, 'ndcg': 0.9505959970593953, 'recall_at_1': 0.8932, 'recall_at_3': 0.97296, 'recall_at_5': 0.98508, 'recall_at_10': 0.99392, 'recall_at_50': 0.99928, 'l2_dist': 25.79092025756836}


                                                                         

Epoch 060 | Train Loss: 1.398437 | Val Loss: 2.586512 | LR: 1.56e-04
{'mrr': 0.9346348233335966, 'ndcg': 0.950627299291487, 'recall_at_1': 0.89332, 'recall_at_3': 0.97284, 'recall_at_5': 0.98516, 'recall_at_10': 0.99392, 'recall_at_50': 0.99928, 'l2_dist': 25.79091453552246}


                                                                         

Epoch 061 | Train Loss: 1.398751 | Val Loss: 2.586562 | LR: 1.56e-04
{'mrr': 0.9346321484254573, 'ndcg': 0.9506236985005841, 'recall_at_1': 0.8934, 'recall_at_3': 0.97268, 'recall_at_5': 0.98528, 'recall_at_10': 0.99392, 'recall_at_50': 0.99924, 'l2_dist': 25.79088592529297}


                                                                         

Epoch 062 | Train Loss: 1.398613 | Val Loss: 2.586545 | LR: 1.56e-04
{'mrr': 0.9347112034793268, 'ndcg': 0.9506829478831946, 'recall_at_1': 0.89352, 'recall_at_3': 0.9726, 'recall_at_5': 0.98516, 'recall_at_10': 0.99388, 'recall_at_50': 0.99924, 'l2_dist': 25.790864944458008}


                                                                         

Epoch 063 | Train Loss: 1.400603 | Val Loss: 2.586180 | LR: 1.56e-04
{'mrr': 0.9348013972912316, 'ndcg': 0.950747921056819, 'recall_at_1': 0.89372, 'recall_at_3': 0.9726, 'recall_at_5': 0.98508, 'recall_at_10': 0.99384, 'recall_at_50': 0.99924, 'l2_dist': 25.79084014892578}
Saved new best model (mrr=0.934801)


                                                                         

Epoch 064 | Train Loss: 1.397282 | Val Loss: 2.586412 | LR: 7.81e-05
{'mrr': 0.9348855541679426, 'ndcg': 0.9508114014478065, 'recall_at_1': 0.89384, 'recall_at_3': 0.97272, 'recall_at_5': 0.98504, 'recall_at_10': 0.99384, 'recall_at_50': 0.99924, 'l2_dist': 25.79082489013672}
Saved new best model (mrr=0.934886)


                                                                        

Epoch 065 | Train Loss: 1.401224 | Val Loss: 2.586147 | LR: 7.81e-05
{'mrr': 0.9348400612951752, 'ndcg': 0.9507770592371252, 'recall_at_1': 0.89376, 'recall_at_3': 0.97272, 'recall_at_5': 0.98508, 'recall_at_10': 0.99388, 'recall_at_50': 0.99924, 'l2_dist': 25.790813446044922}


                                                                         

Epoch 066 | Train Loss: 1.392488 | Val Loss: 2.586265 | LR: 7.81e-05
{'mrr': 0.9348046586330465, 'ndcg': 0.950748652325498, 'recall_at_1': 0.89376, 'recall_at_3': 0.97276, 'recall_at_5': 0.985, 'recall_at_10': 0.99384, 'recall_at_50': 0.99924, 'l2_dist': 25.79079818725586}


                                                                        

Epoch 067 | Train Loss: 1.391595 | Val Loss: 2.586392 | LR: 7.81e-05
{'mrr': 0.934781808090187, 'ndcg': 0.9507341701978439, 'recall_at_1': 0.89364, 'recall_at_3': 0.97288, 'recall_at_5': 0.98504, 'recall_at_10': 0.99388, 'recall_at_50': 0.99924, 'l2_dist': 25.79077911376953}


                                                                         

Epoch 068 | Train Loss: 1.398097 | Val Loss: 2.586240 | LR: 3.91e-05
{'mrr': 0.9348148499636745, 'ndcg': 0.9507587198158036, 'recall_at_1': 0.89372, 'recall_at_3': 0.97292, 'recall_at_5': 0.98504, 'recall_at_10': 0.99388, 'recall_at_50': 0.99924, 'l2_dist': 25.790769577026367}


                                                                         

Epoch 069 | Train Loss: 1.395823 | Val Loss: 2.586165 | LR: 3.91e-05
{'mrr': 0.9347534383783844, 'ndcg': 0.9507133765447778, 'recall_at_1': 0.8936, 'recall_at_3': 0.97288, 'recall_at_5': 0.98504, 'recall_at_10': 0.99384, 'recall_at_50': 0.99924, 'l2_dist': 25.790767669677734}


                                                                         

Epoch 070 | Train Loss: 1.394935 | Val Loss: 2.586106 | LR: 3.91e-05
{'mrr': 0.9347316142814743, 'ndcg': 0.9506960763373038, 'recall_at_1': 0.8936, 'recall_at_3': 0.97292, 'recall_at_5': 0.98504, 'recall_at_10': 0.99388, 'recall_at_50': 0.99924, 'l2_dist': 25.7907657623291}


                                                                        

Epoch 071 | Train Loss: 1.392418 | Val Loss: 2.586043 | LR: 3.91e-05
{'mrr': 0.9346701840416668, 'ndcg': 0.9506513588162451, 'recall_at_1': 0.89344, 'recall_at_3': 0.97304, 'recall_at_5': 0.985, 'recall_at_10': 0.99388, 'recall_at_50': 0.99924, 'l2_dist': 25.790761947631836}


                                                                         

Epoch 072 | Train Loss: 1.397355 | Val Loss: 2.585938 | LR: 1.95e-05
{'mrr': 0.9346766073732048, 'ndcg': 0.9506566267715785, 'recall_at_1': 0.89344, 'recall_at_3': 0.973, 'recall_at_5': 0.985, 'recall_at_10': 0.99388, 'recall_at_50': 0.99924, 'l2_dist': 25.790761947631836}


                                                                         

Epoch 073 | Train Loss: 1.394305 | Val Loss: 2.585942 | LR: 1.95e-05
{'mrr': 0.9347187862117953, 'ndcg': 0.9506882097997466, 'recall_at_1': 0.89352, 'recall_at_3': 0.973, 'recall_at_5': 0.98504, 'recall_at_10': 0.99388, 'recall_at_50': 0.99924, 'l2_dist': 25.790761947631836}


                                                                         

Epoch 074 | Train Loss: 1.390131 | Val Loss: 2.585994 | LR: 1.95e-05
{'mrr': 0.9347284655248365, 'ndcg': 0.9506938806527518, 'recall_at_1': 0.89356, 'recall_at_3': 0.97292, 'recall_at_5': 0.98504, 'recall_at_10': 0.99388, 'recall_at_50': 0.99924, 'l2_dist': 25.790760040283203}
Early stopping triggered.
Training complete. Best mrr: 0.934728
Finished training. Now testing using best model...
Test Results: {'mrr': 0.9354795373180256, 'ndcg': 0.9513033506208108, 'recall_at_1': 0.89412, 'recall_at_3': 0.97404, 'recall_at_5': 0.9862, 'recall_at_10': 0.9944, 'recall_at_50': 0.99932, 'l2_dist': 25.790616989135742}


In [10]:
generate_submission(models, Path(test_path), output_file="sub.csv", device=device, ensemble=True)

Generating submission file...
✓ Saved submission to sub.csv


Unnamed: 0,id,embedding
0,1,"[0.011552590876817703, 0.00943775288760662, -0..."
1,2,"[-0.017103342339396477, -0.0030133062973618507..."
2,3,"[0.0025980216450989246, -0.008530800230801105,..."
3,4,"[0.03638746961951256, -0.0017537014791741967, ..."
4,5,"[0.054213255643844604, 0.03685036301612854, 0...."
...,...,...
1495,1496,"[-0.008818816393613815, -0.008429779671132565,..."
1496,1497,"[0.017080791294574738, 0.03277495130896568, 0...."
1497,1498,"[0.051879268139600754, -0.026452703401446342, ..."
1498,1499,"[-0.01728557050228119, -0.0016986479749903083,..."


In [14]:
import optuna
from optuna.pruners import MedianPruner



def objective(trial, train_dataset, val_dataset, epochs: int = 10, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    n_layers = trial.suggest_int("n_layers", 1, 3)
    layer_choices = [1256, 1565, 2048]
    hidden_layers = [trial.suggest_categorical(f"n_units_l{i}", layer_choices) for i in range(n_layers)]

    activation_fn = nn.ReLU
    
    batch_size = trial.suggest_categorical("batch_size", [2048, 4096])
    lr = trial.suggest_categorical("lr", [1e-4, 1e-3, 1e-2])
    dropout_rate = trial.suggest_categorical('dropout_rate', [0.3, 0.4, 0.5])

    weight_decay = trial.suggest_categorical("weight_decay", [1e-4, 1e-3, 1e-2])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = SpaceTranslator(
        input_dim=1024, 
        output_dim=1536,
        hidden_layers=hidden_layers,
        activation=activation_fn,
        dropout_rate=dropout_rate
    )
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=0.5,           # reduce LR by half when plateau
        patience=2,           # wait 2 epochs before reducing LR
        min_lr=1e-6,          # don't go below this LR
    )

    best_val_loss = float('inf')
    no_improvements = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_train_loss = 0.0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = info_nce_loss(outputs, y_batch, model.logit_scale)

            loss.backward()
            optimizer.step()

            running_train_loss += loss.item()

        train_loss = running_train_loss / len(train_loader)

        # Validation phase
        model.eval()
        running_val_loss = 0.0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)

                outputs = model(X_batch)
                loss = info_nce_loss(outputs, y_batch, model.logit_scale)
                running_val_loss += loss.item()

        val_loss = running_val_loss / len(val_loader)

        scheduler.step(val_loss)

        results = test(val_dataset, model, device)

        trial.report(results['mrr'], epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()


    return results['mrr']


def run_optuna_search(data_path: Path, n_trials: int = 30, epochs: int = 10, n_jobs: int = 1, pruner=None):
    if pruner is None:
        pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=1)

    X, Y = get_data(data_path)    
    train_dataset, val_dataset = get_datasets(X, Y)

    study = optuna.create_study(direction="maximize", pruner=pruner)
    func = lambda trial: objective(trial, train_dataset=train_dataset, val_dataset=val_dataset,
                                   epochs=epochs)
    study.optimize(func, n_trials=n_trials, n_jobs=n_jobs)

    print("Study statistics:")
    print("  Number of finished trials: ", len(study.trials))
    print("  Best trial:")
    trial = study.best_trial
    print("    Value: ", trial.value)
    print("    Params: ")
    for k, v in trial.params.items():
        print(f"      {k}: {v}")

    return study

In [None]:
study = run_optuna_search(data_path=data_path, n_trials=200, epochs=30, n_jobs=1)
study.trials_dataframe().to_csv("optuna_trials.csv", index=False)

best_trial_number = study.best_trial.number
print("Best params:", study.best_params)
print("Best trial number:", study.best_trial.number)