### model.py

In [7]:
import torch
from typing import Optional
from torch import nn
from torch.nn import functional as F

class Translator(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        dir_hidden_dims: list[int],
        scale_hidden_dims: list[int],
        activation=nn.ReLU,
        dropout_rate: float=0.3
    ):
        super().__init__()

        def build_mlp(hidden_dims, out_dim, apply_softplus=False):
            layers = []
            last_dim = input_dim
            for hidden in hidden_dims:
                layers += [
                    nn.Linear(last_dim, hidden),
                    activation(),
                    nn.LayerNorm(hidden),
                    nn.Dropout(dropout_rate)
                ]
                last_dim = hidden
            layers.append(nn.Linear(last_dim, out_dim))
            
            if apply_softplus:
                layers.append(nn.Softplus())
            
            return nn.Sequential(*layers)

        self.dir_head = build_mlp(dir_hidden_dims, output_dim, apply_softplus=False)
        self.scale_head = build_mlp(scale_hidden_dims, 1, apply_softplus=True)

        # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.07))

        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

    def forward(self, x):
        direction = self.dir_head(x)
        scale = self.scale_head(x)
        
        return F.normalize(direction, p=2, dim=-1) * scale

### eval.py

In [8]:
from pathlib import Path
import numpy as np
import pandas as pd

'''Code from https://github.com/Mamiglia/challenge'''

def mrr(pred_indices: np.ndarray, gt_indices: np.ndarray) -> float:
    """
    Compute Mean Reciprocal Rank (MRR)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries (top-K)
        gt_indices: (N,) array of ground truth indices
    Returns:
        mrr: Mean Reciprocal Rank
    """
    reciprocal_ranks = []
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i] == gt_indices[i])[0]
        if matches.size > 0:
            reciprocal_ranks.append(1.0 / (matches[0] + 1))
        else:
            reciprocal_ranks.append(0.0)
    return np.mean(reciprocal_ranks)


def recall_at_k(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int) -> float:
    """Compute Recall@k
    Args:
        pred_indices: (N, N) array of top indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        recall: Recall@k
    """
    recall = 0
    for i in range(len(gt_indices)):
        if gt_indices[i] in pred_indices[i, :k]:
            recall += 1
    recall /= len(gt_indices)
    return recall

import numpy as np

def ndcg(pred_indices: np.ndarray, gt_indices: np.ndarray, k: int = 100) -> float:
    """
    Compute Normalized Discounted Cumulative Gain (NDCG@k)
    Args:
        pred_indices: (N, K) array of predicted indices for N queries
        gt_indices: (N,) array of ground truth indices
        k: number of top predictions to consider
    Returns:
        ndcg: NDCG@k
    """
    ndcg_total = 0.0
    for i in range(len(gt_indices)):
        matches = np.where(pred_indices[i, :k] == gt_indices[i])[0]
        if matches.size > 0:
            rank = matches[0] + 1
            ndcg_total += 1.0 / np.log2(rank + 1)  # DCG (IDCG = 1)
    return ndcg_total / len(gt_indices)



