In [15]:
from typing import Optional
from torch import nn
from torch.nn import functional as F
import torch

class _SpaceTranslator(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_layers,
        activation,
        dropout_rate
    ):
        super().__init__()

        layers = []
        last = input_dim

        for hidden in hidden_layers:
            layers += [
                nn.Linear(last, hidden),
                # nn.LayerNorm(hidden),
                nn.BatchNorm1d(hidden),
                activation(),
                nn.Dropout(dropout_rate)
            ]
            last = hidden
        self.body = nn.Sequential(*layers)

        self.final = nn.Linear(last, output_dim)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_uniform_(module.weight, a=0, mode='fan_in', nonlinearity='relu')
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.LayerNorm) or isinstance(module, nn.BatchNorm1d):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x):
        x = self.body(x)
        return self.final(x)



class SpaceTranslator(nn.Module):
    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_layers,
        activation,
        dropout_rate
    ):
        super().__init__()

        layers = []
        last = input_dim

        for hidden in hidden_layers:
            layers += [
                nn.Linear(last, hidden),
                nn.LayerNorm(hidden),
                activation(),
                nn.Dropout(dropout_rate)
            ]
            last = hidden

        layers.append(nn.Linear(last, output_dim))
        self.net = nn.Sequential(*layers)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)


    def forward(self, x):
        return F.normalize(self.net(x), p=2, dim=1)
        #return self.net(x)


In [16]:
from pathlib import Path
import numpy as np
import torch
import pandas as pd

'''Code from https://github.com/Mamiglia/challenge'''

def mrr(pred_indices: np.ndarray, gt_indices: np.ndarray) -> float:
    """
    Compute Mean Reciprocal Rank (MRR)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries (top-K)
        gt_indices: (N,) array of ground truth indices
    Returns:
        mrr: Mean Reciprocal Rank
    """
    reciprocal_ranks = []
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i] == gt_indices[i])[0]
        if matches.size > 0:
            reciprocal_ranks.append(1.0 / (matches[0] + 1))
        else:
            reciprocal_ranks.append(0.0)
    return np.mean(reciprocal_ranks)


def recall_at_k(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int) -> float:
    """Compute Recall@k
    Args:
        pred_indices: (N, N) array of top indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        recall: Recall@k
    """
    recall = 0
    for i in range(len(gt_indices)):
        if gt_indices[i] in pred_indices[i, :k]:
            recall += 1
    recall /= len(gt_indices)
    return recall

import numpy as np

def ndcg(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int = 100) -> float:
    """
    Compute Normalized Discounted Cumulative Gain (NDCG@k)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        ndcg: NDCG@k
    """
    ndcg_total = 0.0
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i, :k] == gt_indices[i])[0]
        if matches.size > 0:
            rank = matches[0] + 1
            ndcg_total += 1.0 / np.log2(rank + 1)  # DCG (IDCG = 1)
    return ndcg_total / len(gt_indices)



@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """Evaluate retrieval performance using cosine similarity
    Args:
        translated_embd: (N_captions, D) translated caption embeddings
        image_embd: (N_images, D) image embeddings
        gt_indices: (N_captions,) ground truth image indices for each caption
        max_indices: number of top predictions to consider
    Returns:
        results: dict of evaluation metrics
    
    """
    # Compute similarity matrix
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()
    
    n_queries = translated_embd.shape[0]
    device = translated_embd.device
    
    # Prepare containers for the fragments to be reassembled
    all_sorted_indices = []
    l2_distances = []
    
    # Process in batches - the narrow gate approach
    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]
        batch_img_embd = image_embd[batch_slice]
        
        # Compute similarity only for this batch
        batch_similarity = batch_translated @ batch_img_embd.T

        # Get top-k predictions for this batch
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.numpy()
        all_sorted_indices.append(gt_indices[batch_slice][batch_indices])

        # Compute L2 distance for this batch
        batch_gt = gt_indices[batch_slice]
        batch_gt_embeddings = image_embd[batch_gt]
        batch_l2 = (batch_translated - batch_gt_embeddings).norm(dim=1)
        l2_distances.append(batch_l2)
    
    # Reassemble the fragments
    sorted_indices = np.concatenate(all_sorted_indices, axis=0)
    
    # Apply the sacred metrics to the whole
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }
    
    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }
    
    l2_dist = torch.cat(l2_distances, dim=0).mean().item()
    results['l2_dist'] = l2_dist
    
    return results



def generate_submission(models: nn.Module, test_path: Path, ensemble, output_file="submission-dirmodel.csv", device=None):
    test_data = np.load(test_path)
    sample_ids = test_data['captions/ids']
    test_embds = test_data['captions/embeddings']
    test_embds = torch.from_numpy(test_embds).float()

    with torch.no_grad():
        if ensemble:
            preds = []
            for model in models:
                model.eval()
                preds.append(model(test_embds.to(device)).to('cpu'))
                
            pred_embds = torch.stack(preds).mean(dim=0)
        else:
            pred_embds = model(test_embds.to(device)).cpu()
    print("Generating submission file...")

    if isinstance(pred_embds, torch.Tensor):
        pred_embds = pred_embds.cpu().numpy()

    df_submission = pd.DataFrame({'id': sample_ids, 'embedding': pred_embds.tolist()})

    df_submission.to_csv(output_file, index=False, float_format='%.17g')
    print(f"âœ“ Saved submission to {output_file}")

    return df_submission

In [17]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
from pathlib import Path
from tqdm import tqdm
import os
import shutil

def compute_targets_metrics(train_dataset, device):
    sum_vec = None
    sum_sq = None
    n = len(train_dataset)
    dummy_loader = DataLoader(train_dataset, len(train_dataset))
    for x,y in dummy_loader:
        b = y.shape[0]

    sum_vec = y.sum(dim=0).double().to(device)
    sum_sq = (y * y).sum(dim=0).double().to(device)
    
    target_mean = (sum_vec / n).float().to(device)
    var = (sum_sq / n - target_mean.double()**2).float().clamp(min=0.0)
    target_std = torch.sqrt(var + 1e-12).to(device)
    return target_mean, target_std

'''
def moment_loss(preds, target_mean, target_std, mode='per_dim', coef=5e-2):
    if mode == 'per_dim':
        pm = preds.mean(dim=0)
        ps = preds.std(dim=0)
        loss_mean = F.mse_loss(pm, target_mean)
        loss_std  = F.mse_loss(ps,  target_std)
        return coef * (loss_mean + loss_std)
    elif mode == 'scalar':
        pm = preds.mean()
        ps = preds.std()
        tm = target_mean.mean()
        ts = target_std.mean()
        return coef * (F.mse_loss(pm, tm) + F.mse_loss(ps, ts))
    
def info_nce_loss(dir_preds, img_targets, logit_scale):
    dir_preds = F.normalize(dir_preds, dim=-1)
    img_targets = F.normalize(img_targets, dim=-1)
    
    logit_scale = torch.clamp(logit_scale, max=np.log(100))
    logits = (dir_preds @ img_targets.T) * logit_scale.exp()

    labels = torch.arange(logits.size(0), device=logits.device)

    loss_t2i = F.cross_entropy(logits, labels)
    loss_i2t = F.cross_entropy(logits.T, labels)

    return 0.5 * (loss_t2i + loss_i2t)
'''
def rank_aware_loss(logits, labels, top_k=10, margin=0.1):
    batch_size = logits.size(0)
    
    #  correct item scores
    correct_scores = logits[torch.arange(batch_size), labels].unsqueeze(1)
    
    # how much each item's score exceeds the correct item's score
    margins = logits - correct_scores + margin
    
    # mask all ones but diagonal all zeros
    mask = ~torch.eye(batch_size, dtype=torch.bool, device=logits.device)
    
    # only keep when items scored higher than the positive one
    violations = F.relu(margins) * mask.float()
    
    # boolean mask where items score higher than the correct item and sum to count how many items per row score higher
    ranks = (logits > correct_scores).float().sum(dim=1) + 1
    ranks = torch.clamp(ranks, max=top_k)
    
    # Weight by 1/rank
    weights = 1.0 / ranks
    
    # sum of violations weighted by rank
    loss = (violations.sum(dim=1) * weights).mean()
    
    return loss
   

def combined_loss(dir_preds, img_targets, logit_scale, alpha=0.7, beta=0.3):
    #dir_preds_norm = F.normalize(dir_preds, dim=-1)
    #img_targets_norm = F.normalize(img_targets, dim=-1)
    logits = dir_preds @ img_targets.T * logit_scale.exp()
    labels = torch.arange(logits.size(0), device=logits.device)
    
    loss_contrastive = info_nce_loss(dir_preds, img_targets, logit_scale, alpha=alpha)
    
    loss_rank_t2i = rank_aware_loss(logits, labels, top_k=10)
    loss_rank_i2t = rank_aware_loss(logits.T, labels, top_k=10)
    loss_rank = 0.5 * (loss_rank_t2i + loss_rank_i2t)
    
    total_loss = loss_contrastive + beta * loss_rank
    
    return total_loss
    
def info_nce_loss(
    dir_preds,
    img_targets,
    logit_scale: torch.Tensor,
    margin: float = 0.3,
    alpha: float = 0.7
):
    dir_preds = F.normalize(dir_preds, dim=-1)
    img_targets = F.normalize(img_targets, dim=-1)

    logit_scale = torch.clamp(logit_scale, min=np.log(0.01), max=np.log(100))

    logits = dir_preds @ img_targets.T * logit_scale.exp() # (batch_size, batch_size) matrix
    labels = torch.arange(logits.size(0), device=logits.device)
    loss_t2i = F.cross_entropy(logits, labels)
    loss_i2t = F.cross_entropy(logits.T, labels)
    loss_nce = 0.5 * (loss_t2i + loss_i2t)

    mask = torch.eye(logits.size(0), device=logits.device) # ones only on diagonal
    logits_no_pos = logits - mask * 1e9                    # subtracts huge value from positive pairs so will not be selected
    hardest_neg_t2i = logits_no_pos.max(dim=1).values      # max similarity in each row
    positive_sim_t2i = torch.diag(logits) 
    # if hardest_negative < positive - margin -> loss = 0 (due to ReLU) otherwise keep value
    loss_hard_t2i = F.relu(hardest_neg_t2i - positive_sim_t2i + margin).mean()

    logits_no_pos_i2t = logits.T - mask * 1e9
    hardest_neg_i2t = logits_no_pos_i2t.max(dim=1).values
    positive_sim_i2t = torch.diag(logits.T)
    loss_hard_i2t = F.relu(hardest_neg_i2t - positive_sim_i2t + margin).mean()

    loss_hard_total = 0.5 * (loss_hard_t2i + loss_hard_i2t)
    loss = loss_nce + alpha * loss_hard_total

    return loss




def train_model_direction(
    model: SpaceTranslator,
    model_path: Path,
    train_dataset: TensorDataset,
    val_dataset: TensorDataset,
    batch_size: int,
    epochs: int,
    lr: float,
    patience: int,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    shutil.rmtree(Path(model_path).parent, ignore_errors=True)
    print(f"Using device: {device}")

    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3, threshold=0.001, min_lr=1e-9
    )
    
    ema_state = {k: v.clone().detach().cpu() for k, v in model.state_dict().items()}
    ema_decay = 0.998

    best_mrr = float('-inf')
    best_val_loss = float('inf')
    no_improvements = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_train_loss = 0.0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)
        for X_batch, y_batch in progress_bar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            # X_batch = F.normalize(X_batch, dim=-1)
            # y_batch = F.normalize(y_batch, dim=-1)

            optimizer.zero_grad()
            outputs = model(X_batch)
            
            loss = combined_loss(outputs, y_batch, model.logit_scale)
            loss.backward()
            optimizer.step()
            ema_update(model, ema_state, ema_decay)

            running_train_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        train_loss = running_train_loss / len(train_loader)

        # Validation phase
        model.eval()
        running_val_loss = 0.0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                # X_batch = F.normalize(X_batch, dim=-1)
                # y_batch = F.normalize(y_batch, dim=-1)

                outputs = model(X_batch)
                loss = combined_loss(outputs, y_batch, model.logit_scale)

                running_val_loss += loss.item()

        val_loss = running_val_loss / len(val_loader)

        

        print(f"Epoch {epoch:03d} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | LR: {optimizer.param_groups[0]['lr']:.2e}")

        # Optional: external validation/test
        results = test(val_dataset, model, device, ensemble=False)
        print(results)
        mrr = results['mrr']

        # Step the scheduler
        scheduler.step(mrr)
        
        # Save best model and early stopping
        if mrr > best_mrr:
            best_mrr = mrr
            no_improvements = 0

            Path(model_path).parent.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), f'./{Path(model_path).parent}/{epoch}.pth')
            print(f"Saved new best model (mrr={mrr:.6f})")
        else:
            no_improvements += 1
            if no_improvements >= patience:
                print("Early stopping triggered.")
                break

    print(f"Training complete. Best mrr: {best_mrr:.6f}")

    return model, ema_state

def ema_update(model, ema_state, decay):
    with torch.no_grad():
        for k, v in model.state_dict().items():
            if v.dtype.is_floating_point:
                ema_state[k].mul_(decay).add_(v.detach().cpu(), alpha=1 - decay)


def get_data(data_path: Path):
    data = np.load(data_path)
    caption_embeddings = data['captions/embeddings']
    image_embeddings = data['images/embeddings']
    caption_labels = data['captions/label']
    data.close()

    X_abs, y_abs = torch.tensor(caption_embeddings), torch.tensor(image_embeddings[np.argmax(caption_labels, axis=1)])

    return X_abs, y_abs

    