@torch.inference_mode()
def evaluate_retrieval(translated_embd, image_embd, gt_indices, max_indices = 99, batch_size=100):
    """Evaluate retrieval performance using cosine similarity
    Args:
        translated_embd: (N_captions, D) translated caption embeddings
        image_embd: (N_images, D) image embeddings
        gt_indices: (N_captions,) ground truth image indices for each caption
        max_indices: number of top predictions to consider
    Returns:
        results: dict of evaluation metrics
    
    """
    # Compute similarity matrix
    if isinstance(translated_embd, np.ndarray):
        translated_embd = torch.from_numpy(translated_embd).float()
    if isinstance(image_embd, np.ndarray):
        image_embd = torch.from_numpy(image_embd).float()
    
    n_queries = translated_embd.shape[0]
    device = translated_embd.device
    
    # Prepare containers for the fragments to be reassembled
    all_sorted_indices = []
    l2_distances = []
    
    # Process in batches - the narrow gate approach
    for start_idx in range(0, n_queries, batch_size):
        batch_slice = slice(start_idx, min(start_idx + batch_size, n_queries))
        batch_translated = translated_embd[batch_slice]
        batch_img_embd = image_embd[batch_slice]
        
        # Compute similarity only for this batch
        batch_similarity = batch_translated @ batch_img_embd.T

        # Get top-k predictions for this batch
        batch_indices = batch_similarity.topk(k=max_indices, dim=1, sorted=True).indices.numpy()
        all_sorted_indices.append(gt_indices[batch_slice][batch_indices])

        # Compute L2 distance for this batch
        batch_gt = gt_indices[batch_slice]
        batch_gt_embeddings = image_embd[batch_gt]
        batch_l2 = (batch_translated - batch_gt_embeddings).norm(dim=1)
        l2_distances.append(batch_l2)
    
    # Reassemble the fragments
    sorted_indices = np.concatenate(all_sorted_indices, axis=0)
    
    # Apply the sacred metrics to the whole
    metrics = {
        'mrr': mrr,
        'ndcg': ndcg,
        'recall_at_1': lambda preds, gt: recall_at_k(preds, gt, 1),
        'recall_at_3': lambda preds, gt: recall_at_k(preds, gt, 3),
        'recall_at_5': lambda preds, gt: recall_at_k(preds, gt, 5),
        'recall_at_10': lambda preds, gt: recall_at_k(preds, gt, 10),
        'recall_at_50': lambda preds, gt: recall_at_k(preds, gt, 50),
    }
    
    results = {
        name: func(sorted_indices, gt_indices)
        for name, func in metrics.items()
    }
    
    l2_dist = torch.cat(l2_distances, dim=0).mean().item()
    results['l2_dist'] = l2_dist
    
    return results

def eval_on_val(x_val: np.ndarray, y_val: np.ndarray, model: Translator, device) -> dict:
    gt_indices = torch.arange(len(y_val))
    
    model.eval()

    with torch.inference_mode():
        translated = model(x_val.to(device)).to('cpu')

    results = evaluate_retrieval(translated, y_val, gt_indices)
    
    return results

def generate_submission(model: Translator, test_path: Path, output_file="submission.csv", device=None):
    test_data = np.load(test_path)
    sample_ids = test_data['captions/ids']
    test_embds = test_data['captions/embeddings']
    test_embds = torch.from_numpy(test_embds).float()

    with torch.no_grad():
        pred_embds = model(test_embds.to(device)).cpu()

    print("Generating submission file...")

    if isinstance(pred_embds, torch.Tensor):
        pred_embds = pred_embds.cpu().numpy()

    df_submission = pd.DataFrame({'id': sample_ids, 'embedding': pred_embds.tolist()})

    df_submission.to_csv(output_file, index=False, float_format='%.17g')
    print(f"✓ Saved submission to {output_file}")

    return df_submission

### configs

### main.py

In [11]:
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import random_split
from tqdm import tqdm


def info_nce_loss(preds_norm, targets_norm, temp=0.07):
    logits = (preds_norm @ targets_norm.T) / temp
    labels = torch.arange(logits.size(0), device=logits.device)

    loss_t2i = F.cross_entropy(logits, labels)          
    loss_i2t = F.cross_entropy(logits.T, labels)        
    
    return 0.5 * (loss_t2i + loss_i2t)


def mse_loss(preds, targets):
    pred_norms = preds.norm(dim=-1)
    target_norms = targets.norm(dim=-1)
    
    return F.mse_loss(pred_norms, target_norms)


def combined_loss(preds: torch.Tensor, targets: torch.Tensor, temp: float, lamb: float = 1.0):
    preds_norm = F.normalize(preds, p=2, dim=1)
    targets_norm = F.normalize(targets, p=2, dim=1)

    l1 = info_nce_loss(preds_norm, targets_norm, temp)
    l2 = mse_loss(preds, targets)

    return l1 + lamb * l2


def train_model(
    model: Translator,
    model_path: Path,
    train_dataset: TensorDataset,
    val_dataset: TensorDataset,
    batch_size: int,
    epochs: int,
    lr: float,
    patience: int,
    temp: float,
    lambda_mag: float
) -> Translator:    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Using device: {device}")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    best_val_loss = float('inf')
    no_improvements = 0

    for epoch in range(epochs):
        model.train()

        train_loss = 0
        for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()

            outputs = model(X_batch)

            loss = combined_loss(outputs, y_batch, temp, lambda_mag)
            
            loss.backward()

            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        model.eval()

        val_loss = 0

        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                outputs = model(X_batch)
                
                loss = combined_loss(outputs, y_batch, temp, lambda_mag)
                
                val_loss += loss.item()

        val_loss /= len(val_loader)

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}")
        test(val_dataset, model, device)


        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improvements = 0

            Path(model_path).parent.mkdir(parents=True, exist_ok=True)

            torch.save(model.state_dict(), model_path)

            print(f"✓ Saved best model (val_loss={val_loss:.6f})")
        elif no_improvements >= patience:
            return model
        else:
            no_improvements += 1

    return model



def load_data(data_path: Path):
    data = np.load(data_path)
    caption_embeddings = data['captions/embeddings']
    image_embeddings = data['images/embeddings']
    caption_labels = data['captions/label']

    X_abs, y_abs = torch.tensor(caption_embeddings), torch.tensor(image_embeddings[np.argmax(caption_labels, axis=1)])
    
    print('Texts shape', X_abs.shape)
    print('Images shape', X_abs.shape)

    def print_stats():
        mean_X = X_abs.mean(dim=0)
        std_X = X_abs.std(dim=0)
        
        mean_Y = y_abs.mean(dim=0)
        std_Y = y_abs.std(dim=0)

        print("X: mean of stds per dim =", std_X.mean().item(), ", max =", std_X.max().item(), ", min =", std_X.min().item())
        print("Y: mean of stds per dim =", std_Y.mean().item(), ", max =", std_Y.max().item(), ", min =", std_Y.min().item())

    print_stats()
    
    dataset = TensorDataset(X_abs, y_abs)
    train_dataset, val_dataset = random_split(dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(42))
    
    return train_dataset, val_dataset


def test(val_dataset: TensorDataset, model: Translator, device):
    val_loader = DataLoader(val_dataset, batch_size=len(val_dataset))
    for x_val, y_val in val_loader:
        results = eval_on_val(x_val, y_val, model=model, device=device)
    return results

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size= 2048
lr= 0.0005
epochs= 200
patience = 10
temp = 0.011284474643610163
lambda_mag = 0.7763296874424117
dropout_rate = 0.25

data_path= '/kaggle/input/aml-competition/train/train/train.npz'
test_path= '/kaggle/input/aml-competition/test/test/test.clean.npz'

model_save_path= './models/exp1.pth'

train_dataset, val_dataset = load_data(data_path)

Texts shape torch.Size([125000, 1024])
Images shape torch.Size([125000, 1024])
X: mean of stds per dim = 0.788078248500824 , max = 3.573546886444092 , min = 0.3716050386428833
Y: mean of stds per dim = 0.4244377911090851 , max = 1.8597956895828247 , min = 0.08161858469247818


In [16]:
model_args = {
    'input_dim': 1024,
    'output_dim': 1536,
    'dir_hidden_dims': [2048, 4096],
    'scale_hidden_dims': [1024, 512, 256],
    'activation': nn.GELU,
    'dropout_rate': dropout_rate
}
model = Translator(**model_args).to(device)

train_model(model, model_save_path, train_dataset, val_dataset, batch_size, epochs, lr, patience, temp, lambda_mag)

print('Finished training. Now testing using best model...')

state = torch.load(model_save_path)
model.load_state_dict(state)
results = test(val_dataset, model, device)

print("Test Results:", results)

Using device: cuda


Epoch 1/200: 100%|██████████| 55/55 [00:05<00:00, 10.75it/s]