def get_datasets(X_abs, y_abs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:    
    print('Texts shape', X_abs.shape)
    print('Images shape', y_abs.shape)
    
    dataset = TensorDataset(X_abs, y_abs)
    train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
    
    return train_dataset, val_dataset

def test(val_dataset: TensorDataset, model, device, ensemble):
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
    for x_val, y_val in val_loader:
        results = eval_on_val(x_val, y_val, models=model, device=device, ensemble=ensemble)
    return results
    
def eval_on_val(x_val: np.ndarray, y_val: np.ndarray, models, device, ensemble) -> dict:
    gt_indices = torch.arange(len(y_val))
    
    with torch.no_grad():
        if ensemble:
            preds = []
            for model in models:
                model.eval()
                preds.append(model(x_val.to(device)).to('cpu'))
            translated = torch.stack(preds).mean(dim=0)

        else:
            models.eval()
            translated = models(x_val.to(device)).to('cpu')

    results = evaluate_retrieval(translated, y_val, gt_indices)
    
    return results

    
def average_checkpoints(models_path):
    models = sorted(os.listdir(models_path), key=lambda name: int(name.split('.')[0]))[-3:]
    print(models)
    avg_state = None
    for model in models:
        state = torch.load(f'{models_path}/{model}')
        if avg_state is None:
            avg_state = {k: v.clone().float() for k, v in state.items()}
        else:
            for k in avg_state:
                avg_state[k] += state[k].float()
    for k in avg_state:
        avg_state[k] /= len(models)
    return avg_state


In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = 1024
output_dim = 1536
hidden_layers=[1256, 1536]
dropout_rate = 0.5

batch_size= 4096
lr=0.01
epochs=200
patience = 10
model_path='./models/best_model.pth'

data_path= '/kaggle/input/aml-competition/train/train/train.npz'
test_path= '/kaggle/input/aml-competition/test/test/test.clean.npz'

In [19]:
dir_x, dir_y = get_data(data_path)

target_norm_mean = dir_y.norm(dim=1).mean().item()

dir_train_dataset, dir_val_dataset = get_datasets(dir_x, dir_y)

print('Target norm mean', target_norm_mean)

Texts shape torch.Size([125000, 1024])
Images shape torch.Size([125000, 1536])
Target norm mean 25.93919563293457


In [20]:
target_mean, target_std = compute_targets_metrics(dir_train_dataset, device)

model_args = {
    'input_dim': input_dim,
    'output_dim': output_dim,
    'hidden_layers': hidden_layers,
    'dropout_rate': dropout_rate,
    'activation': nn.GELU,
}

last_model = SpaceTranslator(**model_args).to(device)
last_model, ema_state = train_model_direction(last_model, './models/best_model.pth', dir_train_dataset, dir_val_dataset, batch_size, epochs, lr, patience)

print('Finished training. Now testing using best model...')

ema_model = SpaceTranslator(**model_args).to(device)
sd = {k: v.to(device) for k, v in ema_state.items()}
ema_model.load_state_dict(sd)

avg_model = SpaceTranslator(**model_args).to(device)
avg_state = average_checkpoints(Path(model_path).parent)
avg_model.load_state_dict(avg_state)

models = [avg_model, last_model, ema_model]
torch.save(ema_model.state_dict(), f'./models/ema_model.pth')
torch.save(last_model.state_dict(), f'./models/last_model.pth')
torch.save(avg_model.state_dict(), f'./models/avg_model.pth')
results = test(dir_val_dataset, models, device, ensemble=True)
print("Test Results:", results)


Using device: cuda


                                                                         

Epoch 001 | Train Loss: 210.356959 | Val Loss: 35.736558 | LR: 1.00e-02
{'mrr': 0.34877361304261884, 'ndcg': 0.48941262328763196, 'recall_at_1': 0.18564, 'recall_at_3': 0.40172, 'recall_at_5': 0.53464, 'recall_at_10': 0.7224, 'recall_at_50': 0.98676, 'l2_dist': 25.85254669189453}
Saved new best model (mrr=0.348774)


                                                                       

Epoch 002 | Train Loss: 35.288618 | Val Loss: 24.271111 | LR: 1.00e-02
{'mrr': 0.39042217431178217, 'ndcg': 0.5242987210737555, 'recall_at_1': 0.2206, 'recall_at_3': 0.45708, 'recall_at_5': 0.5966, 'recall_at_10': 0.77388, 'recall_at_50': 0.99008, 'l2_dist': 25.857450485229492}
Saved new best model (mrr=0.390422)


                                                                       

Epoch 003 | Train Loss: 26.790894 | Val Loss: 20.305132 | LR: 1.00e-02
{'mrr': 0.4207454957178096, 'ndcg': 0.5492788624216108, 'recall_at_1': 0.24816, 'recall_at_3': 0.49536, 'recall_at_5': 0.63504, 'recall_at_10': 0.80476, 'recall_at_50': 0.99236, 'l2_dist': 25.859094619750977}
Saved new best model (mrr=0.420745)


                                                                       

Epoch 004 | Train Loss: 22.908970 | Val Loss: 17.941972 | LR: 1.00e-02
{'mrr': 0.44564420041696057, 'ndcg': 0.569706427576807, 'recall_at_1': 0.2708, 'recall_at_3': 0.53016, 'recall_at_5': 0.66552, 'recall_at_10': 0.83044, 'recall_at_50': 0.99332, 'l2_dist': 25.86042022705078}
Saved new best model (mrr=0.445644)


                                                                       

Epoch 005 | Train Loss: 20.292939 | Val Loss: 16.324618 | LR: 1.00e-02
{'mrr': 0.4680773722517407, 'ndcg': 0.5880094738570555, 'recall_at_1': 0.2916, 'recall_at_3': 0.55832, 'recall_at_5': 0.6936, 'recall_at_10': 0.84932, 'recall_at_50': 0.99444, 'l2_dist': 25.861726760864258}
Saved new best model (mrr=0.468077)


                                                                       

Epoch 006 | Train Loss: 18.563441 | Val Loss: 15.124334 | LR: 1.00e-02
{'mrr': 0.48827609607066486, 'ndcg': 0.6043349164683561, 'recall_at_1': 0.31168, 'recall_at_3': 0.58184, 'recall_at_5': 0.71684, 'recall_at_10': 0.86408, 'recall_at_50': 0.99568, 'l2_dist': 25.862689971923828}
Saved new best model (mrr=0.488276)


                                                                       

Epoch 007 | Train Loss: 17.245295 | Val Loss: 14.237216 | LR: 1.00e-02
{'mrr': 0.5067044241321723, 'ndcg': 0.6191919410211442, 'recall_at_1': 0.33024, 'recall_at_3': 0.60364, 'recall_at_5': 0.7368, 'recall_at_10': 0.8782, 'recall_at_50': 0.9962, 'l2_dist': 25.863325119018555}
Saved new best model (mrr=0.506704)


                                                                       

Epoch 008 | Train Loss: 16.202508 | Val Loss: 13.533494 | LR: 1.00e-02
{'mrr': 0.5220217209005125, 'ndcg': 0.6314892865895978, 'recall_at_1': 0.34648, 'recall_at_3': 0.62272, 'recall_at_5': 0.7544, 'recall_at_10': 0.88784, 'recall_at_50': 0.9968, 'l2_dist': 25.864139556884766}
Saved new best model (mrr=0.522022)


                                                                       

Epoch 009 | Train Loss: 15.361560 | Val Loss: 12.969336 | LR: 1.00e-02
{'mrr': 0.5363675143713611, 'ndcg': 0.6429445344192745, 'recall_at_1': 0.36172, 'recall_at_3': 0.63916, 'recall_at_5': 0.7684, 'recall_at_10': 0.89588, 'recall_at_50': 0.99704, 'l2_dist': 25.864116668701172}
Saved new best model (mrr=0.536368)


                                                                        

Epoch 010 | Train Loss: 14.674307 | Val Loss: 12.509549 | LR: 1.00e-02
{'mrr': 0.5486409233456253, 'ndcg': 0.6527671615528109, 'recall_at_1': 0.37468, 'recall_at_3': 0.65516, 'recall_at_5': 0.78076, 'recall_at_10': 0.90196, 'recall_at_50': 0.99748, 'l2_dist': 25.864364624023438}
Saved new best model (mrr=0.548641)


                                                                        

Epoch 011 | Train Loss: 14.038562 | Val Loss: 12.117681 | LR: 1.00e-02
{'mrr': 0.5608706011037617, 'ndcg': 0.6625131349398954, 'recall_at_1': 0.38752, 'recall_at_3': 0.66936, 'recall_at_5': 0.79312, 'recall_at_10': 0.90896, 'recall_at_50': 0.99764, 'l2_dist': 25.8647403717041}
Saved new best model (mrr=0.560871)


                                                                        

Epoch 012 | Train Loss: 13.572878 | Val Loss: 11.786315 | LR: 1.00e-02
{'mrr': 0.5711998690075255, 'ndcg': 0.6707708238953514, 'recall_at_1': 0.39808, 'recall_at_3': 0.68264, 'recall_at_5': 0.80356, 'recall_at_10': 0.91612, 'recall_at_50': 0.9978, 'l2_dist': 25.8648624420166}
Saved new best model (mrr=0.571200)


                                                                        

Epoch 013 | Train Loss: 13.164705 | Val Loss: 11.505875 | LR: 1.00e-02
{'mrr': 0.5818324382165228, 'ndcg': 0.6792059000463424, 'recall_at_1': 0.40932, 'recall_at_3': 0.69644, 'recall_at_5': 0.81276, 'recall_at_10': 0.92064, 'recall_at_50': 0.99792, 'l2_dist': 25.864990234375}
Saved new best model (mrr=0.581832)


                                                                        

Epoch 014 | Train Loss: 12.784986 | Val Loss: 11.259336 | LR: 1.00e-02
{'mrr': 0.5921224799341613, 'ndcg': 0.6873067348836344, 'recall_at_1': 0.42184, 'recall_at_3': 0.70572, 'recall_at_5': 0.82236, 'recall_at_10': 0.92564, 'recall_at_50': 0.99796, 'l2_dist': 25.86508560180664}
Saved new best model (mrr=0.592122)


                                                                        

Epoch 015 | Train Loss: 12.484209 | Val Loss: 11.039776 | LR: 1.00e-02
{'mrr': 0.6014904970174199, 'ndcg': 0.6947177153806606, 'recall_at_1': 0.4322, 'recall_at_3': 0.71612, 'recall_at_5': 0.8298, 'recall_at_10': 0.93044, 'recall_at_50': 0.99812, 'l2_dist': 25.865079879760742}
Saved new best model (mrr=0.601490)


                                                                        

Epoch 016 | Train Loss: 12.172440 | Val Loss: 10.848015 | LR: 1.00e-02
{'mrr': 0.6106687119355172, 'ndcg': 0.7019350030021027, 'recall_at_1': 0.4432, 'recall_at_3': 0.7256, 'recall_at_5': 0.83808, 'recall_at_10': 0.93408, 'recall_at_50': 0.99828, 'l2_dist': 25.865314483642578}
Saved new best model (mrr=0.610669)


                                                                        

Epoch 017 | Train Loss: 11.893793 | Val Loss: 10.681971 | LR: 1.00e-02
{'mrr': 0.6184543675716054, 'ndcg': 0.7080774375459754, 'recall_at_1': 0.45192, 'recall_at_3': 0.73404, 'recall_at_5': 0.8448, 'recall_at_10': 0.93728, 'recall_at_50': 0.99852, 'l2_dist': 25.865150451660156}
Saved new best model (mrr=0.618454)


                                                                        

Epoch 018 | Train Loss: 11.693225 | Val Loss: 10.530918 | LR: 1.00e-02
{'mrr': 0.6270212814334672, 'ndcg': 0.714806661141468, 'recall_at_1': 0.46212, 'recall_at_3': 0.7436, 'recall_at_5': 0.8504, 'recall_at_10': 0.94152, 'recall_at_50': 0.99852, 'l2_dist': 25.865198135375977}
Saved new best model (mrr=0.627021)


                                                                        

Epoch 019 | Train Loss: 11.503738 | Val Loss: 10.397065 | LR: 1.00e-02
{'mrr': 0.6342342408070779, 'ndcg': 0.7204851286048964, 'recall_at_1': 0.47068, 'recall_at_3': 0.7518, 'recall_at_5': 0.85628, 'recall_at_10': 0.94472, 'recall_at_50': 0.99856, 'l2_dist': 25.865339279174805}
Saved new best model (mrr=0.634234)


                                                                        

Epoch 020 | Train Loss: 11.320509 | Val Loss: 10.275093 | LR: 1.00e-02
{'mrr': 0.6415318909604202, 'ndcg': 0.7262316914396556, 'recall_at_1': 0.47884, 'recall_at_3': 0.75912, 'recall_at_5': 0.86244, 'recall_at_10': 0.94752, 'recall_at_50': 0.99856, 'l2_dist': 25.865400314331055}
Saved new best model (mrr=0.641532)


                                                                        

Epoch 021 | Train Loss: 11.136556 | Val Loss: 10.165616 | LR: 1.00e-02
{'mrr': 0.6475219729802468, 'ndcg': 0.7309425558995117, 'recall_at_1': 0.4858, 'recall_at_3': 0.7666, 'recall_at_5': 0.86692, 'recall_at_10': 0.95028, 'recall_at_50': 0.99864, 'l2_dist': 25.864776611328125}
Saved new best model (mrr=0.647522)


                                                                        

Epoch 022 | Train Loss: 10.977596 | Val Loss: 10.065384 | LR: 1.00e-02
{'mrr': 0.6553078133586042, 'ndcg': 0.7369893413380989, 'recall_at_1': 0.4958, 'recall_at_3': 0.77276, 'recall_at_5': 0.87176, 'recall_at_10': 0.95256, 'recall_at_50': 0.99876, 'l2_dist': 25.86504554748535}
Saved new best model (mrr=0.655308)


                                                                        

Epoch 023 | Train Loss: 10.847889 | Val Loss: 9.971278 | LR: 1.00e-02
{'mrr': 0.6613660579342835, 'ndcg': 0.7417658300558547, 'recall_at_1': 0.50232, 'recall_at_3': 0.78044, 'recall_at_5': 0.87716, 'recall_at_10': 0.9552, 'recall_at_50': 0.9988, 'l2_dist': 25.864723205566406}
Saved new best model (mrr=0.661366)


                                                                        

Epoch 024 | Train Loss: 10.722004 | Val Loss: 9.882383 | LR: 1.00e-02
{'mrr': 0.6677908640693652, 'ndcg': 0.7467806740657235, 'recall_at_1': 0.51044, 'recall_at_3': 0.7856, 'recall_at_5': 0.88216, 'recall_at_10': 0.9568, 'recall_at_50': 0.99892, 'l2_dist': 25.86505699157715}
Saved new best model (mrr=0.667791)


                                                                        

Epoch 025 | Train Loss: 10.616902 | Val Loss: 9.807582 | LR: 1.00e-02
{'mrr': 0.6738468768985345, 'ndcg': 0.7514809518482916, 'recall_at_1': 0.51812, 'recall_at_3': 0.79212, 'recall_at_5': 0.88552, 'recall_at_10': 0.9582, 'recall_at_50': 0.99896, 'l2_dist': 25.86491584777832}
Saved new best model (mrr=0.673847)


                                                                        

Epoch 026 | Train Loss: 10.514111 | Val Loss: 9.734742 | LR: 1.00e-02
{'mrr': 0.6803789858812286, 'ndcg': 0.7565783691914064, 'recall_at_1': 0.52604, 'recall_at_3': 0.79948, 'recall_at_5': 0.88996, 'recall_at_10': 0.96024, 'recall_at_50': 0.99904, 'l2_dist': 25.86466407775879}
Saved new best model (mrr=0.680379)


                                                                        

Epoch 027 | Train Loss: 10.403069 | Val Loss: 9.664301 | LR: 1.00e-02
{'mrr': 0.6848934200408667, 'ndcg': 0.7601279338401109, 'recall_at_1': 0.53072, 'recall_at_3': 0.80404, 'recall_at_5': 0.8938, 'recall_at_10': 0.9616, 'recall_at_50': 0.99896, 'l2_dist': 25.864599227905273}
Saved new best model (mrr=0.684893)


                                                                        

Epoch 028 | Train Loss: 10.301274 | Val Loss: 9.600293 | LR: 1.00e-02
{'mrr': 0.6910497358208493, 'ndcg': 0.7649012672159193, 'recall_at_1': 0.5386, 'recall_at_3': 0.81064, 'recall_at_5': 0.89696, 'recall_at_10': 0.96336, 'recall_at_50': 0.99904, 'l2_dist': 25.86449432373047}
Saved new best model (mrr=0.691050)


                                                                        

Epoch 029 | Train Loss: 10.212447 | Val Loss: 9.548667 | LR: 1.00e-02
{'mrr': 0.696331961432998, 'ndcg': 0.769012088418852, 'recall_at_1': 0.54516, 'recall_at_3': 0.81552, 'recall_at_5': 0.9006, 'recall_at_10': 0.96524, 'recall_at_50': 0.99904, 'l2_dist': 25.86447525024414}
Saved new best model (mrr=0.696332)


                                                                        

Epoch 030 | Train Loss: 10.139976 | Val Loss: 9.493090 | LR: 1.00e-02
{'mrr': 0.7016808874218388, 'ndcg': 0.7731652092116216, 'recall_at_1': 0.55196, 'recall_at_3': 0.82084, 'recall_at_5': 0.90428, 'recall_at_10': 0.96648, 'recall_at_50': 0.99908, 'l2_dist': 25.86417579650879}
Saved new best model (mrr=0.701681)


                                                                        

Epoch 031 | Train Loss: 10.053629 | Val Loss: 9.442781 | LR: 1.00e-02
{'mrr': 0.7061084992814336, 'ndcg': 0.7766056709439453, 'recall_at_1': 0.55736, 'recall_at_3': 0.82544, 'recall_at_5': 0.90632, 'recall_at_10': 0.96768, 'recall_at_50': 0.99908, 'l2_dist': 25.86440086364746}
Saved new best model (mrr=0.706108)


                                                                        

Epoch 032 | Train Loss: 9.999943 | Val Loss: 9.396892 | LR: 1.00e-02
{'mrr': 0.710577275134074, 'ndcg': 0.7800868647309288, 'recall_at_1': 0.56268, 'recall_at_3': 0.83024, 'recall_at_5': 0.91, 'recall_at_10': 0.96892, 'recall_at_50': 0.99908, 'l2_dist': 25.86409568786621}
Saved new best model (mrr=0.710577)


                                                                        

Epoch 033 | Train Loss: 9.925764 | Val Loss: 9.351394 | LR: 1.00e-02
{'mrr': 0.7160230155563707, 'ndcg': 0.7842881324668319, 'recall_at_1': 0.56984, 'recall_at_3': 0.83548, 'recall_at_5': 0.91316, 'recall_at_10': 0.97036, 'recall_at_50': 0.9992, 'l2_dist': 25.863929748535156}
Saved new best model (mrr=0.716023)


                                                                        

Epoch 034 | Train Loss: 9.857046 | Val Loss: 9.312234 | LR: 1.00e-02
{'mrr': 0.7197352410145846, 'ndcg': 0.7871500755160256, 'recall_at_1': 0.57484, 'recall_at_3': 0.8382, 'recall_at_5': 0.9146, 'recall_at_10': 0.9712, 'recall_at_50': 0.99924, 'l2_dist': 25.863922119140625}
Saved new best model (mrr=0.719735)


                                                                        

Epoch 035 | Train Loss: 9.809270 | Val Loss: 9.272635 | LR: 1.00e-02
{'mrr': 0.7250683693141633, 'ndcg': 0.7912782048554682, 'recall_at_1': 0.5814, 'recall_at_3': 0.84304, 'recall_at_5': 0.9178, 'recall_at_10': 0.97216, 'recall_at_50': 0.99924, 'l2_dist': 25.86406707763672}
Saved new best model (mrr=0.725068)


                                                                        

Epoch 036 | Train Loss: 9.734685 | Val Loss: 9.236198 | LR: 1.00e-02
{'mrr': 0.7294214466916015, 'ndcg': 0.7946476226335751, 'recall_at_1': 0.58704, 'recall_at_3': 0.84768, 'recall_at_5': 0.92032, 'recall_at_10': 0.97236, 'recall_at_50': 0.99924, 'l2_dist': 25.864089965820312}
Saved new best model (mrr=0.729421)


                                                                        

Epoch 037 | Train Loss: 9.691483 | Val Loss: 9.206135 | LR: 1.00e-02
{'mrr': 0.7336941487000487, 'ndcg': 0.797933255565202, 'recall_at_1': 0.59268, 'recall_at_3': 0.8508, 'recall_at_5': 0.9224, 'recall_at_10': 0.97376, 'recall_at_50': 0.99928, 'l2_dist': 25.8636417388916}
Saved new best model (mrr=0.733694)


                                                                        

Epoch 038 | Train Loss: 9.642693 | Val Loss: 9.171164 | LR: 1.00e-02
{'mrr': 0.7375329366780221, 'ndcg': 0.8009354921923227, 'recall_at_1': 0.59684, 'recall_at_3': 0.8554, 'recall_at_5': 0.92532, 'recall_at_10': 0.97472, 'recall_at_50': 0.99924, 'l2_dist': 25.864004135131836}
Saved new best model (mrr=0.737533)


                                                                        

Epoch 039 | Train Loss: 9.588188 | Val Loss: 9.139769 | LR: 1.00e-02
{'mrr': 0.7419958088720762, 'ndcg': 0.8043539647856982, 'recall_at_1': 0.60352, 'recall_at_3': 0.85776, 'recall_at_5': 0.92752, 'recall_at_10': 0.9762, 'recall_at_50': 0.99932, 'l2_dist': 25.863847732543945}
Saved new best model (mrr=0.741996)


                                                                        

Epoch 040 | Train Loss: 9.558002 | Val Loss: 9.114191 | LR: 1.00e-02
{'mrr': 0.7451040692043837, 'ndcg': 0.8067514510240792, 'recall_at_1': 0.60752, 'recall_at_3': 0.86104, 'recall_at_5': 0.92916, 'recall_at_10': 0.97604, 'recall_at_50': 0.99932, 'l2_dist': 25.863685607910156}
Saved new best model (mrr=0.745104)


                                                                        

Epoch 041 | Train Loss: 9.504546 | Val Loss: 9.088155 | LR: 1.00e-02
{'mrr': 0.7496327719070607, 'ndcg': 0.8102235425867619, 'recall_at_1': 0.61396, 'recall_at_3': 0.86476, 'recall_at_5': 0.93188, 'recall_at_10': 0.97724, 'recall_at_50': 0.9994, 'l2_dist': 25.863616943359375}
Saved new best model (mrr=0.749633)


                                                                        

Epoch 042 | Train Loss: 9.467524 | Val Loss: 9.059880 | LR: 1.00e-02
{'mrr': 0.7527260498828706, 'ndcg': 0.8126223165448918, 'recall_at_1': 0.61752, 'recall_at_3': 0.86776, 'recall_at_5': 0.93364, 'recall_at_10': 0.97788, 'recall_at_50': 0.99932, 'l2_dist': 25.86366844177246}
Saved new best model (mrr=0.752726)


                                                                        

Epoch 043 | Train Loss: 9.432654 | Val Loss: 9.036482 | LR: 1.00e-02
{'mrr': 0.7556678466709184, 'ndcg': 0.8149038664694828, 'recall_at_1': 0.62132, 'recall_at_3': 0.87148, 'recall_at_5': 0.93524, 'recall_at_10': 0.97792, 'recall_at_50': 0.99928, 'l2_dist': 25.863605499267578}
Saved new best model (mrr=0.755668)


                                                                        

Epoch 044 | Train Loss: 9.395155 | Val Loss: 9.019772 | LR: 1.00e-02
{'mrr': 0.7592838355625994, 'ndcg': 0.8176824821346259, 'recall_at_1': 0.6258, 'recall_at_3': 0.87452, 'recall_at_5': 0.93688, 'recall_at_10': 0.9788, 'recall_at_50': 0.99936, 'l2_dist': 25.863439559936523}
Saved new best model (mrr=0.759284)


                                                                        

Epoch 045 | Train Loss: 9.360313 | Val Loss: 8.995063 | LR: 1.00e-02
{'mrr': 0.7633908272690991, 'ndcg': 0.8208377857003949, 'recall_at_1': 0.63136, 'recall_at_3': 0.87796, 'recall_at_5': 0.93824, 'recall_at_10': 0.9794, 'recall_at_50': 0.99932, 'l2_dist': 25.863285064697266}
Saved new best model (mrr=0.763391)


                                                                        

Epoch 046 | Train Loss: 9.327949 | Val Loss: 8.976937 | LR: 1.00e-02
{'mrr': 0.7668092652867095, 'ndcg': 0.8234611941344092, 'recall_at_1': 0.63612, 'recall_at_3': 0.88044, 'recall_at_5': 0.94092, 'recall_at_10': 0.9802, 'recall_at_50': 0.99932, 'l2_dist': 25.863475799560547}
Saved new best model (mrr=0.766809)


                                                                        

Epoch 047 | Train Loss: 9.297588 | Val Loss: 8.956117 | LR: 1.00e-02
{'mrr': 0.7693379580982384, 'ndcg': 0.8253811299327651, 'recall_at_1': 0.63984, 'recall_at_3': 0.88236, 'recall_at_5': 0.94088, 'recall_at_10': 0.98028, 'recall_at_50': 0.9994, 'l2_dist': 25.863649368286133}
Saved new best model (mrr=0.769338)


                                                                        

Epoch 048 | Train Loss: 9.270127 | Val Loss: 8.936843 | LR: 1.00e-02
{'mrr': 0.7728167537461033, 'ndcg': 0.8280557537511742, 'recall_at_1': 0.64464, 'recall_at_3': 0.88416, 'recall_at_5': 0.94296, 'recall_at_10': 0.98028, 'recall_at_50': 0.99936, 'l2_dist': 25.863662719726562}
Saved new best model (mrr=0.772817)


                                                                        

Epoch 049 | Train Loss: 9.241386 | Val Loss: 8.916855 | LR: 1.00e-02
{'mrr': 0.7758653725429038, 'ndcg': 0.8303903428074345, 'recall_at_1': 0.64884, 'recall_at_3': 0.88704, 'recall_at_5': 0.94432, 'recall_at_10': 0.9812, 'recall_at_50': 0.99936, 'l2_dist': 25.863616943359375}
Saved new best model (mrr=0.775865)


                                                                        

Epoch 050 | Train Loss: 9.210445 | Val Loss: 8.898814 | LR: 1.00e-02
{'mrr': 0.7800357719133535, 'ndcg': 0.8335622874277105, 'recall_at_1': 0.65476, 'recall_at_3': 0.88936, 'recall_at_5': 0.94576, 'recall_at_10': 0.98112, 'recall_at_50': 0.99932, 'l2_dist': 25.86370086669922}
Saved new best model (mrr=0.780036)


                                                                        

Epoch 051 | Train Loss: 9.186489 | Val Loss: 8.883049 | LR: 1.00e-02
{'mrr': 0.7827055493425288, 'ndcg': 0.8356284394974318, 'recall_at_1': 0.65808, 'recall_at_3': 0.89224, 'recall_at_5': 0.94788, 'recall_at_10': 0.98172, 'recall_at_50': 0.99936, 'l2_dist': 25.86377716064453}
Saved new best model (mrr=0.782706)


                                                                        

Epoch 052 | Train Loss: 9.165732 | Val Loss: 8.868834 | LR: 1.00e-02
{'mrr': 0.7845697297752329, 'ndcg': 0.8370520190617405, 'recall_at_1': 0.66096, 'recall_at_3': 0.8934, 'recall_at_5': 0.94852, 'recall_at_10': 0.98184, 'recall_at_50': 0.9994, 'l2_dist': 25.86334991455078}
Saved new best model (mrr=0.784570)


                                                                        

Epoch 053 | Train Loss: 9.139190 | Val Loss: 8.850796 | LR: 1.00e-02
{'mrr': 0.7876262479758819, 'ndcg': 0.8394087462144358, 'recall_at_1': 0.66464, 'recall_at_3': 0.89644, 'recall_at_5': 0.9496, 'recall_at_10': 0.98224, 'recall_at_50': 0.9994, 'l2_dist': 25.86369514465332}
Saved new best model (mrr=0.787626)


                                                                        

Epoch 054 | Train Loss: 9.117922 | Val Loss: 8.834616 | LR: 1.00e-02
{'mrr': 0.7897901057084208, 'ndcg': 0.8410599928759764, 'recall_at_1': 0.66808, 'recall_at_3': 0.89768, 'recall_at_5': 0.95048, 'recall_at_10': 0.98296, 'recall_at_50': 0.9994, 'l2_dist': 25.8636474609375}
Saved new best model (mrr=0.789790)


                                                                        

Epoch 055 | Train Loss: 9.097168 | Val Loss: 8.826574 | LR: 1.00e-02
{'mrr': 0.7933750839370854, 'ndcg': 0.8438051571659173, 'recall_at_1': 0.67268, 'recall_at_3': 0.90016, 'recall_at_5': 0.9516, 'recall_at_10': 0.98324, 'recall_at_50': 0.99944, 'l2_dist': 25.86373519897461}
Saved new best model (mrr=0.793375)


                                                                        

Epoch 056 | Train Loss: 9.066976 | Val Loss: 8.811369 | LR: 1.00e-02
{'mrr': 0.795422717748469, 'ndcg': 0.8453629765820695, 'recall_at_1': 0.67604, 'recall_at_3': 0.90184, 'recall_at_5': 0.95268, 'recall_at_10': 0.9834, 'recall_at_50': 0.99948, 'l2_dist': 25.86403465270996}
Saved new best model (mrr=0.795423)


                                                                        

Epoch 057 | Train Loss: 9.051930 | Val Loss: 8.800178 | LR: 1.00e-02
{'mrr': 0.7984553671837755, 'ndcg': 0.8476828810745082, 'recall_at_1': 0.68028, 'recall_at_3': 0.90328, 'recall_at_5': 0.95344, 'recall_at_10': 0.98404, 'recall_at_50': 0.9994, 'l2_dist': 25.863689422607422}
Saved new best model (mrr=0.798455)


                                                                        

Epoch 058 | Train Loss: 9.030934 | Val Loss: 8.787858 | LR: 1.00e-02
{'mrr': 0.8012176452087841, 'ndcg': 0.8497889629454042, 'recall_at_1': 0.684, 'recall_at_3': 0.90608, 'recall_at_5': 0.95436, 'recall_at_10': 0.98388, 'recall_at_50': 0.99936, 'l2_dist': 25.863359451293945}
Saved new best model (mrr=0.801218)


                                                                        

Epoch 059 | Train Loss: 9.008796 | Val Loss: 8.773514 | LR: 1.00e-02
{'mrr': 0.8036445598504766, 'ndcg': 0.8516495233235416, 'recall_at_1': 0.68752, 'recall_at_3': 0.90864, 'recall_at_5': 0.955, 'recall_at_10': 0.98444, 'recall_at_50': 0.99944, 'l2_dist': 25.863712310791016}
Saved new best model (mrr=0.803645)


                                                                        

Epoch 060 | Train Loss: 8.992390 | Val Loss: 8.762922 | LR: 1.00e-02
{'mrr': 0.8054698923992343, 'ndcg': 0.8530417966630865, 'recall_at_1': 0.69004, 'recall_at_3': 0.90956, 'recall_at_5': 0.95604, 'recall_at_10': 0.9846, 'recall_at_50': 0.99936, 'l2_dist': 25.863800048828125}
Saved new best model (mrr=0.805470)


                                                                        

Epoch 061 | Train Loss: 8.975786 | Val Loss: 8.753201 | LR: 1.00e-02
{'mrr': 0.8081989607026463, 'ndcg': 0.8551218317547551, 'recall_at_1': 0.69412, 'recall_at_3': 0.91156, 'recall_at_5': 0.95696, 'recall_at_10': 0.98472, 'recall_at_50': 0.99944, 'l2_dist': 25.863555908203125}
Saved new best model (mrr=0.808199)


                                                                        

Epoch 062 | Train Loss: 8.958698 | Val Loss: 8.742177 | LR: 1.00e-02
{'mrr': 0.8110097402391434, 'ndcg': 0.8572687392859679, 'recall_at_1': 0.69792, 'recall_at_3': 0.91316, 'recall_at_5': 0.95816, 'recall_at_10': 0.98504, 'recall_at_50': 0.99944, 'l2_dist': 25.863962173461914}
Saved new best model (mrr=0.811010)


                                                                        

Epoch 063 | Train Loss: 8.945007 | Val Loss: 8.729631 | LR: 1.00e-02
{'mrr': 0.8134996706486883, 'ndcg': 0.8591756806700174, 'recall_at_1': 0.70136, 'recall_at_3': 0.91528, 'recall_at_5': 0.95908, 'recall_at_10': 0.98524, 'recall_at_50': 0.9994, 'l2_dist': 25.86406707763672}
Saved new best model (mrr=0.813500)


                                                                        

Epoch 064 | Train Loss: 8.931388 | Val Loss: 8.723962 | LR: 1.00e-02
{'mrr': 0.8150324488334691, 'ndcg': 0.8603490272811962, 'recall_at_1': 0.70336, 'recall_at_3': 0.9174, 'recall_at_5': 0.95864, 'recall_at_10': 0.98584, 'recall_at_50': 0.99944, 'l2_dist': 25.864227294921875}
Saved new best model (mrr=0.815032)


                                                                        

Epoch 065 | Train Loss: 8.911401 | Val Loss: 8.714994 | LR: 1.00e-02
{'mrr': 0.8175928550216663, 'ndcg': 0.8622834904302139, 'recall_at_1': 0.70756, 'recall_at_3': 0.91788, 'recall_at_5': 0.95996, 'recall_at_10': 0.98592, 'recall_at_50': 0.99944, 'l2_dist': 25.86419677734375}
Saved new best model (mrr=0.817593)


                                                                        

Epoch 066 | Train Loss: 8.896591 | Val Loss: 8.704339 | LR: 1.00e-02
{'mrr': 0.8203357628661024, 'ndcg': 0.8643581841713707, 'recall_at_1': 0.71152, 'recall_at_3': 0.91904, 'recall_at_5': 0.96024, 'recall_at_10': 0.98596, 'recall_at_50': 0.99948, 'l2_dist': 25.864215850830078}
Saved new best model (mrr=0.820336)


                                                                        

Epoch 067 | Train Loss: 8.877470 | Val Loss: 8.694634 | LR: 1.00e-02
{'mrr': 0.8219581168568664, 'ndcg': 0.8656096802887264, 'recall_at_1': 0.7136, 'recall_at_3': 0.92124, 'recall_at_5': 0.9604, 'recall_at_10': 0.98616, 'recall_at_50': 0.99944, 'l2_dist': 25.864017486572266}
Saved new best model (mrr=0.821958)


                                                                        

Epoch 068 | Train Loss: 8.861064 | Val Loss: 8.681238 | LR: 1.00e-02
{'mrr': 0.8247619734230455, 'ndcg': 0.867742476922289, 'recall_at_1': 0.71784, 'recall_at_3': 0.92284, 'recall_at_5': 0.96192, 'recall_at_10': 0.98624, 'recall_at_50': 0.99948, 'l2_dist': 25.864234924316406}
Saved new best model (mrr=0.824762)


                                                                        

Epoch 069 | Train Loss: 8.850726 | Val Loss: 8.677458 | LR: 1.00e-02
{'mrr': 0.8261404539531323, 'ndcg': 0.8688053391910471, 'recall_at_1': 0.71952, 'recall_at_3': 0.92436, 'recall_at_5': 0.96184, 'recall_at_10': 0.98636, 'recall_at_50': 0.99948, 'l2_dist': 25.86395835876465}
Saved new best model (mrr=0.826140)


                                                                        

Epoch 070 | Train Loss: 8.841624 | Val Loss: 8.670968 | LR: 1.00e-02
{'mrr': 0.8278596826675609, 'ndcg': 0.8701035946033379, 'recall_at_1': 0.72232, 'recall_at_3': 0.92464, 'recall_at_5': 0.96296, 'recall_at_10': 0.98684, 'recall_at_50': 0.99948, 'l2_dist': 25.864185333251953}
Saved new best model (mrr=0.827860)


                                                                        

Epoch 071 | Train Loss: 8.821863 | Val Loss: 8.661323 | LR: 1.00e-02
{'mrr': 0.8301225611336894, 'ndcg': 0.8718075634272423, 'recall_at_1': 0.72608, 'recall_at_3': 0.92648, 'recall_at_5': 0.96312, 'recall_at_10': 0.98688, 'recall_at_50': 0.99948, 'l2_dist': 25.863927841186523}
Saved new best model (mrr=0.830123)


                                                                        

Epoch 072 | Train Loss: 8.815411 | Val Loss: 8.655657 | LR: 1.00e-02
{'mrr': 0.83235511640557, 'ndcg': 0.8735006338009601, 'recall_at_1': 0.72936, 'recall_at_3': 0.9272, 'recall_at_5': 0.96364, 'recall_at_10': 0.98696, 'recall_at_50': 0.99948, 'l2_dist': 25.86426544189453}
Saved new best model (mrr=0.832355)


                                                                        

Epoch 073 | Train Loss: 8.798954 | Val Loss: 8.648046 | LR: 1.00e-02
{'mrr': 0.8331536089656633, 'ndcg': 0.8741180504661344, 'recall_at_1': 0.7306, 'recall_at_3': 0.92852, 'recall_at_5': 0.96452, 'recall_at_10': 0.98724, 'recall_at_50': 0.99948, 'l2_dist': 25.864524841308594}
Saved new best model (mrr=0.833154)


                                                                        

Epoch 074 | Train Loss: 8.790113 | Val Loss: 8.638907 | LR: 1.00e-02
{'mrr': 0.8365892417777736, 'ndcg': 0.8767195603634117, 'recall_at_1': 0.73588, 'recall_at_3': 0.93008, 'recall_at_5': 0.96564, 'recall_at_10': 0.98716, 'recall_at_50': 0.99948, 'l2_dist': 25.86450958251953}
Saved new best model (mrr=0.836589)


                                                                        

Epoch 075 | Train Loss: 8.775919 | Val Loss: 8.634699 | LR: 1.00e-02
{'mrr': 0.837945625873825, 'ndcg': 0.8777514500102424, 'recall_at_1': 0.73764, 'recall_at_3': 0.93036, 'recall_at_5': 0.9658, 'recall_at_10': 0.9872, 'recall_at_50': 0.99948, 'l2_dist': 25.86415672302246}
Saved new best model (mrr=0.837946)


                                                                        

Epoch 076 | Train Loss: 8.766360 | Val Loss: 8.629706 | LR: 1.00e-02
{'mrr': 0.8399766476785439, 'ndcg': 0.8792782854120753, 'recall_at_1': 0.74104, 'recall_at_3': 0.93132, 'recall_at_5': 0.96608, 'recall_at_10': 0.9876, 'recall_at_50': 0.99948, 'l2_dist': 25.86467170715332}
Saved new best model (mrr=0.839977)


                                                                        

Epoch 077 | Train Loss: 8.749940 | Val Loss: 8.622226 | LR: 1.00e-02
{'mrr': 0.8410447067036843, 'ndcg': 0.8801008125041856, 'recall_at_1': 0.74268, 'recall_at_3': 0.93272, 'recall_at_5': 0.96664, 'recall_at_10': 0.98716, 'recall_at_50': 0.99948, 'l2_dist': 25.864748001098633}
Saved new best model (mrr=0.841045)


                                                                        

Epoch 078 | Train Loss: 8.731741 | Val Loss: 8.621512 | LR: 1.00e-02
{'mrr': 0.8428517810595006, 'ndcg': 0.8814663461292962, 'recall_at_1': 0.74532, 'recall_at_3': 0.93304, 'recall_at_5': 0.96684, 'recall_at_10': 0.98768, 'recall_at_50': 0.99948, 'l2_dist': 25.864646911621094}
Saved new best model (mrr=0.842852)


                                                                        

Epoch 079 | Train Loss: 8.730828 | Val Loss: 8.610981 | LR: 1.00e-02
{'mrr': 0.8443322712201379, 'ndcg': 0.8826075330512935, 'recall_at_1': 0.7472, 'recall_at_3': 0.93488, 'recall_at_5': 0.96744, 'recall_at_10': 0.98772, 'recall_at_50': 0.99948, 'l2_dist': 25.864959716796875}
Saved new best model (mrr=0.844332)


                                                                        

Epoch 080 | Train Loss: 8.718339 | Val Loss: 8.605654 | LR: 1.00e-02
{'mrr': 0.8463453326900547, 'ndcg': 0.8841179472418413, 'recall_at_1': 0.75076, 'recall_at_3': 0.9348, 'recall_at_5': 0.968, 'recall_at_10': 0.98788, 'recall_at_50': 0.99952, 'l2_dist': 25.864625930786133}
Saved new best model (mrr=0.846345)


                                                                        

Epoch 081 | Train Loss: 8.707945 | Val Loss: 8.602138 | LR: 1.00e-02
{'mrr': 0.8487402202385655, 'ndcg': 0.8859166923067121, 'recall_at_1': 0.7548, 'recall_at_3': 0.93564, 'recall_at_5': 0.96832, 'recall_at_10': 0.98836, 'recall_at_50': 0.99944, 'l2_dist': 25.864765167236328}
Saved new best model (mrr=0.848740)


                                                                        

Epoch 082 | Train Loss: 8.694968 | Val Loss: 8.597170 | LR: 1.00e-02
{'mrr': 0.8495444197844713, 'ndcg': 0.8865560251232001, 'recall_at_1': 0.75504, 'recall_at_3': 0.9374, 'recall_at_5': 0.96836, 'recall_at_10': 0.98808, 'recall_at_50': 0.99944, 'l2_dist': 25.864744186401367}
Saved new best model (mrr=0.849544)


                                                                        

Epoch 083 | Train Loss: 8.687545 | Val Loss: 8.591729 | LR: 1.00e-02
{'mrr': 0.8518434911695653, 'ndcg': 0.8882770556858044, 'recall_at_1': 0.75904, 'recall_at_3': 0.93788, 'recall_at_5': 0.969, 'recall_at_10': 0.98828, 'recall_at_50': 0.99948, 'l2_dist': 25.864770889282227}
Saved new best model (mrr=0.851843)


                                                                        

Epoch 084 | Train Loss: 8.672701 | Val Loss: 8.589050 | LR: 1.00e-02
{'mrr': 0.852907243564596, 'ndcg': 0.8890892993771442, 'recall_at_1': 0.76056, 'recall_at_3': 0.9388, 'recall_at_5': 0.96924, 'recall_at_10': 0.98868, 'recall_at_50': 0.99948, 'l2_dist': 25.864988327026367}
Saved new best model (mrr=0.852907)


                                                                        

Epoch 085 | Train Loss: 8.665377 | Val Loss: 8.584316 | LR: 1.00e-02
{'mrr': 0.8553793260614839, 'ndcg': 0.8909672825028212, 'recall_at_1': 0.76444, 'recall_at_3': 0.94032, 'recall_at_5': 0.96972, 'recall_at_10': 0.98884, 'recall_at_50': 0.99948, 'l2_dist': 25.86471176147461}
Saved new best model (mrr=0.855379)


                                                                        

Epoch 086 | Train Loss: 8.656486 | Val Loss: 8.582368 | LR: 1.00e-02
{'mrr': 0.8571201884891425, 'ndcg': 0.8922620910763276, 'recall_at_1': 0.76788, 'recall_at_3': 0.94084, 'recall_at_5': 0.97048, 'recall_at_10': 0.98884, 'recall_at_50': 0.99948, 'l2_dist': 25.8646297454834}
Saved new best model (mrr=0.857120)


                                                                        

Epoch 087 | Train Loss: 8.647351 | Val Loss: 8.576293 | LR: 1.00e-02
{'mrr': 0.8584113174541191, 'ndcg': 0.8932426952540287, 'recall_at_1': 0.76968, 'recall_at_3': 0.94156, 'recall_at_5': 0.9706, 'recall_at_10': 0.98912, 'recall_at_50': 0.99948, 'l2_dist': 25.864635467529297}
Saved new best model (mrr=0.858411)


                                                                        

Epoch 088 | Train Loss: 8.634130 | Val Loss: 8.574218 | LR: 1.00e-02
{'mrr': 0.8595153535275977, 'ndcg': 0.8940813708266678, 'recall_at_1': 0.7712, 'recall_at_3': 0.94212, 'recall_at_5': 0.97068, 'recall_at_10': 0.98924, 'recall_at_50': 0.99948, 'l2_dist': 25.864702224731445}
Saved new best model (mrr=0.859515)


                                                                        

Epoch 089 | Train Loss: 8.623721 | Val Loss: 8.567663 | LR: 1.00e-02
{'mrr': 0.8611775209435643, 'ndcg': 0.8953376356656629, 'recall_at_1': 0.77372, 'recall_at_3': 0.94308, 'recall_at_5': 0.97068, 'recall_at_10': 0.98936, 'recall_at_50': 0.99944, 'l2_dist': 25.864965438842773}
Saved new best model (mrr=0.861178)


                                                                        

Epoch 090 | Train Loss: 8.615987 | Val Loss: 8.563215 | LR: 1.00e-02
{'mrr': 0.8621099286227674, 'ndcg': 0.896067497181653, 'recall_at_1': 0.77476, 'recall_at_3': 0.94444, 'recall_at_5': 0.97172, 'recall_at_10': 0.98964, 'recall_at_50': 0.99948, 'l2_dist': 25.86488914489746}
Saved new best model (mrr=0.862110)


                                                                        

Epoch 091 | Train Loss: 8.608321 | Val Loss: 8.560117 | LR: 1.00e-02
{'mrr': 0.8635725904815222, 'ndcg': 0.897171189600055, 'recall_at_1': 0.77704, 'recall_at_3': 0.9444, 'recall_at_5': 0.97192, 'recall_at_10': 0.98968, 'recall_at_50': 0.99944, 'l2_dist': 25.86485481262207}
Saved new best model (mrr=0.863573)


                                                                        

Epoch 092 | Train Loss: 8.594999 | Val Loss: 8.559254 | LR: 1.00e-02
{'mrr': 0.8650389196671997, 'ndcg': 0.8982825428026626, 'recall_at_1': 0.77928, 'recall_at_3': 0.9452, 'recall_at_5': 0.97244, 'recall_at_10': 0.98976, 'recall_at_50': 0.99944, 'l2_dist': 25.864843368530273}
Saved new best model (mrr=0.865039)


                                                                        

Epoch 093 | Train Loss: 8.586975 | Val Loss: 8.559565 | LR: 1.00e-02
{'mrr': 0.8658863112689891, 'ndcg': 0.89893485364862, 'recall_at_1': 0.78052, 'recall_at_3': 0.94652, 'recall_at_5': 0.9728, 'recall_at_10': 0.99032, 'recall_at_50': 0.99944, 'l2_dist': 25.864748001098633}
Saved new best model (mrr=0.865886)


                                                                        

Epoch 094 | Train Loss: 8.578087 | Val Loss: 8.553418 | LR: 1.00e-02
{'mrr': 0.8677402273435489, 'ndcg': 0.9003209189876024, 'recall_at_1': 0.78368, 'recall_at_3': 0.94664, 'recall_at_5': 0.97256, 'recall_at_10': 0.99028, 'recall_at_50': 0.99956, 'l2_dist': 25.864994049072266}
Saved new best model (mrr=0.867740)


                                                                        

Epoch 095 | Train Loss: 8.568571 | Val Loss: 8.552704 | LR: 1.00e-02
{'mrr': 0.8689889570037173, 'ndcg': 0.901252090204635, 'recall_at_1': 0.78576, 'recall_at_3': 0.94716, 'recall_at_5': 0.97304, 'recall_at_10': 0.99012, 'recall_at_50': 0.99952, 'l2_dist': 25.864534378051758}
Saved new best model (mrr=0.868989)


                                                                        

Epoch 096 | Train Loss: 8.560708 | Val Loss: 8.545038 | LR: 1.00e-02
{'mrr': 0.869966284529269, 'ndcg': 0.9020106594868766, 'recall_at_1': 0.78684, 'recall_at_3': 0.94808, 'recall_at_5': 0.97368, 'recall_at_10': 0.99048, 'recall_at_50': 0.99948, 'l2_dist': 25.864748001098633}
Saved new best model (mrr=0.869966)


                                                                        

Epoch 097 | Train Loss: 8.546208 | Val Loss: 8.548457 | LR: 1.00e-02
{'mrr': 0.8705028695100112, 'ndcg': 0.902424110948287, 'recall_at_1': 0.7878, 'recall_at_3': 0.94816, 'recall_at_5': 0.97412, 'recall_at_10': 0.99052, 'recall_at_50': 0.99948, 'l2_dist': 25.864498138427734}
Saved new best model (mrr=0.870503)


                                                                        

Epoch 098 | Train Loss: 8.540303 | Val Loss: 8.546576 | LR: 1.00e-02
{'mrr': 0.8724688154927315, 'ndcg': 0.9039054737815042, 'recall_at_1': 0.79088, 'recall_at_3': 0.94904, 'recall_at_5': 0.97404, 'recall_at_10': 0.99076, 'recall_at_50': 0.99952, 'l2_dist': 25.864500045776367}
Saved new best model (mrr=0.872469)


                                                                        

Epoch 099 | Train Loss: 8.528478 | Val Loss: 8.547786 | LR: 1.00e-02
{'mrr': 0.8734213847021621, 'ndcg': 0.9046374922847354, 'recall_at_1': 0.79216, 'recall_at_3': 0.95016, 'recall_at_5': 0.97452, 'recall_at_10': 0.99084, 'recall_at_50': 0.99952, 'l2_dist': 25.864622116088867}
Saved new best model (mrr=0.873421)


                                                                         

Epoch 100 | Train Loss: 8.515743 | Val Loss: 8.546721 | LR: 1.00e-02
{'mrr': 0.8748704446283402, 'ndcg': 0.9057168315898575, 'recall_at_1': 0.7946, 'recall_at_3': 0.95052, 'recall_at_5': 0.9744, 'recall_at_10': 0.99076, 'recall_at_50': 0.99952, 'l2_dist': 25.86458969116211}
Saved new best model (mrr=0.874870)


                                                                         

Epoch 101 | Train Loss: 8.511141 | Val Loss: 8.551028 | LR: 1.00e-02
{'mrr': 0.8762466979973452, 'ndcg': 0.9067555494707903, 'recall_at_1': 0.79692, 'recall_at_3': 0.95068, 'recall_at_5': 0.97468, 'recall_at_10': 0.99076, 'recall_at_50': 0.9994, 'l2_dist': 25.86429214477539}
Saved new best model (mrr=0.876247)


                                                                         

Epoch 102 | Train Loss: 8.497430 | Val Loss: 8.546428 | LR: 1.00e-02
{'mrr': 0.8775003372386349, 'ndcg': 0.9076966520977922, 'recall_at_1': 0.799, 'recall_at_3': 0.95112, 'recall_at_5': 0.9754, 'recall_at_10': 0.99104, 'recall_at_50': 0.9994, 'l2_dist': 25.864429473876953}
Saved new best model (mrr=0.877500)


                                                                         

Epoch 103 | Train Loss: 8.484945 | Val Loss: 8.550936 | LR: 1.00e-02
{'mrr': 0.8790220052257282, 'ndcg': 0.9088590342638435, 'recall_at_1': 0.80124, 'recall_at_3': 0.95236, 'recall_at_5': 0.97556, 'recall_at_10': 0.99128, 'recall_at_50': 0.99936, 'l2_dist': 25.864185333251953}
Saved new best model (mrr=0.879022)


                                                                         

Epoch 104 | Train Loss: 8.478484 | Val Loss: 8.554836 | LR: 1.00e-02
{'mrr': 0.8794563746564507, 'ndcg': 0.9091903245227217, 'recall_at_1': 0.80184, 'recall_at_3': 0.95264, 'recall_at_5': 0.97592, 'recall_at_10': 0.99144, 'recall_at_50': 0.9994, 'l2_dist': 25.86420440673828}
Saved new best model (mrr=0.879456)


                                                                         

Epoch 105 | Train Loss: 8.471068 | Val Loss: 8.546347 | LR: 1.00e-02
{'mrr': 0.881127026477069, 'ndcg': 0.9104392268753718, 'recall_at_1': 0.80484, 'recall_at_3': 0.95236, 'recall_at_5': 0.97628, 'recall_at_10': 0.99172, 'recall_at_50': 0.99952, 'l2_dist': 25.86420440673828}
Saved new best model (mrr=0.881127)


                                                                         

Epoch 106 | Train Loss: 8.456992 | Val Loss: 8.546707 | LR: 1.00e-02
{'mrr': 0.8822545253862849, 'ndcg': 0.9112906704265155, 'recall_at_1': 0.8066, 'recall_at_3': 0.95292, 'recall_at_5': 0.97624, 'recall_at_10': 0.9916, 'recall_at_50': 0.99944, 'l2_dist': 25.864870071411133}
Saved new best model (mrr=0.882255)


                                                                         

Epoch 107 | Train Loss: 8.448037 | Val Loss: 8.548946 | LR: 1.00e-02
{'mrr': 0.8825030691216542, 'ndcg': 0.9114919589233167, 'recall_at_1': 0.8066, 'recall_at_3': 0.95368, 'recall_at_5': 0.9762, 'recall_at_10': 0.9916, 'recall_at_50': 0.99948, 'l2_dist': 25.864839553833008}
Saved new best model (mrr=0.882503)


                                                                         

Epoch 108 | Train Loss: 8.435445 | Val Loss: 8.546677 | LR: 1.00e-02
{'mrr': 0.88353894153147, 'ndcg': 0.9122805801001492, 'recall_at_1': 0.80816, 'recall_at_3': 0.95464, 'recall_at_5': 0.97668, 'recall_at_10': 0.99168, 'recall_at_50': 0.99944, 'l2_dist': 25.86404800415039}
Saved new best model (mrr=0.883539)


                                                                         

Epoch 109 | Train Loss: 8.428571 | Val Loss: 8.546384 | LR: 1.00e-02
{'mrr': 0.8848755256220445, 'ndcg': 0.9132770433717151, 'recall_at_1': 0.8106, 'recall_at_3': 0.9546, 'recall_at_5': 0.97644, 'recall_at_10': 0.99192, 'recall_at_50': 0.99948, 'l2_dist': 25.865020751953125}
Saved new best model (mrr=0.884876)


                                                                         

Epoch 110 | Train Loss: 8.414532 | Val Loss: 8.548757 | LR: 1.00e-02
{'mrr': 0.8853596674037084, 'ndcg': 0.913656290662212, 'recall_at_1': 0.81124, 'recall_at_3': 0.95452, 'recall_at_5': 0.977, 'recall_at_10': 0.99204, 'recall_at_50': 0.99948, 'l2_dist': 25.864124298095703}
Saved new best model (mrr=0.885360)


                                                                         

Epoch 111 | Train Loss: 8.406815 | Val Loss: 8.549977 | LR: 1.00e-02
{'mrr': 0.8868853371200756, 'ndcg': 0.914797884404584, 'recall_at_1': 0.81384, 'recall_at_3': 0.95536, 'recall_at_5': 0.97748, 'recall_at_10': 0.9922, 'recall_at_50': 0.99952, 'l2_dist': 25.865360260009766}
Saved new best model (mrr=0.886885)


                                                                         

Epoch 112 | Train Loss: 8.395442 | Val Loss: 8.549442 | LR: 1.00e-02
{'mrr': 0.8877660929485381, 'ndcg': 0.9154655765048577, 'recall_at_1': 0.81512, 'recall_at_3': 0.95612, 'recall_at_5': 0.9776, 'recall_at_10': 0.99232, 'recall_at_50': 0.99944, 'l2_dist': 25.864370346069336}
Saved new best model (mrr=0.887766)


                                                                         

Epoch 113 | Train Loss: 8.381469 | Val Loss: 8.554197 | LR: 1.00e-02
{'mrr': 0.8887156867735224, 'ndcg': 0.9161744652798747, 'recall_at_1': 0.81672, 'recall_at_3': 0.95616, 'recall_at_5': 0.97756, 'recall_at_10': 0.99196, 'recall_at_50': 0.99944, 'l2_dist': 25.864017486572266}
Saved new best model (mrr=0.888716)


                                                                         

Epoch 114 | Train Loss: 8.371555 | Val Loss: 8.551076 | LR: 1.00e-02
{'mrr': 0.8897287699547028, 'ndcg': 0.9169326729671178, 'recall_at_1': 0.81848, 'recall_at_3': 0.95692, 'recall_at_5': 0.9776, 'recall_at_10': 0.99216, 'recall_at_50': 0.9994, 'l2_dist': 25.86452293395996}
Saved new best model (mrr=0.889729)


                                                                         

Epoch 115 | Train Loss: 8.361910 | Val Loss: 8.560104 | LR: 1.00e-02
{'mrr': 0.889622972145786, 'ndcg': 0.9168605595608875, 'recall_at_1': 0.8184, 'recall_at_3': 0.95712, 'recall_at_5': 0.97804, 'recall_at_10': 0.9926, 'recall_at_50': 0.9994, 'l2_dist': 25.8649845123291}


                                                                         

Epoch 116 | Train Loss: 8.347069 | Val Loss: 8.559645 | LR: 1.00e-02
{'mrr': 0.891501938615648, 'ndcg': 0.9182812260294715, 'recall_at_1': 0.8214, 'recall_at_3': 0.95736, 'recall_at_5': 0.97868, 'recall_at_10': 0.99272, 'recall_at_50': 0.99952, 'l2_dist': 25.864513397216797}
Saved new best model (mrr=0.891502)


                                                                         

Epoch 117 | Train Loss: 8.336854 | Val Loss: 8.552369 | LR: 1.00e-02
{'mrr': 0.8920185073149952, 'ndcg': 0.9186706720302482, 'recall_at_1': 0.82212, 'recall_at_3': 0.957, 'recall_at_5': 0.97892, 'recall_at_10': 0.99272, 'recall_at_50': 0.99948, 'l2_dist': 25.865184783935547}
Saved new best model (mrr=0.892019)


                                                                         

Epoch 118 | Train Loss: 8.318837 | Val Loss: 8.560990 | LR: 1.00e-02
{'mrr': 0.8921275502436841, 'ndcg': 0.9187663763732533, 'recall_at_1': 0.82212, 'recall_at_3': 0.9578, 'recall_at_5': 0.97896, 'recall_at_10': 0.99292, 'recall_at_50': 0.99952, 'l2_dist': 25.865476608276367}
Saved new best model (mrr=0.892128)


                                                                         

Epoch 119 | Train Loss: 8.306911 | Val Loss: 8.561553 | LR: 1.00e-02
{'mrr': 0.8935919537981588, 'ndcg': 0.9198504448658801, 'recall_at_1': 0.82488, 'recall_at_3': 0.9584, 'recall_at_5': 0.97888, 'recall_at_10': 0.99296, 'recall_at_50': 0.99952, 'l2_dist': 25.865886688232422}
Saved new best model (mrr=0.893592)


                                                                         

Epoch 120 | Train Loss: 8.296496 | Val Loss: 8.566355 | LR: 1.00e-02
{'mrr': 0.8935276146214256, 'ndcg': 0.9198190562944409, 'recall_at_1': 0.82432, 'recall_at_3': 0.95808, 'recall_at_5': 0.97896, 'recall_at_10': 0.99304, 'recall_at_50': 0.99952, 'l2_dist': 25.865785598754883}


                                                                         

Epoch 121 | Train Loss: 8.286322 | Val Loss: 8.575978 | LR: 1.00e-02
{'mrr': 0.8940918149701368, 'ndcg': 0.9202370218636483, 'recall_at_1': 0.82544, 'recall_at_3': 0.95844, 'recall_at_5': 0.97904, 'recall_at_10': 0.99324, 'recall_at_50': 0.99948, 'l2_dist': 25.8656005859375}
Saved new best model (mrr=0.894092)


                                                                         

Epoch 122 | Train Loss: 8.271580 | Val Loss: 8.569868 | LR: 1.00e-02
{'mrr': 0.895887782126177, 'ndcg': 0.9215768520620315, 'recall_at_1': 0.82836, 'recall_at_3': 0.95924, 'recall_at_5': 0.97964, 'recall_at_10': 0.99292, 'recall_at_50': 0.99944, 'l2_dist': 25.86683464050293}
Saved new best model (mrr=0.895888)


                                                                         

Epoch 123 | Train Loss: 8.257798 | Val Loss: 8.570758 | LR: 1.00e-02
{'mrr': 0.8969585067497086, 'ndcg': 0.922405502152219, 'recall_at_1': 0.8298, 'recall_at_3': 0.96024, 'recall_at_5': 0.97992, 'recall_at_10': 0.993, 'recall_at_50': 0.99948, 'l2_dist': 25.866622924804688}
Saved new best model (mrr=0.896959)


                                                                         

Epoch 124 | Train Loss: 8.245111 | Val Loss: 8.579288 | LR: 1.00e-02
{'mrr': 0.8969009317390966, 'ndcg': 0.9223542498413662, 'recall_at_1': 0.83016, 'recall_at_3': 0.95908, 'recall_at_5': 0.98024, 'recall_at_10': 0.99304, 'recall_at_50': 0.99952, 'l2_dist': 25.865989685058594}


                                                                         

Epoch 125 | Train Loss: 8.230098 | Val Loss: 8.579191 | LR: 1.00e-02
{'mrr': 0.8979441115075475, 'ndcg': 0.9231397885954041, 'recall_at_1': 0.83176, 'recall_at_3': 0.95956, 'recall_at_5': 0.98024, 'recall_at_10': 0.99316, 'recall_at_50': 0.99948, 'l2_dist': 25.867252349853516}
Saved new best model (mrr=0.897944)


                                                                         

Epoch 126 | Train Loss: 8.221044 | Val Loss: 8.579188 | LR: 1.00e-02
{'mrr': 0.8987208825769741, 'ndcg': 0.9237183072616276, 'recall_at_1': 0.8332, 'recall_at_3': 0.96024, 'recall_at_5': 0.98028, 'recall_at_10': 0.99332, 'recall_at_50': 0.99948, 'l2_dist': 25.867145538330078}
Saved new best model (mrr=0.898721)


                                                                         

Epoch 127 | Train Loss: 8.204498 | Val Loss: 8.592960 | LR: 1.00e-02
{'mrr': 0.8988822191599132, 'ndcg': 0.9238410196639838, 'recall_at_1': 0.8334, 'recall_at_3': 0.96012, 'recall_at_5': 0.9802, 'recall_at_10': 0.9934, 'recall_at_50': 0.99944, 'l2_dist': 25.867464065551758}
Saved new best model (mrr=0.898882)


                                                                         

Epoch 128 | Train Loss: 8.188705 | Val Loss: 8.605139 | LR: 1.00e-02
{'mrr': 0.8995209047204961, 'ndcg': 0.9243149564985065, 'recall_at_1': 0.8344, 'recall_at_3': 0.95984, 'recall_at_5': 0.98084, 'recall_at_10': 0.99324, 'recall_at_50': 0.99944, 'l2_dist': 25.867855072021484}
Saved new best model (mrr=0.899521)


                                                                         

Epoch 129 | Train Loss: 8.174120 | Val Loss: 8.605315 | LR: 1.00e-02
{'mrr': 0.9001114415662838, 'ndcg': 0.9247577310156048, 'recall_at_1': 0.83544, 'recall_at_3': 0.96092, 'recall_at_5': 0.98048, 'recall_at_10': 0.9932, 'recall_at_50': 0.99944, 'l2_dist': 25.867469787597656}
Saved new best model (mrr=0.900111)


                                                                         

Epoch 130 | Train Loss: 8.161413 | Val Loss: 8.595733 | LR: 1.00e-02
{'mrr': 0.9009094215634448, 'ndcg': 0.92536848058878, 'recall_at_1': 0.8366, 'recall_at_3': 0.9614, 'recall_at_5': 0.981, 'recall_at_10': 0.99316, 'recall_at_50': 0.99948, 'l2_dist': 25.867446899414062}
Saved new best model (mrr=0.900909)


                                                                         

Epoch 131 | Train Loss: 8.147729 | Val Loss: 8.601658 | LR: 1.00e-02
{'mrr': 0.901408976987742, 'ndcg': 0.9257338132403337, 'recall_at_1': 0.83744, 'recall_at_3': 0.96092, 'recall_at_5': 0.98072, 'recall_at_10': 0.99312, 'recall_at_50': 0.99948, 'l2_dist': 25.867076873779297}
Saved new best model (mrr=0.901409)


                                                                         

Epoch 132 | Train Loss: 8.132400 | Val Loss: 8.605450 | LR: 1.00e-02
{'mrr': 0.9020983832674704, 'ndcg': 0.9262612329538037, 'recall_at_1': 0.83848, 'recall_at_3': 0.96152, 'recall_at_5': 0.9812, 'recall_at_10': 0.99348, 'recall_at_50': 0.99944, 'l2_dist': 25.8660831451416}
Saved new best model (mrr=0.902098)


                                                                         

Epoch 133 | Train Loss: 8.119427 | Val Loss: 8.606718 | LR: 1.00e-02
{'mrr': 0.9016146786639437, 'ndcg': 0.9259117850883165, 'recall_at_1': 0.83736, 'recall_at_3': 0.96148, 'recall_at_5': 0.98064, 'recall_at_10': 0.99344, 'recall_at_50': 0.99948, 'l2_dist': 25.867446899414062}


                                                                         

Epoch 134 | Train Loss: 8.104056 | Val Loss: 8.611420 | LR: 1.00e-02
{'mrr': 0.9034603671183951, 'ndcg': 0.9272771141915405, 'recall_at_1': 0.84096, 'recall_at_3': 0.96184, 'recall_at_5': 0.98144, 'recall_at_10': 0.99364, 'recall_at_50': 0.99952, 'l2_dist': 25.866666793823242}
Saved new best model (mrr=0.903460)


                                                                         

Epoch 135 | Train Loss: 8.082631 | Val Loss: 8.599972 | LR: 1.00e-02
{'mrr': 0.9036133611577432, 'ndcg': 0.9274142429291975, 'recall_at_1': 0.84112, 'recall_at_3': 0.96216, 'recall_at_5': 0.981, 'recall_at_10': 0.9936, 'recall_at_50': 0.99948, 'l2_dist': 25.866046905517578}
Saved new best model (mrr=0.903613)


                                                                         

Epoch 136 | Train Loss: 8.065851 | Val Loss: 8.628601 | LR: 1.00e-02
{'mrr': 0.9038123778191686, 'ndcg': 0.927543791598708, 'recall_at_1': 0.8416, 'recall_at_3': 0.96196, 'recall_at_5': 0.98096, 'recall_at_10': 0.99356, 'recall_at_50': 0.99948, 'l2_dist': 25.865991592407227}
Saved new best model (mrr=0.903812)


                                                                         

Epoch 137 | Train Loss: 8.053449 | Val Loss: 8.622343 | LR: 1.00e-02
{'mrr': 0.9045504681407127, 'ndcg': 0.9281128176622501, 'recall_at_1': 0.84264, 'recall_at_3': 0.9624, 'recall_at_5': 0.98168, 'recall_at_10': 0.99332, 'recall_at_50': 0.99952, 'l2_dist': 25.86632537841797}
Saved new best model (mrr=0.904550)


                                                                         

Epoch 138 | Train Loss: 8.037707 | Val Loss: 8.625389 | LR: 1.00e-02
{'mrr': 0.904785913031901, 'ndcg': 0.9282716509505328, 'recall_at_1': 0.8432, 'recall_at_3': 0.96276, 'recall_at_5': 0.98124, 'recall_at_10': 0.99332, 'recall_at_50': 0.99952, 'l2_dist': 25.865060806274414}
Saved new best model (mrr=0.904786)


                                                                         

Epoch 139 | Train Loss: 8.019913 | Val Loss: 8.650460 | LR: 1.00e-02
{'mrr': 0.9050334738578814, 'ndcg': 0.9284791261436067, 'recall_at_1': 0.84324, 'recall_at_3': 0.96328, 'recall_at_5': 0.98204, 'recall_at_10': 0.99332, 'recall_at_50': 0.99948, 'l2_dist': 25.863149642944336}
Saved new best model (mrr=0.905033)


                                                                         

Epoch 140 | Train Loss: 8.003783 | Val Loss: 8.637961 | LR: 1.00e-02
{'mrr': 0.9059400494276124, 'ndcg': 0.9291440154065888, 'recall_at_1': 0.84496, 'recall_at_3': 0.96288, 'recall_at_5': 0.982, 'recall_at_10': 0.9932, 'recall_at_50': 0.99948, 'l2_dist': 25.8637752532959}
Saved new best model (mrr=0.905940)


                                                                         

Epoch 141 | Train Loss: 7.984996 | Val Loss: 8.640611 | LR: 1.00e-02
{'mrr': 0.9060086252681905, 'ndcg': 0.9292012663354318, 'recall_at_1': 0.84508, 'recall_at_3': 0.96388, 'recall_at_5': 0.9818, 'recall_at_10': 0.99348, 'recall_at_50': 0.99948, 'l2_dist': 25.861970901489258}
Saved new best model (mrr=0.906009)


                                                                         

Epoch 142 | Train Loss: 7.964607 | Val Loss: 8.653389 | LR: 1.00e-02
{'mrr': 0.9067503375178677, 'ndcg': 0.9297748757822509, 'recall_at_1': 0.84588, 'recall_at_3': 0.96432, 'recall_at_5': 0.98192, 'recall_at_10': 0.99332, 'recall_at_50': 0.99956, 'l2_dist': 25.86153793334961}
Saved new best model (mrr=0.906750)


                                                                         

Epoch 143 | Train Loss: 7.946808 | Val Loss: 8.667174 | LR: 1.00e-02
{'mrr': 0.9070313057256595, 'ndcg': 0.9299664549225585, 'recall_at_1': 0.84684, 'recall_at_3': 0.96396, 'recall_at_5': 0.98196, 'recall_at_10': 0.99348, 'recall_at_50': 0.9996, 'l2_dist': 25.860885620117188}
Saved new best model (mrr=0.907031)


                                                                         

Epoch 144 | Train Loss: 7.929769 | Val Loss: 8.677611 | LR: 1.00e-02
{'mrr': 0.9075922646309968, 'ndcg': 0.9303767669248405, 'recall_at_1': 0.84808, 'recall_at_3': 0.96344, 'recall_at_5': 0.98216, 'recall_at_10': 0.99352, 'recall_at_50': 0.99952, 'l2_dist': 25.859384536743164}
Saved new best model (mrr=0.907592)


                                                                         

Epoch 145 | Train Loss: 7.911709 | Val Loss: 8.678746 | LR: 1.00e-02
{'mrr': 0.9081859581787547, 'ndcg': 0.9308241532286277, 'recall_at_1': 0.84904, 'recall_at_3': 0.96444, 'recall_at_5': 0.98196, 'recall_at_10': 0.9936, 'recall_at_50': 0.99948, 'l2_dist': 25.858800888061523}
Saved new best model (mrr=0.908186)


                                                                         

Epoch 146 | Train Loss: 7.888996 | Val Loss: 8.711145 | LR: 1.00e-02
{'mrr': 0.9083485094552078, 'ndcg': 0.9309395553839246, 'recall_at_1': 0.8492, 'recall_at_3': 0.96416, 'recall_at_5': 0.98216, 'recall_at_10': 0.99316, 'recall_at_50': 0.99952, 'l2_dist': 25.857032775878906}
Saved new best model (mrr=0.908349)


                                                                         

Epoch 147 | Train Loss: 7.872118 | Val Loss: 8.701007 | LR: 1.00e-02
{'mrr': 0.9087153192838503, 'ndcg': 0.9312257491747366, 'recall_at_1': 0.8494, 'recall_at_3': 0.96384, 'recall_at_5': 0.98184, 'recall_at_10': 0.99328, 'recall_at_50': 0.99944, 'l2_dist': 25.855812072753906}
Saved new best model (mrr=0.908715)


                                                                         

Epoch 148 | Train Loss: 7.844815 | Val Loss: 8.726366 | LR: 1.00e-02
{'mrr': 0.9092838983253061, 'ndcg': 0.9316296283591409, 'recall_at_1': 0.85092, 'recall_at_3': 0.96384, 'recall_at_5': 0.98184, 'recall_at_10': 0.9934, 'recall_at_50': 0.99956, 'l2_dist': 25.85511016845703}
Saved new best model (mrr=0.909284)


                                                                         

Epoch 149 | Train Loss: 7.821288 | Val Loss: 8.699593 | LR: 1.00e-02
{'mrr': 0.9093719547136119, 'ndcg': 0.9317347598074304, 'recall_at_1': 0.85056, 'recall_at_3': 0.96568, 'recall_at_5': 0.98196, 'recall_at_10': 0.9936, 'recall_at_50': 0.9996, 'l2_dist': 25.854169845581055}
Saved new best model (mrr=0.909372)


                                                                         

Epoch 150 | Train Loss: 7.807897 | Val Loss: 8.715131 | LR: 1.00e-02
{'mrr': 0.9099768073822565, 'ndcg': 0.9321864458317304, 'recall_at_1': 0.85156, 'recall_at_3': 0.96616, 'recall_at_5': 0.98232, 'recall_at_10': 0.99372, 'recall_at_50': 0.99952, 'l2_dist': 25.85320472717285}
Saved new best model (mrr=0.909977)


                                                                         

Epoch 151 | Train Loss: 7.778905 | Val Loss: 8.740828 | LR: 1.00e-02
{'mrr': 0.910693092894599, 'ndcg': 0.9327142358684251, 'recall_at_1': 0.85296, 'recall_at_3': 0.96552, 'recall_at_5': 0.98248, 'recall_at_10': 0.9936, 'recall_at_50': 0.99956, 'l2_dist': 25.851505279541016}
Saved new best model (mrr=0.910693)


                                                                         

Epoch 152 | Train Loss: 7.758256 | Val Loss: 8.725563 | LR: 1.00e-02
{'mrr': 0.9102909585997337, 'ndcg': 0.9324089341523676, 'recall_at_1': 0.8522, 'recall_at_3': 0.96552, 'recall_at_5': 0.98172, 'recall_at_10': 0.9934, 'recall_at_50': 0.99956, 'l2_dist': 25.85041046142578}


                                                                         

Epoch 153 | Train Loss: 7.726067 | Val Loss: 8.782766 | LR: 1.00e-02
{'mrr': 0.911035460786248, 'ndcg': 0.9329618678360959, 'recall_at_1': 0.8536, 'recall_at_3': 0.96544, 'recall_at_5': 0.98268, 'recall_at_10': 0.99356, 'recall_at_50': 0.99956, 'l2_dist': 25.849010467529297}
Saved new best model (mrr=0.911035)


                                                                         

Epoch 154 | Train Loss: 7.701685 | Val Loss: 8.781732 | LR: 1.00e-02
{'mrr': 0.9116257696458977, 'ndcg': 0.9333986222397688, 'recall_at_1': 0.85464, 'recall_at_3': 0.96616, 'recall_at_5': 0.98212, 'recall_at_10': 0.9934, 'recall_at_50': 0.99956, 'l2_dist': 25.847566604614258}
Saved new best model (mrr=0.911626)


                                                                         

Epoch 155 | Train Loss: 7.678189 | Val Loss: 8.825734 | LR: 1.00e-02
{'mrr': 0.9112410540178736, 'ndcg': 0.9331051145391687, 'recall_at_1': 0.85396, 'recall_at_3': 0.96568, 'recall_at_5': 0.98248, 'recall_at_10': 0.99328, 'recall_at_50': 0.99952, 'l2_dist': 25.847230911254883}


                                                                         

Epoch 156 | Train Loss: 7.646727 | Val Loss: 8.803892 | LR: 1.00e-02
{'mrr': 0.912535540165627, 'ndcg': 0.9340906399630889, 'recall_at_1': 0.85592, 'recall_at_3': 0.9654, 'recall_at_5': 0.9826, 'recall_at_10': 0.99384, 'recall_at_50': 0.99956, 'l2_dist': 25.844907760620117}
Saved new best model (mrr=0.912536)


                                                                         

Epoch 157 | Train Loss: 7.616019 | Val Loss: 8.835546 | LR: 1.00e-02
{'mrr': 0.911810819565494, 'ndcg': 0.9335465666456121, 'recall_at_1': 0.85464, 'recall_at_3': 0.96604, 'recall_at_5': 0.98276, 'recall_at_10': 0.99336, 'recall_at_50': 0.99952, 'l2_dist': 25.844581604003906}


                                                                         

Epoch 158 | Train Loss: 7.592397 | Val Loss: 8.882018 | LR: 1.00e-02
{'mrr': 0.912778926475431, 'ndcg': 0.934269241637707, 'recall_at_1': 0.85636, 'recall_at_3': 0.96576, 'recall_at_5': 0.98228, 'recall_at_10': 0.99356, 'recall_at_50': 0.99956, 'l2_dist': 25.842479705810547}
Saved new best model (mrr=0.912779)


                                                                         

Epoch 159 | Train Loss: 7.562647 | Val Loss: 8.903275 | LR: 1.00e-02
{'mrr': 0.9132830617834757, 'ndcg': 0.9346571167356517, 'recall_at_1': 0.85684, 'recall_at_3': 0.9674, 'recall_at_5': 0.98236, 'recall_at_10': 0.99356, 'recall_at_50': 0.99956, 'l2_dist': 25.841089248657227}
Saved new best model (mrr=0.913283)


                                                                         

Epoch 160 | Train Loss: 7.530540 | Val Loss: 8.907273 | LR: 1.00e-02
{'mrr': 0.9132226227583026, 'ndcg': 0.9345954924697151, 'recall_at_1': 0.85736, 'recall_at_3': 0.9664, 'recall_at_5': 0.98212, 'recall_at_10': 0.994, 'recall_at_50': 0.99944, 'l2_dist': 25.84023094177246}


                                                                         

Epoch 161 | Train Loss: 7.506746 | Val Loss: 8.915393 | LR: 1.00e-02
{'mrr': 0.9133039527991696, 'ndcg': 0.9346522184840201, 'recall_at_1': 0.85764, 'recall_at_3': 0.96628, 'recall_at_5': 0.98244, 'recall_at_10': 0.99368, 'recall_at_50': 0.9996, 'l2_dist': 25.83821678161621}
Saved new best model (mrr=0.913304)


                                                                         

Epoch 162 | Train Loss: 7.473185 | Val Loss: 8.936068 | LR: 1.00e-02
{'mrr': 0.9131085209753046, 'ndcg': 0.9345162277404513, 'recall_at_1': 0.85732, 'recall_at_3': 0.96632, 'recall_at_5': 0.98208, 'recall_at_10': 0.99364, 'recall_at_50': 0.99952, 'l2_dist': 25.837215423583984}


                                                                         

Epoch 163 | Train Loss: 7.412302 | Val Loss: 8.928278 | LR: 5.00e-03
{'mrr': 0.9147758304345504, 'ndcg': 0.9357676662228415, 'recall_at_1': 0.85968, 'recall_at_3': 0.9672, 'recall_at_5': 0.98224, 'recall_at_10': 0.99348, 'recall_at_50': 0.99948, 'l2_dist': 25.835920333862305}
Saved new best model (mrr=0.914776)


                                                                         

Epoch 164 | Train Loss: 7.368843 | Val Loss: 8.923306 | LR: 5.00e-03
{'mrr': 0.9155444939961487, 'ndcg': 0.9363479288082309, 'recall_at_1': 0.86096, 'recall_at_3': 0.9676, 'recall_at_5': 0.98312, 'recall_at_10': 0.99368, 'recall_at_50': 0.99956, 'l2_dist': 25.83478546142578}
Saved new best model (mrr=0.915544)


                                                                         

Epoch 165 | Train Loss: 7.332259 | Val Loss: 8.935398 | LR: 5.00e-03
{'mrr': 0.9149281519941473, 'ndcg': 0.9358955174104316, 'recall_at_1': 0.85976, 'recall_at_3': 0.968, 'recall_at_5': 0.98244, 'recall_at_10': 0.9936, 'recall_at_50': 0.99952, 'l2_dist': 25.833282470703125}


                                                                         

Epoch 166 | Train Loss: 7.309873 | Val Loss: 8.979349 | LR: 5.00e-03
{'mrr': 0.9155956757823107, 'ndcg': 0.936400506322549, 'recall_at_1': 0.86072, 'recall_at_3': 0.96784, 'recall_at_5': 0.98292, 'recall_at_10': 0.99352, 'recall_at_50': 0.99952, 'l2_dist': 25.832454681396484}
Saved new best model (mrr=0.915596)


                                                                         

Epoch 167 | Train Loss: 7.268614 | Val Loss: 9.015142 | LR: 5.00e-03
{'mrr': 0.9160914800091144, 'ndcg': 0.9367580826652152, 'recall_at_1': 0.86192, 'recall_at_3': 0.967, 'recall_at_5': 0.98252, 'recall_at_10': 0.9938, 'recall_at_50': 0.99956, 'l2_dist': 25.831497192382812}
Saved new best model (mrr=0.916091)


                                                                         

Epoch 168 | Train Loss: 7.241962 | Val Loss: 9.035503 | LR: 5.00e-03
{'mrr': 0.9153484958698936, 'ndcg': 0.936207841025957, 'recall_at_1': 0.8604, 'recall_at_3': 0.9676, 'recall_at_5': 0.98288, 'recall_at_10': 0.9938, 'recall_at_50': 0.99948, 'l2_dist': 25.8305606842041}


                                                                         

Epoch 169 | Train Loss: 7.213369 | Val Loss: 9.043328 | LR: 5.00e-03
{'mrr': 0.9161562289398906, 'ndcg': 0.9368161511363519, 'recall_at_1': 0.86184, 'recall_at_3': 0.96736, 'recall_at_5': 0.983, 'recall_at_10': 0.9938, 'recall_at_50': 0.99952, 'l2_dist': 25.829832077026367}
Saved new best model (mrr=0.916156)


                                                                         

Epoch 170 | Train Loss: 7.186179 | Val Loss: 9.081370 | LR: 5.00e-03
{'mrr': 0.9159212813577707, 'ndcg': 0.9366150626533238, 'recall_at_1': 0.8618, 'recall_at_3': 0.9668, 'recall_at_5': 0.9826, 'recall_at_10': 0.99372, 'recall_at_50': 0.99956, 'l2_dist': 25.8288516998291}


                                                                         

Epoch 171 | Train Loss: 7.156657 | Val Loss: 9.055746 | LR: 5.00e-03
{'mrr': 0.9164389746690763, 'ndcg': 0.9370184505123024, 'recall_at_1': 0.8626, 'recall_at_3': 0.96744, 'recall_at_5': 0.98264, 'recall_at_10': 0.99408, 'recall_at_50': 0.9996, 'l2_dist': 25.828317642211914}
Saved new best model (mrr=0.916439)


                                                                         

Epoch 172 | Train Loss: 7.117363 | Val Loss: 9.088496 | LR: 2.50e-03
{'mrr': 0.9166643207427244, 'ndcg': 0.9371852078985311, 'recall_at_1': 0.86296, 'recall_at_3': 0.96732, 'recall_at_5': 0.9824, 'recall_at_10': 0.9938, 'recall_at_50': 0.99956, 'l2_dist': 25.82755470275879}
Saved new best model (mrr=0.916664)


                                                                         

Epoch 173 | Train Loss: 7.079164 | Val Loss: 9.102704 | LR: 2.50e-03
{'mrr': 0.9171800732851357, 'ndcg': 0.9375663241052696, 'recall_at_1': 0.86392, 'recall_at_3': 0.9674, 'recall_at_5': 0.98248, 'recall_at_10': 0.9938, 'recall_at_50': 0.99964, 'l2_dist': 25.826940536499023}
Saved new best model (mrr=0.917180)


                                                                         

Epoch 174 | Train Loss: 7.058874 | Val Loss: 9.136932 | LR: 2.50e-03
{'mrr': 0.9172335747593042, 'ndcg': 0.9376071830695032, 'recall_at_1': 0.86404, 'recall_at_3': 0.96728, 'recall_at_5': 0.98268, 'recall_at_10': 0.9938, 'recall_at_50': 0.9996, 'l2_dist': 25.82636260986328}
Saved new best model (mrr=0.917234)


                                                                         

Epoch 175 | Train Loss: 7.039483 | Val Loss: 9.162484 | LR: 2.50e-03
{'mrr': 0.9173447039818012, 'ndcg': 0.9376864168712649, 'recall_at_1': 0.86416, 'recall_at_3': 0.96744, 'recall_at_5': 0.98248, 'recall_at_10': 0.99396, 'recall_at_50': 0.9996, 'l2_dist': 25.82564926147461}
Saved new best model (mrr=0.917345)


                                                                         

Epoch 176 | Train Loss: 7.021986 | Val Loss: 9.169529 | LR: 2.50e-03
{'mrr': 0.9175668036730751, 'ndcg': 0.9378669646581477, 'recall_at_1': 0.86416, 'recall_at_3': 0.968, 'recall_at_5': 0.98296, 'recall_at_10': 0.99384, 'recall_at_50': 0.9996, 'l2_dist': 25.824954986572266}
Saved new best model (mrr=0.917567)


                                                                         

Epoch 177 | Train Loss: 6.994974 | Val Loss: 9.177043 | LR: 2.50e-03
{'mrr': 0.9178867301843457, 'ndcg': 0.9380849437470066, 'recall_at_1': 0.86532, 'recall_at_3': 0.96768, 'recall_at_5': 0.98256, 'recall_at_10': 0.99396, 'recall_at_50': 0.9996, 'l2_dist': 25.82479476928711}
Saved new best model (mrr=0.917887)


                                                                         

Epoch 178 | Train Loss: 6.965443 | Val Loss: 9.181867 | LR: 1.25e-03
{'mrr': 0.9181062536613935, 'ndcg': 0.9382496185508981, 'recall_at_1': 0.86572, 'recall_at_3': 0.96756, 'recall_at_5': 0.9828, 'recall_at_10': 0.99396, 'recall_at_50': 0.9996, 'l2_dist': 25.824474334716797}
Saved new best model (mrr=0.918106)


                                                                         

Epoch 179 | Train Loss: 6.954677 | Val Loss: 9.195402 | LR: 1.25e-03
{'mrr': 0.9183427726465152, 'ndcg': 0.9384239235276786, 'recall_at_1': 0.86604, 'recall_at_3': 0.9676, 'recall_at_5': 0.9826, 'recall_at_10': 0.99392, 'recall_at_50': 0.99952, 'l2_dist': 25.82390594482422}
Saved new best model (mrr=0.918343)


                                                                         

Epoch 180 | Train Loss: 6.940185 | Val Loss: 9.209755 | LR: 1.25e-03
{'mrr': 0.9180008330610167, 'ndcg': 0.9381705571443268, 'recall_at_1': 0.86548, 'recall_at_3': 0.96784, 'recall_at_5': 0.98264, 'recall_at_10': 0.99376, 'recall_at_50': 0.99956, 'l2_dist': 25.823516845703125}


                                                                         

Epoch 181 | Train Loss: 6.925490 | Val Loss: 9.223111 | LR: 1.25e-03
{'mrr': 0.9179936389259791, 'ndcg': 0.9381644089579702, 'recall_at_1': 0.86544, 'recall_at_3': 0.96764, 'recall_at_5': 0.98276, 'recall_at_10': 0.99372, 'recall_at_50': 0.9996, 'l2_dist': 25.823131561279297}


                                                                         

Epoch 182 | Train Loss: 6.913460 | Val Loss: 9.230830 | LR: 1.25e-03
{'mrr': 0.9184438546692315, 'ndcg': 0.9385011558821748, 'recall_at_1': 0.86616, 'recall_at_3': 0.9676, 'recall_at_5': 0.9828, 'recall_at_10': 0.99388, 'recall_at_50': 0.99956, 'l2_dist': 25.822832107543945}
Saved new best model (mrr=0.918444)


                                                                         

Epoch 183 | Train Loss: 6.905278 | Val Loss: 9.235260 | LR: 6.25e-04
{'mrr': 0.9183052167433675, 'ndcg': 0.9384026185458906, 'recall_at_1': 0.86592, 'recall_at_3': 0.96772, 'recall_at_5': 0.9828, 'recall_at_10': 0.99396, 'recall_at_50': 0.99964, 'l2_dist': 25.82268524169922}


                                                                         

Epoch 184 | Train Loss: 6.899821 | Val Loss: 9.238097 | LR: 6.25e-04
{'mrr': 0.9183553218315881, 'ndcg': 0.9384377758115917, 'recall_at_1': 0.86608, 'recall_at_3': 0.96776, 'recall_at_5': 0.98292, 'recall_at_10': 0.99404, 'recall_at_50': 0.9996, 'l2_dist': 25.82258415222168}


                                                                         

Epoch 185 | Train Loss: 6.886190 | Val Loss: 9.239790 | LR: 6.25e-04
{'mrr': 0.9184115878711804, 'ndcg': 0.938477050764332, 'recall_at_1': 0.86608, 'recall_at_3': 0.96764, 'recall_at_5': 0.9826, 'recall_at_10': 0.99384, 'recall_at_50': 0.9996, 'l2_dist': 25.822425842285156}


                                                                         

Epoch 186 | Train Loss: 6.876360 | Val Loss: 9.250942 | LR: 6.25e-04
{'mrr': 0.9184467865241395, 'ndcg': 0.9385117081486276, 'recall_at_1': 0.86596, 'recall_at_3': 0.968, 'recall_at_5': 0.98288, 'recall_at_10': 0.9938, 'recall_at_50': 0.99956, 'l2_dist': 25.822187423706055}
Saved new best model (mrr=0.918447)


                                                                         

Epoch 187 | Train Loss: 6.866812 | Val Loss: 9.254520 | LR: 3.13e-04
{'mrr': 0.9182039284501953, 'ndcg': 0.9383247248356715, 'recall_at_1': 0.86572, 'recall_at_3': 0.96784, 'recall_at_5': 0.98312, 'recall_at_10': 0.99388, 'recall_at_50': 0.99956, 'l2_dist': 25.82209587097168}


                                                                         

Epoch 188 | Train Loss: 6.863809 | Val Loss: 9.263229 | LR: 3.13e-04
{'mrr': 0.9183447426144816, 'ndcg': 0.9384308207890604, 'recall_at_1': 0.86592, 'recall_at_3': 0.968, 'recall_at_5': 0.98312, 'recall_at_10': 0.9938, 'recall_at_50': 0.99956, 'l2_dist': 25.82196044921875}


                                                                         

Epoch 189 | Train Loss: 6.861477 | Val Loss: 9.267792 | LR: 3.13e-04
{'mrr': 0.9184688817853436, 'ndcg': 0.9385178896313036, 'recall_at_1': 0.86632, 'recall_at_3': 0.9678, 'recall_at_5': 0.98304, 'recall_at_10': 0.99388, 'recall_at_50': 0.99956, 'l2_dist': 25.8219051361084}
Saved new best model (mrr=0.918469)


                                                                         

Epoch 190 | Train Loss: 6.855020 | Val Loss: 9.267006 | LR: 3.13e-04
{'mrr': 0.9185795304122106, 'ndcg': 0.9386020158185565, 'recall_at_1': 0.86644, 'recall_at_3': 0.96796, 'recall_at_5': 0.98284, 'recall_at_10': 0.99392, 'recall_at_50': 0.99956, 'l2_dist': 25.82176971435547}
Saved new best model (mrr=0.918580)


                                                                         

Epoch 191 | Train Loss: 6.853622 | Val Loss: 9.268468 | LR: 1.56e-04
{'mrr': 0.9185960548050746, 'ndcg': 0.9386127520584251, 'recall_at_1': 0.86648, 'recall_at_3': 0.9678, 'recall_at_5': 0.98284, 'recall_at_10': 0.99392, 'recall_at_50': 0.99956, 'l2_dist': 25.821706771850586}
Saved new best model (mrr=0.918596)


                                                                         

Epoch 192 | Train Loss: 6.853867 | Val Loss: 9.270601 | LR: 1.56e-04
{'mrr': 0.9187493024967726, 'ndcg': 0.9387267547284248, 'recall_at_1': 0.86676, 'recall_at_3': 0.96788, 'recall_at_5': 0.983, 'recall_at_10': 0.99388, 'recall_at_50': 0.99956, 'l2_dist': 25.821670532226562}
Saved new best model (mrr=0.918749)


                                                                         

Epoch 193 | Train Loss: 6.845939 | Val Loss: 9.274498 | LR: 1.56e-04
{'mrr': 0.918577928059286, 'ndcg': 0.9386006967548485, 'recall_at_1': 0.8664, 'recall_at_3': 0.96792, 'recall_at_5': 0.98312, 'recall_at_10': 0.99392, 'recall_at_50': 0.99956, 'l2_dist': 25.821622848510742}


                                                                         

Epoch 194 | Train Loss: 6.842644 | Val Loss: 9.275593 | LR: 1.56e-04
{'mrr': 0.9185147344373612, 'ndcg': 0.9385544508358866, 'recall_at_1': 0.86628, 'recall_at_3': 0.96792, 'recall_at_5': 0.98308, 'recall_at_10': 0.99396, 'recall_at_50': 0.99956, 'l2_dist': 25.82158088684082}


                                                                         

Epoch 195 | Train Loss: 6.847499 | Val Loss: 9.274870 | LR: 7.81e-05
{'mrr': 0.918409030924572, 'ndcg': 0.9384733316042352, 'recall_at_1': 0.86612, 'recall_at_3': 0.96772, 'recall_at_5': 0.983, 'recall_at_10': 0.994, 'recall_at_50': 0.99956, 'l2_dist': 25.821575164794922}


                                                                         

Epoch 196 | Train Loss: 6.844411 | Val Loss: 9.278162 | LR: 7.81e-05
{'mrr': 0.9185054521357698, 'ndcg': 0.938545518320474, 'recall_at_1': 0.86632, 'recall_at_3': 0.96808, 'recall_at_5': 0.983, 'recall_at_10': 0.994, 'recall_at_50': 0.99956, 'l2_dist': 25.82154655456543}


                                                                         

Epoch 197 | Train Loss: 6.843208 | Val Loss: 9.278991 | LR: 7.81e-05
{'mrr': 0.9185594722148694, 'ndcg': 0.9385875509466799, 'recall_at_1': 0.86636, 'recall_at_3': 0.96812, 'recall_at_5': 0.98292, 'recall_at_10': 0.994, 'recall_at_50': 0.99956, 'l2_dist': 25.821502685546875}


                                                                         

Epoch 198 | Train Loss: 6.842658 | Val Loss: 9.279609 | LR: 7.81e-05
{'mrr': 0.9184858026852972, 'ndcg': 0.9385309771272154, 'recall_at_1': 0.86628, 'recall_at_3': 0.96804, 'recall_at_5': 0.98288, 'recall_at_10': 0.99396, 'recall_at_50': 0.99956, 'l2_dist': 25.82149314880371}


                                                                         

Epoch 199 | Train Loss: 6.839954 | Val Loss: 9.280280 | LR: 3.91e-05
{'mrr': 0.9184715574615667, 'ndcg': 0.93852107581172, 'recall_at_1': 0.86624, 'recall_at_3': 0.968, 'recall_at_5': 0.98292, 'recall_at_10': 0.99396, 'recall_at_50': 0.99956, 'l2_dist': 25.82147979736328}


                                                                         

Epoch 200 | Train Loss: 6.837841 | Val Loss: 9.280258 | LR: 3.91e-05
{'mrr': 0.9185031128318929, 'ndcg': 0.9385456425777864, 'recall_at_1': 0.86628, 'recall_at_3': 0.96804, 'recall_at_5': 0.98292, 'recall_at_10': 0.99396, 'recall_at_50': 0.99956, 'l2_dist': 25.82147216796875}
Training complete. Best mrr: 0.918749
Finished training. Now testing using best model...
['190.pth', '191.pth', '192.pth']
Test Results: {'mrr': 0.9185033708563725, 'ndcg': 0.9385494500356236, 'recall_at_1': 0.86628, 'recall_at_3': 0.96784, 'recall_at_5': 0.98296, 'recall_at_10': 0.994, 'recall_at_50': 0.99956, 'l2_dist': 25.823225021362305}


In [21]:
generate_submission(models, Path(test_path), output_file="sub.csv", device=device, ensemble=True)

Generating submission file...
âœ“ Saved submission to sub.csv


Unnamed: 0,id,embedding
0,1,"[0.019891632720828056, 0.008660918101668358, 0..."
1,2,"[-0.03293107450008392, -0.015808546915650368, ..."
2,3,"[0.004772179294377565, -0.040211666375398636, ..."
3,4,"[0.026275502517819405, -0.00867441762238741, -..."
4,5,"[0.05914425477385521, 0.029865944758057594, 0...."
...,...,...
1495,1496,"[-0.001608845661394298, -0.0567314438521862, 0..."
1496,1497,"[0.02157767303287983, 0.01465622615069151, 0.0..."
1497,1498,"[0.05050719901919365, -0.04127898067235947, 0...."
1498,1499,"[-0.003923273179680109, -0.022004181519150734,..."


In [14]:
import optuna
from optuna.pruners import MedianPruner



def objective(trial, train_dataset, val_dataset, epochs: int = 10, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    n_layers = trial.suggest_int("n_layers", 1, 3)
    layer_choices = [1256, 1565, 2048]
    hidden_layers = [trial.suggest_categorical(f"n_units_l{i}", layer_choices) for i in range(n_layers)]

    activation_fn = nn.ReLU
    
    batch_size = trial.suggest_categorical("batch_size", [2048, 4096])
    lr = trial.suggest_categorical("lr", [1e-4, 1e-3, 1e-2])
    dropout_rate = trial.suggest_categorical('dropout_rate', [0.3, 0.4, 0.5])

    weight_decay = trial.suggest_categorical("weight_decay", [1e-4, 1e-3, 1e-2])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = SpaceTranslator(
        input_dim=1024, 
        output_dim=1536,
        hidden_layers=hidden_layers,
        activation=activation_fn,
        dropout_rate=dropout_rate
    )
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=0.5,           # reduce LR by half when plateau
        patience=2,           # wait 2 epochs before reducing LR
        min_lr=1e-6,          # don't go below this LR
    )

    best_val_loss = float('inf')
    no_improvements = 0

    for epoch in range(1, epochs + 1):
        model.train()
        running_train_loss = 0.0

        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = info_nce_loss(outputs, y_batch, model.logit_scale)

            loss.backward()
            optimizer.step()

            running_train_loss += loss.item()

        train_loss = running_train_loss / len(train_loader)

        # Validation phase
        model.eval()
        running_val_loss = 0.0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)

                outputs = model(X_batch)
                loss = info_nce_loss(outputs, y_batch, model.logit_scale)
                running_val_loss += loss.item()

        val_loss = running_val_loss / len(val_loader)

        scheduler.step(val_loss)

        results = test(val_dataset, model, device)

        trial.report(results['mrr'], epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()


    return results['mrr']


def run_optuna_search(data_path: Path, n_trials: int = 30, epochs: int = 10, n_jobs: int = 1, pruner=None):
    if pruner is None:
        pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=1)

    X, Y = get_data(data_path)    
    train_dataset, val_dataset = get_datasets(X, Y)

    study = optuna.create_study(direction="maximize", pruner=pruner)
    func = lambda trial: objective(trial, train_dataset=train_dataset, val_dataset=val_dataset,
                                   epochs=epochs)
    study.optimize(func, n_trials=n_trials, n_jobs=n_jobs)

    print("Study statistics:")
    print("  Number of finished trials: ", len(study.trials))
    print("  Best trial:")
    trial = study.best_trial
    print("    Value: ", trial.value)
    print("    Params: ")
    for k, v in trial.params.items():
        print(f"      {k}: {v}")

    return study

In [None]:
study = run_optuna_search(data_path=data_path, n_trials=200, epochs=30, n_jobs=1)
study.trials_dataframe().to_csv("optuna_trials.csv", index=False)

best_trial_number = study.best_trial.number
print("Best params:", study.best_params)
print("Best trial number:", study.best_trial.number)