Epoch 1: Train Loss = 63.403510, Val Loss = 4.756323
✓ Saved best model (val_loss=4.756323)


Epoch 2/200: 100%|██████████| 55/55 [00:05<00:00, 10.76it/s]


Epoch 2: Train Loss = 6.506368, Val Loss = 4.519658
✓ Saved best model (val_loss=4.519658)


Epoch 3/200: 100%|██████████| 55/55 [00:05<00:00, 10.90it/s]


Epoch 3: Train Loss = 6.042053, Val Loss = 4.303803
✓ Saved best model (val_loss=4.303803)


Epoch 4/200: 100%|██████████| 55/55 [00:05<00:00, 10.90it/s]


Epoch 4: Train Loss = 5.694650, Val Loss = 4.141370
✓ Saved best model (val_loss=4.141370)


Epoch 5/200: 100%|██████████| 55/55 [00:05<00:00, 10.83it/s]


Epoch 5: Train Loss = 5.365307, Val Loss = 3.938814
✓ Saved best model (val_loss=3.938814)


Epoch 6/200: 100%|██████████| 55/55 [00:05<00:00, 10.77it/s]


Epoch 6: Train Loss = 5.052535, Val Loss = 3.782082
✓ Saved best model (val_loss=3.782082)


Epoch 7/200: 100%|██████████| 55/55 [00:05<00:00, 10.74it/s]


Epoch 7: Train Loss = 4.808344, Val Loss = 3.658976
✓ Saved best model (val_loss=3.658976)


Epoch 8/200: 100%|██████████| 55/55 [00:05<00:00, 10.84it/s]


Epoch 8: Train Loss = 4.574966, Val Loss = 3.574367
✓ Saved best model (val_loss=3.574367)


Epoch 9/200: 100%|██████████| 55/55 [00:05<00:00, 10.78it/s]


Epoch 9: Train Loss = 4.367667, Val Loss = 3.506182
✓ Saved best model (val_loss=3.506182)


Epoch 10/200: 100%|██████████| 55/55 [00:05<00:00, 10.71it/s]


Epoch 10: Train Loss = 4.224765, Val Loss = 3.456625
✓ Saved best model (val_loss=3.456625)


Epoch 11/200: 100%|██████████| 55/55 [00:05<00:00, 10.83it/s]


Epoch 11: Train Loss = 4.062249, Val Loss = 3.424075
✓ Saved best model (val_loss=3.424075)


Epoch 12/200: 100%|██████████| 55/55 [00:05<00:00, 10.71it/s]


Epoch 12: Train Loss = 3.934872, Val Loss = 3.410171
✓ Saved best model (val_loss=3.410171)


Epoch 13/200: 100%|██████████| 55/55 [00:05<00:00, 10.37it/s]


Epoch 13: Train Loss = 3.832203, Val Loss = 3.400214
✓ Saved best model (val_loss=3.400214)


Epoch 14/200: 100%|██████████| 55/55 [00:05<00:00, 10.65it/s]


Epoch 14: Train Loss = 3.740494, Val Loss = 3.397689
✓ Saved best model (val_loss=3.397689)


Epoch 15/200: 100%|██████████| 55/55 [00:05<00:00, 10.41it/s]


Epoch 15: Train Loss = 3.633952, Val Loss = 3.399666


Epoch 16/200: 100%|██████████| 55/55 [00:05<00:00, 10.52it/s]


Epoch 16: Train Loss = 3.561708, Val Loss = 3.343435
✓ Saved best model (val_loss=3.343435)


Epoch 17/200: 100%|██████████| 55/55 [00:05<00:00, 10.51it/s]


Epoch 17: Train Loss = 3.504278, Val Loss = 3.348626


Epoch 18/200: 100%|██████████| 55/55 [00:05<00:00, 10.50it/s]


Epoch 18: Train Loss = 3.429238, Val Loss = 3.320605
✓ Saved best model (val_loss=3.320605)


Epoch 19/200: 100%|██████████| 55/55 [00:05<00:00, 10.51it/s]


Epoch 19: Train Loss = 3.399527, Val Loss = 3.335627


Epoch 20/200: 100%|██████████| 55/55 [00:05<00:00, 10.38it/s]


Epoch 20: Train Loss = 3.332678, Val Loss = 3.318508
✓ Saved best model (val_loss=3.318508)


Epoch 21/200: 100%|██████████| 55/55 [00:05<00:00, 10.78it/s]


Epoch 21: Train Loss = 3.308346, Val Loss = 3.321895


Epoch 22/200: 100%|██████████| 55/55 [00:05<00:00, 10.44it/s]


Epoch 22: Train Loss = 3.267964, Val Loss = 3.318960


Epoch 23/200: 100%|██████████| 55/55 [00:05<00:00, 10.45it/s]


Epoch 23: Train Loss = 3.220179, Val Loss = 3.337986


Epoch 24/200: 100%|██████████| 55/55 [00:05<00:00, 10.34it/s]


Epoch 24: Train Loss = 3.212539, Val Loss = 3.301433
✓ Saved best model (val_loss=3.301433)


Epoch 25/200: 100%|██████████| 55/55 [00:05<00:00, 10.30it/s]


Epoch 25: Train Loss = 3.158418, Val Loss = 3.290306
✓ Saved best model (val_loss=3.290306)


Epoch 26/200: 100%|██████████| 55/55 [00:05<00:00, 10.77it/s]


Epoch 26: Train Loss = 3.138638, Val Loss = 3.332640


Epoch 27/200: 100%|██████████| 55/55 [00:05<00:00, 10.33it/s]


Epoch 27: Train Loss = 3.116332, Val Loss = 3.311292


Epoch 28/200: 100%|██████████| 55/55 [00:05<00:00, 10.43it/s]


Epoch 28: Train Loss = 3.089696, Val Loss = 3.294357


Epoch 29/200: 100%|██████████| 55/55 [00:05<00:00, 10.46it/s]


Epoch 29: Train Loss = 3.081923, Val Loss = 3.299584


Epoch 30/200: 100%|██████████| 55/55 [00:05<00:00, 10.44it/s]


Epoch 30: Train Loss = 3.043829, Val Loss = 3.294397


Epoch 31/200: 100%|██████████| 55/55 [00:05<00:00, 10.47it/s]


Epoch 31: Train Loss = 3.010462, Val Loss = 3.300718


Epoch 32/200: 100%|██████████| 55/55 [00:05<00:00, 10.46it/s]


Epoch 32: Train Loss = 2.981384, Val Loss = 3.292304


Epoch 33/200: 100%|██████████| 55/55 [00:05<00:00, 10.77it/s]


Epoch 33: Train Loss = 3.009797, Val Loss = 3.285048
✓ Saved best model (val_loss=3.285048)


Epoch 34/200: 100%|██████████| 55/55 [00:05<00:00, 10.46it/s]


Epoch 34: Train Loss = 2.971409, Val Loss = 3.318648


Epoch 35/200: 100%|██████████| 55/55 [00:05<00:00, 10.44it/s]


Epoch 35: Train Loss = 2.959987, Val Loss = 3.296684


Epoch 36/200: 100%|██████████| 55/55 [00:05<00:00, 10.51it/s]


Epoch 36: Train Loss = 2.937953, Val Loss = 3.294926


Epoch 37/200: 100%|██████████| 55/55 [00:05<00:00, 10.45it/s]


Epoch 37: Train Loss = 2.913658, Val Loss = 3.288248


Epoch 38/200: 100%|██████████| 55/55 [00:05<00:00, 10.59it/s]


Epoch 38: Train Loss = 2.909888, Val Loss = 3.288682


Epoch 39/200: 100%|██████████| 55/55 [00:05<00:00, 10.34it/s]


Epoch 39: Train Loss = 2.895155, Val Loss = 3.307273


Epoch 40/200: 100%|██████████| 55/55 [00:05<00:00, 10.48it/s]


Epoch 40: Train Loss = 2.877329, Val Loss = 3.287952


Epoch 41/200: 100%|██████████| 55/55 [00:05<00:00, 10.47it/s]


Epoch 41: Train Loss = 2.858980, Val Loss = 3.281878
✓ Saved best model (val_loss=3.281878)


Epoch 42/200: 100%|██████████| 55/55 [00:05<00:00, 10.44it/s]


Epoch 42: Train Loss = 2.844573, Val Loss = 3.267917
✓ Saved best model (val_loss=3.267917)


Epoch 43/200: 100%|██████████| 55/55 [00:05<00:00, 10.29it/s]


Epoch 43: Train Loss = 2.843100, Val Loss = 3.256049
✓ Saved best model (val_loss=3.256049)


Epoch 44/200: 100%|██████████| 55/55 [00:05<00:00, 10.30it/s]


Epoch 44: Train Loss = 2.805466, Val Loss = 3.260415


Epoch 45/200: 100%|██████████| 55/55 [00:05<00:00, 10.75it/s]


Epoch 45: Train Loss = 2.813081, Val Loss = 3.269445


Epoch 46/200: 100%|██████████| 55/55 [00:05<00:00, 10.42it/s]


Epoch 46: Train Loss = 2.765885, Val Loss = 3.286369


Epoch 47/200: 100%|██████████| 55/55 [00:05<00:00, 10.32it/s]


Epoch 47: Train Loss = 2.761727, Val Loss = 3.257748


Epoch 48/200: 100%|██████████| 55/55 [00:05<00:00, 10.31it/s]


Epoch 48: Train Loss = 2.739805, Val Loss = 3.250500
✓ Saved best model (val_loss=3.250500)


Epoch 49/200: 100%|██████████| 55/55 [00:05<00:00, 10.26it/s]


Epoch 49: Train Loss = 2.742031, Val Loss = 3.294099


Epoch 50/200: 100%|██████████| 55/55 [00:05<00:00, 10.76it/s]


Epoch 50: Train Loss = 2.724362, Val Loss = 3.265693


Epoch 51/200: 100%|██████████| 55/55 [00:05<00:00, 10.42it/s]


Epoch 51: Train Loss = 2.717613, Val Loss = 3.277410


Epoch 52/200: 100%|██████████| 55/55 [00:05<00:00, 10.30it/s]


Epoch 52: Train Loss = 2.686120, Val Loss = 3.270539


Epoch 53/200: 100%|██████████| 55/55 [00:05<00:00, 10.27it/s]


Epoch 53: Train Loss = 2.673258, Val Loss = 3.285613


Epoch 54/200: 100%|██████████| 55/55 [00:05<00:00, 10.39it/s]


Epoch 54: Train Loss = 2.641142, Val Loss = 3.284139


Epoch 55/200: 100%|██████████| 55/55 [00:05<00:00, 10.30it/s]


Epoch 55: Train Loss = 2.644398, Val Loss = 3.263867


Epoch 56/200: 100%|██████████| 55/55 [00:05<00:00, 10.28it/s]


Epoch 56: Train Loss = 2.637483, Val Loss = 3.273742


Epoch 57/200: 100%|██████████| 55/55 [00:05<00:00, 10.64it/s]


Epoch 57: Train Loss = 2.597506, Val Loss = 3.273258


Epoch 58/200: 100%|██████████| 55/55 [00:05<00:00, 10.29it/s]


Epoch 58: Train Loss = 2.603845, Val Loss = 3.272361


Epoch 59/200: 100%|██████████| 55/55 [00:05<00:00, 10.32it/s]


Epoch 59: Train Loss = 2.589099, Val Loss = 3.255983
Finished training. Now testing using best model...
Test Results: {'mrr': 0.8955046637625308, 'ndcg': 0.9201528960664349, 'recall_at_1': 0.84144, 'recall_at_3': 0.94104, 'recall_at_5': 0.962, 'recall_at_10': 0.97936, 'recall_at_50': 0.99808, 'l2_dist': 30.50058937072754}


In [6]:
generate_submission(model, Path(test_path), device=device)

Generating submission file...
✓ Saved submission to submission.csv


Unnamed: 0,id,embedding
0,1,"[-0.13813145458698273, -0.5371889472007751, 0...."
1,2,"[-0.6189366579055786, -0.3590053915977478, 0.2..."
2,3,"[-0.7931116819381714, -0.1779468059539795, 0.4..."
3,4,"[0.45219361782073975, -1.561142086982727, -0.6..."
4,5,"[1.2697274684906006, 1.481618046760559, -0.018..."
...,...,...
1495,1496,"[0.2991337180137634, -0.4530474543571472, 0.47..."
1496,1497,"[-0.0717444121837616, 0.21529829502105713, -0...."
1497,1498,"[0.2169097661972046, -1.007057547569275, -0.54..."
1498,1499,"[0.27533355355262756, -0.3522982895374298, 0.1..."


In [28]:
from torchsummary import summary
summary(model, input_size=(1024,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 2048]       2,099,200
              SiLU-2                 [-1, 2048]               0
         LayerNorm-3                 [-1, 2048]           4,096
           Dropout-4                 [-1, 2048]               0
            Linear-5                 [-1, 4096]       8,392,704
              SiLU-6                 [-1, 4096]               0
         LayerNorm-7                 [-1, 4096]           8,192
           Dropout-8                 [-1, 4096]               0
            Linear-9                 [-1, 1536]       6,292,992
           Linear-10                  [-1, 512]         524,800
             SiLU-11                  [-1, 512]               0
        LayerNorm-12                  [-1, 512]           1,024
          Dropout-13                  [-1, 512]               0
           Linear-14                  [

In [30]:
import optuna
from optuna.pruners import MedianPruner
from pathlib import Path

ACTIVATIONS = {
    #"relu": nn.ReLU,
    "gelu": nn.GELU,
    "silu": nn.SiLU,
    'selu': nn.SELU,
    'celu': nn.CELU
    #"leakyrelu": nn.LeakyReLU
}

def objective(trial, train_dataset, val_dataset, input_dim, output_dim, epochs: int = 10, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Hyperparametri ---
    dir_n_layers = trial.suggest_int("dir_n_layers", 1, 4)
    dir_hidden_dims_choices = [1024, 1536, 2048, 4096]
    dir_hidden_dims = [trial.suggest_categorical(f"dir_l{i}_units", dir_hidden_dims_choices) for i in range(dir_n_layers)]

    scale_n_layers = trial.suggest_int("scale_n_layers", 1, 4)
    scale_hidden_dims_choices = [128, 256, 512, 1024, 1536]
    scale_hidden_dims = [trial.suggest_categorical(f"scale_l{i}_units", scale_hidden_dims_choices) for i in range(scale_n_layers)]

    activation_name = trial.suggest_categorical("activation", list(ACTIVATIONS.keys()))
    activation_fn = ACTIVATIONS[activation_name]

    batch_size = trial.suggest_categorical("batch_size", [2048, 4096])
    lr = trial.suggest_float("lr", 1e-6, 1e-3, log=True)
    dropout_rate = trial.suggest_categorical('dropout_rate', [0.2, 0.25, 0.3])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    temp = trial.suggest_float("temp", 0.01, 0.2, log=True)
    
    lambda_mag = trial.suggest_float("lambda_mag", 0.2, 1.3)

    # --- Modello ---
    model_args = {
        'input_dim': input_dim,
        'output_dim': output_dim,
        'dir_hidden_dims': dir_hidden_dims,
        'scale_hidden_dims': scale_hidden_dims,
        'activation': activation_fn,
        'dropout_rate': dropout_rate
    }
    model = Translator(**model_args).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            
            outputs = model(X_batch)
            
            loss = combined_loss(outputs, y_batch, temp, lambda_mag)
            loss.backward()
            
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)

        # --- Validation ---
        model.eval()
        
        val_loss = 0.0
        
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
                outputs = model(X_batch)
                val_loss += combined_loss(outputs, y_batch, temp, lambda_mag).item()
        
        val_loss /= len(val_loader)

        results = test(val_dataset, model, device)
        trial.report(results['recall_at_1'], epoch)

        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return results['recall_at_1']


def run_optuna_search(data_path: Path, n_trials: int = 30, epochs: int = 30, n_jobs: int = 1, sampler=None, pruner=None):
    if pruner is None:
        pruner = MedianPruner(n_startup_trials=5, n_warmup_steps=1)

    train_dataset, val_dataset = load_data(data_path)
    input_dim = train_dataset[0][0].shape[0]
    output_dim = train_dataset[0][1].shape[0]

    study = optuna.create_study(direction="maximize", pruner=pruner)
    func = lambda trial: objective(trial, train_dataset=train_dataset, val_dataset=val_dataset,
                                   input_dim=input_dim, output_dim=output_dim,
                                   epochs=epochs)
    study.optimize(func, n_trials=n_trials, n_jobs=n_jobs)

    print("Study statistics:")
    print("  Number of finished trials: ", len(study.trials))
    print("  Best trial:")
    trial = study.best_trial
    print("    Value: ", trial.value)
    print("    Params: ")
    for k, v in trial.params.items():
        print(f"      {k}: {v}")

    return study


In [31]:
study = run_optuna_search(data_path=data_path, n_trials=100, epochs=10, n_jobs=1)
study.trials_dataframe().to_csv("optuna_trials.csv", index=False)

best_trial_number = study.best_trial.number
print("Best params:", study.best_params)
print("Best trial number:", study.best_trial.number)

Texts shape torch.Size([125000, 1024])
Images shape torch.Size([125000, 1024])


[I 2025-11-02 21:32:45,301] A new study created in memory with name: no-name-dc48b929-0cc8-4de7-a056-aa3a39b6e3dd


X: mean of stds per dim = 0.788078248500824 , max = 3.573546886444092 , min = 0.3716050386428833
Y: mean of stds per dim = 0.4244377911090851 , max = 1.8597956895828247 , min = 0.08161858469247818


[I 2025-11-02 21:34:29,714] Trial 0 finished with value: 0.86928 and parameters: {'dir_n_layers': 4, 'dir_l0_units': 1024, 'dir_l1_units': 2048, 'dir_l2_units': 4096, 'dir_l3_units': 2048, 'scale_n_layers': 3, 'scale_l0_units': 256, 'scale_l1_units': 1024, 'scale_l2_units': 1024, 'activation': 'gelu', 'batch_size': 4096, 'lr': 0.0006374960046354015, 'dropout_rate': 0.3, 'temp': 0.010505979831525658, 'lambda_mag': 0.8413356793088904}. Best is trial 0 with value: 0.86928.
[I 2025-11-02 21:35:32,120] Trial 1 finished with value: 0.188 and parameters: {'dir_n_layers': 2, 'dir_l0_units': 1024, 'dir_l1_units': 2048, 'scale_n_layers': 3, 'scale_l0_units': 128, 'scale_l1_units': 512, 'scale_l2_units': 512, 'activation': 'silu', 'batch_size': 4096, 'lr': 2.030535103415105e-06, 'dropout_rate': 0.3, 'temp': 0.07212758127415753, 'lambda_mag': 1.2920576865786182}. Best is trial 0 with value: 0.86928.
[I 2025-11-02 21:37:23,009] Trial 2 finished with value: 0.3492 and parameters: {'dir_n_layers': 3,

Study statistics:
  Number of finished trials:  100
  Best trial:
    Value:  0.87928
    Params: 
      dir_n_layers: 4
      dir_l0_units: 2048
      dir_l1_units: 2048
      dir_l2_units: 1536
      dir_l3_units: 4096
      scale_n_layers: 2
      scale_l0_units: 128
      scale_l1_units: 256
      activation: gelu
      batch_size: 2048
      lr: 0.0004944305652643576
      dropout_rate: 0.25
      temp: 0.011284474643610163
      lambda_mag: 0.7763296874424117


AttributeError: module 'optuna.study.study' has no attribute 'Storage'

In [None]:
print(1)