In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# github_recsys_twotower_with_robust_eval.py

import os
import math
import random
from typing import Dict

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

# Optional: for AUC metric
try:
    from sklearn.metrics import roc_auc_score
    SKLEARN_AVAILABLE = True
except ImportError:
    SKLEARN_AVAILABLE = False


########################################
# 1. CONFIG AND FEATURE DEFINITIONS
########################################

class Config:
    # File paths: adapt these to your actual locations
    train_balanced_path = "/content/drive/MyDrive/Project_Work/RepoRecSys/data/train_balanced.csv"
    train_negative_path = "/content/drive/MyDrive/Project_Work/RepoRecSys/data/train_negative.csv"
    test_balanced_path = "/content/drive/MyDrive/Project_Work/RepoRecSys/data/test_balanced.csv"
    test_negative_path = "/content/drive/MyDrive/Project_Work/RepoRecSys/data/test_negative.csv"

    # Model hyperparameters
    user_id_emb_dim = 64
    item_id_emb_dim = 64
    lang_emb_dim = 16
    hidden_dim = 128
    embedding_dim = 64  # final tower output dimension

    batch_size = 4096
    num_epochs = 10
    lr = 1e-3
    weight_decay = 1e-5

    # Contrastive / pointwise loss config
    temperature = 0.1  # Scale factor for cosine similarity logits
    contrastive_weight = 0.7  # weight for contrastive (InfoNCE) loss
    bce_weight = 0.3         # weight for pointwise BCE loss
    max_contrastive_negatives = 256  # max sampled explicit negatives per batch for InfoNCE

    # Validation split
    train_user_fraction = 0.8  # 80% users for train, 20% for validation

    # Ranking evaluation config
    eval_k_list = [5, 10]
    eval_num_negatives = 100  # per user for ranking eval

    # CUDA if available
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Random seeds for reproducibility
    seed = 42


# Columns in csvs
NUMERIC_REPO_COLS = [
    "watchers", "commits", "issues", "pull_requests",
    "mean_commits_language", "max_commits_language", "min_commits_language",
    "std_commits_language",
    "mean_pull_requests_language", "max_pull_requests_language",
    "min_pull_requests_language", "std_pull_requests_language",
    "mean_issues_language", "max_issues_language", "min_issues_language",
    "std_issues_language",
    "mean_watchers_language", "max_watchers_language",
    "min_watchers_language", "std_watchers_language",
    "events", "year",
    "weight", "cp", "avg_cp", "stddev",
]

CATEGORICAL_REPO_COLS = [
    "language_code",
]

USER_ID_COL = "id_user"
ITEM_ID_COL = "project_id"
TARGET_COL = "target"

In [None]:
########################################
# 2. DATA LOADING AND PREPROCESSING
########################################

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_raw_data(cfg: Config):
    train_bal = pd.read_csv(cfg.train_balanced_path)
    train_neg = pd.read_csv(cfg.train_negative_path)
    test_bal = pd.read_csv(cfg.test_balanced_path)
    test_neg = pd.read_csv(cfg.test_negative_path)
    return train_bal, train_neg, test_bal, test_neg


def build_id_mappings(train_pool: pd.DataFrame, test_df: pd.DataFrame):
    # Unique users and items across train_pool + test
    all_users = pd.concat([train_pool[USER_ID_COL], test_df[USER_ID_COL]]).unique()
    all_items = pd.concat([train_pool[ITEM_ID_COL], test_df[ITEM_ID_COL]]).unique()
    all_langs = pd.concat([train_pool["language_code"], test_df["language_code"]]).unique()

    user2idx = {uid: i for i, uid in enumerate(all_users)}
    item2idx = {iid: i for i, iid in enumerate(all_items)}
    lang2idx = {lid: i for i, lid in enumerate(all_langs)}

    idx2user = {i: uid for uid, i in user2idx.items()}
    idx2item = {i: iid for iid, i in item2idx.items()}
    idx2lang = {i: lid for lid, i in lang2idx.items()}

    return user2idx, item2idx, lang2idx, idx2user, idx2item, idx2lang


def compute_numeric_scalers(train_pool: pd.DataFrame):
    """
    Compute mean and std for each numeric repo feature on training pool data.
    Returns dicts: col -> (mean, std).
    """
    means = {}
    stds = {}
    for col in NUMERIC_REPO_COLS:
        col_values = train_pool[col].astype(float).values
        mean = col_values.mean()
        std = col_values.std()
        if std < 1e-6:
            std = 1.0
        means[col] = mean
        stds[col] = std
    return means, stds


def normalize_numeric_features(df: pd.DataFrame, means: Dict[str, float], stds: Dict[str, float]):
    df = df.copy()
    for col in NUMERIC_REPO_COLS:
        df[col] = (df[col].astype(float) - means[col]) / stds[col]
    return df


def split_train_val_by_user(
    train_pool: pd.DataFrame,
    train_user_fraction: float,
    seed: int,
):
    """
    Split the training pool into train_df and val_df based on users.
    """
    all_users = train_pool[USER_ID_COL].unique()
    rng = np.random.RandomState(seed)
    rng.shuffle(all_users)

    n_train_users = int(len(all_users) * train_user_fraction)
    train_users = set(all_users[:n_train_users])
    val_users = set(all_users[n_train_users:])

    train_df = train_pool[train_pool[USER_ID_COL].isin(train_users)].reset_index(drop=True)
    val_df = train_pool[train_pool[USER_ID_COL].isin(val_users)].reset_index(drop=True)

    return train_df, val_df


def build_item_feature_table_norm(
    train_pool: pd.DataFrame,
    test_df: pd.DataFrame,
    means: Dict[str, float],
    stds: Dict[str, float],
):
    """
    Build a unique, normalized item feature table: one row per project_id.
    """
    all_df = pd.concat([train_pool, test_df], ignore_index=True)
    all_df_norm = normalize_numeric_features(all_df, means, stds)

    # Keep one row per item; using last occurrence (could also aggregate)
    all_df_norm = all_df_norm.sort_values("events")
    item_table = all_df_norm.drop_duplicates(subset=[ITEM_ID_COL], keep="last")
    item_table = item_table.reset_index(drop=True)
    return item_table

In [None]:
########################################
# 3. DATASET AND DATALOADER
########################################

class TwoTowerDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        user2idx: Dict[int, int],
        item2idx: Dict[int, int],
        lang2idx: Dict[int, int],
    ):
        self.df = df.reset_index(drop=True)
        self.user2idx = user2idx
        self.item2idx = item2idx
        self.lang2idx = lang2idx

        # Pre-extract numpy arrays so __getitem__ is fast
        self.user_ids = self.df[USER_ID_COL].values
        self.item_ids = self.df[ITEM_ID_COL].values
        self.lang_codes = self.df["language_code"].values
        self.labels = self.df[TARGET_COL].astype(float).values

        self.numeric_matrix = self.df[NUMERIC_REPO_COLS].astype(float).values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        uid = self.user_ids[idx]
        iid = self.item_ids[idx]
        lang = self.lang_codes[idx]
        label = self.labels[idx]
        numerics = self.numeric_matrix[idx]

        u_idx = self.user2idx[uid]
        i_idx = self.item2idx[iid]
        l_idx = self.lang2idx[lang]

        # Convert to tensors
        u_idx = torch.tensor(u_idx, dtype=torch.long)
        i_idx = torch.tensor(i_idx, dtype=torch.long)
        l_idx = torch.tensor(l_idx, dtype=torch.long)
        numerics = torch.tensor(numerics, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)

        return u_idx, i_idx, l_idx, numerics, label


def make_dataloaders(
    cfg: Config,
    train_df: pd.DataFrame,
    val_df: pd.DataFrame,
    user2idx: Dict[int, int],
    item2idx: Dict[int, int],
    lang2idx: Dict[int, int],
):
    train_dataset = TwoTowerDataset(train_df, user2idx, item2idx, lang2idx)
    val_dataset = TwoTowerDataset(val_df, user2idx, item2idx, lang2idx)

    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

    return train_loader, val_loader

In [None]:
########################################
# 4. MODEL DEFINITIONS
########################################

class UserTower(nn.Module):
    def __init__(self, num_users: int, cfg: Config):
        super().__init__()
        self.user_emb = nn.Embedding(num_users, cfg.user_id_emb_dim)

        self.mlp = nn.Sequential(
            nn.Linear(cfg.user_id_emb_dim, cfg.hidden_dim),
            nn.ReLU(),
            nn.Linear(cfg.hidden_dim, cfg.embedding_dim),
        )

    def forward(self, user_ids: torch.Tensor) -> torch.Tensor:
        """
        user_ids: [B]
        returns: [B, embedding_dim]
        """
        x = self.user_emb(user_ids)  # [B, user_id_emb_dim]
        x = self.mlp(x)              # [B, embedding_dim]
        x = nn.functional.normalize(x, p=2, dim=-1)
        return x


class ItemTower(nn.Module):
    def __init__(self, num_items: int, num_langs: int, num_numeric_feats: int, cfg: Config):
        super().__init__()
        self.item_emb = nn.Embedding(num_items, cfg.item_id_emb_dim)
        self.lang_emb = nn.Embedding(num_langs, cfg.lang_emb_dim)

        input_dim = cfg.item_id_emb_dim + cfg.lang_emb_dim + num_numeric_feats

        self.mlp = nn.Sequential(
            nn.Linear(input_dim, cfg.hidden_dim),
            nn.ReLU(),
            nn.Linear(cfg.hidden_dim, cfg.embedding_dim),
        )

    def forward(
        self,
        item_ids: torch.Tensor,     # [B]
        lang_ids: torch.Tensor,     # [B]
        numeric_feats: torch.Tensor # [B, num_numeric_feats]
    ) -> torch.Tensor:
        """
        returns: [B, embedding_dim]
        """
        item_e = self.item_emb(item_ids)    # [B, item_id_emb_dim]
        lang_e = self.lang_emb(lang_ids)    # [B, lang_emb_dim]

        x = torch.cat([item_e, lang_e, numeric_feats], dim=-1)  # [B, input_dim]
        x = self.mlp(x)                                         # [B, embedding_dim]
        x = nn.functional.normalize(x, p=2, dim=-1)
        return x


class TwoTowerRecSys(nn.Module):
    def __init__(self, num_users: int, num_items: int, num_langs: int, num_numeric_feats: int, cfg: Config):
        super().__init__()
        self.user_tower = UserTower(num_users, cfg)
        self.item_tower = ItemTower(num_items, num_langs, num_numeric_feats, cfg)

    def forward(
        self,
        user_ids: torch.Tensor,
        item_ids: torch.Tensor,
        lang_ids: torch.Tensor,
        numeric_feats: torch.Tensor,
    ):
        """
        user_ids: [B]
        item_ids: [B]
        lang_ids: [B]
        numeric_feats: [B, F]

        returns:
            logits: [B] (dot products)
            user_embs: [B, D]
            item_embs: [B, D]
        """
        u = self.user_tower(user_ids)                          # [B, D]
        v = self.item_tower(item_ids, lang_ids, numeric_feats) # [B, D]

        logits = torch.sum(u * v, dim=-1)  # [B]
        return logits, u, v

In [None]:
def train_one_epoch_contrastive(
    model,
    train_loader,
    optimizer,
    device,
    temperature: float = 0.1,
):
    model.train()
    total_loss = 0.0
    total_examples = 0
    ce_criterion = nn.CrossEntropyLoss()

    for batch in train_loader:
        user_ids, item_ids, lang_ids, numerics, labels = batch

        # Use only positives for contrastive
        pos_mask = labels > 0.5
        if pos_mask.sum() < 2:
            continue

        user_ids = user_ids[pos_mask].to(device)
        item_ids = item_ids[pos_mask].to(device)
        lang_ids = lang_ids[pos_mask].to(device)
        numerics = numerics[pos_mask].to(device)

        optimizer.zero_grad()

        _, u_emb, v_emb = model(user_ids, item_ids, lang_ids, numerics)

        # In-batch negatives: [B, B]
        sim_matrix = torch.matmul(u_emb, v_emb.T)
        logits = sim_matrix / temperature

        batch_size = user_ids.size(0)
        targets = torch.arange(batch_size, device=device, dtype=torch.long)

        loss = ce_criterion(logits, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch_size
        total_examples += batch_size

    return total_loss / max(total_examples, 1)


def train_one_epoch_mixed(
    model: TwoTowerRecSys,
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device: str,
    temperature: float = 0.1,
    contrastive_weight: float = 0.7,
    bce_weight: float = 0.3,
    max_contrastive_negatives: int = 256,
):
    """
    Stage B: mixed objective for finetuning.

    - BCE on ALL examples (positives + explicit negatives).
    - Contrastive InfoNCE on POSITIVE examples, with additional explicit
      negatives sampled from the batch and added to the denominator.
    """
    model.train()
    total_loss = 0.0
    total_examples = 0

    ce_criterion = nn.CrossEntropyLoss()
    bce_criterion = nn.BCEWithLogitsLoss()

    for batch in train_loader:
        user_ids, item_ids, lang_ids, numerics, labels = batch

        user_ids = user_ids.to(device)
        item_ids = item_ids.to(device)
        lang_ids = lang_ids.to(device)
        numerics = numerics.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        logits, u_emb, v_emb = model(user_ids, item_ids, lang_ids, numerics)

        # 1) Pointwise BCE on all examples
        bce_loss = bce_criterion(logits, labels)

        # 2) Contrastive on positives only
        pos_mask = labels > 0.5
        num_pos = int(pos_mask.sum().item())

        if num_pos > 1 and contrastive_weight > 0.0:
            u_pos = u_emb[pos_mask]  # [P, D]
            v_pos = v_emb[pos_mask]  # [P, D]

            # All explicit negatives in the batch (label 0)
            neg_mask = labels < 0.5
            v_neg_all = v_emb[neg_mask]  # [N, D]

            if v_neg_all.size(0) > 0 and max_contrastive_negatives > 0:
                num_negs = min(max_contrastive_negatives, v_neg_all.size(0))
                perm = torch.randperm(v_neg_all.size(0), device=device)[:num_negs]
                v_neg = v_neg_all[perm]  # [num_negs, D]
                # Candidate items: positives + sampled negatives
                all_items = torch.cat([v_pos, v_neg], dim=0)  # [P + num_negs, D]
            else:
                # Fallback: only positives as in-batch negatives
                all_items = v_pos  # [P, D]

            # Similarity between positive users and all candidate items
            sim_matrix = torch.matmul(u_pos, all_items.T)  # [P, P + num_negs]
            ce_logits = sim_matrix / temperature

            # Each user i's positive is at index i
            targets = torch.arange(num_pos, device=device, dtype=torch.long)
            contrastive_loss = ce_criterion(ce_logits, targets)
        else:
            contrastive_loss = torch.tensor(0.0, device=device)

        # Combine losses
        loss = contrastive_weight * contrastive_loss + bce_weight * bce_loss

        loss.backward()
        optimizer.step()

        batch_size = user_ids.size(0)
        total_loss += loss.item() * batch_size
        total_examples += batch_size

    avg_loss = total_loss / max(total_examples, 1)
    return avg_loss




def evaluate_pointwise(
    model: TwoTowerRecSys,
    data_loader: DataLoader,
    device: str,
):
    """
    Simple pointwise evaluation:
    - average loss
    - ROC-AUC (if sklearn available)
    """
    model.eval()
    criterion = nn.BCEWithLogitsLoss()

    total_loss = 0.0
    total_examples = 0

    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch in data_loader:
            user_ids, item_ids, lang_ids, numerics, labels = batch
            user_ids = user_ids.to(device)
            item_ids = item_ids.to(device)
            lang_ids = lang_ids.to(device)
            numerics = numerics.to(device)
            labels = labels.to(device)

            logits, _, _ = model(user_ids, item_ids, lang_ids, numerics)
            loss = criterion(logits, labels)

            probs = torch.sigmoid(logits)

            batch_size = labels.size(0)
            total_loss += loss.item() * batch_size
            total_examples += batch_size

            all_labels.append(labels.detach().cpu().numpy())
            all_probs.append(probs.detach().cpu().numpy())

    avg_loss = total_loss / max(total_examples, 1)
    all_labels = np.concatenate(all_labels, axis=0)
    all_probs = np.concatenate(all_probs, axis=0)

    if SKLEARN_AVAILABLE:
        try:
            auc = roc_auc_score(all_labels, all_probs)
        except ValueError:
            auc = float("nan")
    else:
        auc = float("nan")

    return avg_loss, auc

In [None]:
########################################
# 6. ROBUST RANKING METRICS (Recall@K, NDCG@K)
########################################

def compute_dcg(relevances: np.ndarray, k: int) -> float:
    """
    relevances: binary or graded relevance scores sorted in ranking order.
    """
    rel = relevances[:k]
    if len(rel) == 0:
        return 0.0
    discounts = np.log2(np.arange(2, len(rel) + 2))
    return float(np.sum((2**rel - 1) / discounts))


def evaluate_ranking_with_neg_sampling(
    model: TwoTowerRecSys,
    test_df: pd.DataFrame,
    train_pool: pd.DataFrame,
    item_table: pd.DataFrame,
    user2idx: Dict[int, int],
    item2idx: Dict[int, int],
    lang2idx: Dict[int, int],
    device: str,
    k_list=None,
    num_negatives: int = 100,
):
    """
    More robust ranking evaluation:
    For each user in test_df:
      - Positives = repos in test_df with target=1.
      - Negatives = sample 'num_negatives' items from full catalog that the user
        has never positively interacted with (train_pool + test_df).
      - Rank positive + sampled negatives with the model.
      - Compute Recall@K and NDCG@K for each K.

    Returns: dict like
      {
        "Recall@5": value,
        "Recall@10": value,
        "NDCG@5": value,
        "NDCG@10": value,
        "num_users": n_users_evaluated
      }
    """
    if k_list is None:
        k_list = [5, 10]

    model.eval()

    # Pre-build mapping from project_id -> row index in item_table
    item_row_map = {pid: idx for idx, pid in enumerate(item_table[ITEM_ID_COL].values)}
    all_item_ids = item_table[ITEM_ID_COL].values

    # For each user: set of items they have positively interacted with anywhere
    user_pos_all = (
        pd.concat([train_pool, test_df], ignore_index=True)
        .query(f"{TARGET_COL} == 1")
        .groupby(USER_ID_COL)[ITEM_ID_COL]
        .agg(lambda x: set(x.values))
    )

    # Positives only from test
    user_pos_test = (
        test_df.query(f"{TARGET_COL} == 1")
        .groupby(USER_ID_COL)[ITEM_ID_COL]
        .agg(lambda x: list(set(x.values)))
    )

    # Users we can evaluate on (have positives in test and exist in mapping)
    eval_users = [
        uid for uid in user_pos_test.index
        if (uid in user2idx) and (uid in user_pos_all.index)
    ]

    recalls = {k: [] for k in k_list}
    ndcgs = {k: [] for k in k_list}

    with torch.no_grad():
        for uid in eval_users:
            pos_items = user_pos_test[uid]
            if len(pos_items) == 0:
                continue

            all_pos = user_pos_all[uid]
            seen_items = set(all_pos)

            # Candidate negative pool = all items - seen positives
            candidate_neg = [pid for pid in all_item_ids if pid not in seen_items]
            if len(candidate_neg) == 0:
                continue

            # Sample negatives
            num_neg = min(num_negatives, len(candidate_neg))
            neg_items = random.sample(candidate_neg, num_neg)

            # Build candidate list: positives + negatives
            cand_items = list(pos_items) + neg_items
            cand_labels = np.array(
                [1] * len(pos_items) + [0] * len(neg_items),
                dtype=np.int32,
            )

            # Map to item_table rows
            cand_indices = [item_row_map[pid] for pid in cand_items]
            cand_rows = item_table.iloc[cand_indices]

            item_ids_raw = cand_rows[ITEM_ID_COL].values
            lang_raw = cand_rows["language_code"].values

            item_indices = torch.tensor(
                [item2idx[i] for i in item_ids_raw],
                dtype=torch.long,
                device=device,
            )
            lang_indices = torch.tensor(
                [lang2idx[l] for l in lang_raw],
                dtype=torch.long,
                device=device,
            )

            numerics_tensor = torch.tensor(
                cand_rows[NUMERIC_REPO_COLS].astype(float).values,
                dtype=torch.float32,
                device=device,
            )

            u_idx = user2idx[uid]
            user_tensor = torch.tensor([u_idx], dtype=torch.long, device=device)
            user_batch = user_tensor.expand(len(cand_items))

            logits, _, _ = model(user_batch, item_indices, lang_indices, numerics_tensor)
            scores = torch.sigmoid(logits).detach().cpu().numpy()

            # Sort candidates by score
            ranking = np.argsort(-scores)  # descending
            sorted_labels = cand_labels[ranking]

            num_pos = float(len(pos_items))

            for k in k_list:
                k_eff = min(k, len(sorted_labels))
                topk_labels = sorted_labels[:k_eff]

                # Recall@K: how many positives in top-K divided by total positives
                recall_k = float(topk_labels.sum() / num_pos)
                recalls[k].append(recall_k)

                # NDCG@K
                dcg_k = compute_dcg(sorted_labels, k_eff)
                # Ideal DCG: all positives ranked at top
                ideal_labels = np.sort(cand_labels)[::-1]
                idcg_k = compute_dcg(ideal_labels, k_eff)
                ndcg_k = dcg_k / idcg_k if idcg_k > 0 else 0.0
                ndcgs[k].append(ndcg_k)

    metrics = {}
    n_users = len(eval_users)
    metrics["num_users"] = n_users
    for k in k_list:
        if len(recalls[k]) > 0:
            metrics[f"Recall@{k}"] = float(np.mean(recalls[k]))
        else:
            metrics[f"Recall@{k}"] = float("nan")
        if len(ndcgs[k]) > 0:
            metrics[f"NDCG@{k}"] = float(np.mean(ndcgs[k]))
        else:
            metrics[f"NDCG@{k}"] = float("nan")

    return metrics

In [None]:
########################################
# 7. MAIN SCRIPT
########################################

def main():
    cfg = Config()
    set_seed(cfg.seed)

    print("Loading raw data...")
    train_bal, train_neg, test_bal, test_neg = load_raw_data(cfg)

    # Combine balanced + negative splits
    train_pool = pd.concat([train_bal, train_neg], ignore_index=True)
    test_df = pd.concat([test_bal, test_neg], ignore_index=True)
    print(f"Train pool size: {len(train_pool)}, Test size: {len(test_df)}")

    # Split train pool into train/val *by user*
    print("Splitting train pool into train and validation by user...")
    train_df_all, val_df = split_train_val_by_user(
        train_pool,
        train_user_fraction=cfg.train_user_fraction,
        seed=cfg.seed,
    )

    num_train_pos = int((train_df_all[TARGET_COL] == 1).sum())
    num_train_neg = int((train_df_all[TARGET_COL] == 0).sum())
    print(
        f"Train size (pos+neg): {len(train_df_all)} "
        f"[pos={num_train_pos}, neg={num_train_neg}], "
        f"Val size (mixed): {len(val_df)}"
    )

    # Build mappings over all users/items/langs present in train+test
    print("Building ID mappings (users/items/langs)...")
    user2idx, item2idx, lang2idx, idx2user, idx2item, idx2lang = build_id_mappings(
        train_pool, test_df
    )
    num_users = len(user2idx)
    num_items = len(item2idx)
    num_langs = len(lang2idx)
    print(f"Num users: {num_users}, num items: {num_items}, num langs: {num_langs}")

    # Fit numeric feature scalers on the entire training pool
    print("Fitting numeric scalers on training pool...")
    means, stds = compute_numeric_scalers(train_pool)

    # Create POSITIVE-ONLY and FULL views of the training data
    train_df_pos = train_df_all[train_df_all[TARGET_COL] == 1].reset_index(drop=True)
    train_df_full = train_df_all.reset_index(drop=True)

    print("Normalizing numeric features...")
    train_df_pos_norm = normalize_numeric_features(train_df_pos, means, stds)
    train_df_full_norm = normalize_numeric_features(train_df_full, means, stds)
    val_df_norm = normalize_numeric_features(val_df, means, stds)
    test_df_norm = normalize_numeric_features(test_df, means, stds)

    # Create TWO train loaders:
    #  - train_loader_contrastive: positive-only for Stage A
    #  - train_loader_full: pos+neg for Stage B (mixed loss)
    print("Creating dataloaders for Stage A (contrastive) and Stage B (mixed)...")
    train_loader_contrastive, _ = make_dataloaders(
        cfg, train_df_pos_norm, val_df_norm, user2idx, item2idx, lang2idx
    )
    train_loader_full, val_loader = make_dataloaders(
        cfg, train_df_full_norm, val_df_norm, user2idx, item2idx, lang2idx
    )

    # Initialize model
    print("Initializing model...")
    model = TwoTowerRecSys(
        num_users=num_users,
        num_items=num_items,
        num_langs=num_langs,
        num_numeric_feats=len(NUMERIC_REPO_COLS),
        cfg=cfg,
    ).to(cfg.device)

    # -------------------------
    # Stage A: pure contrastive
    # -------------------------
    pretrain_lr = cfg.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=pretrain_lr, weight_decay=cfg.weight_decay)

    num_pretrain_epochs = 5  # you can tune this
    best_stageA_auc = -1.0

    for epoch in range(1, num_pretrain_epochs + 1):
        print(f"\n[Stage A] Contrastive pretraining - Epoch {epoch}/{num_pretrain_epochs}")

        train_loss = train_one_epoch_contrastive(
            model=model,
            train_loader=train_loader_contrastive,
            optimizer=optimizer,
            device=cfg.device,
            temperature=cfg.temperature,
        )
        print(f"Train loss (contrastive): {train_loss:.4f}")

        # Pointwise val AUC as a sanity check
        val_loss, val_auc = evaluate_pointwise(model, val_loader, cfg.device)
        print(f"Val loss (pointwise): {val_loss:.4f}, Val AUC: {val_auc:.4f}")

        if SKLEARN_AVAILABLE and not math.isnan(val_auc) and val_auc > best_stageA_auc:
            best_stageA_auc = val_auc
            torch.save(model.state_dict(), "best_twotower_stageA.pt")
            print("Saved Stage A best model based on validation AUC.")

    # Load best Stage A model
    if os.path.exists("best_twotower_stageA.pt"):
        print("\nLoading Stage A best checkpoint for finetuning...")
        model.load_state_dict(torch.load("best_twotower_stageA.pt", map_location=cfg.device))
    else:
        print("\nWARNING: Stage A checkpoint not found, proceeding with current model.")

    # -------------------------
    # Stage B: mixed finetuning
    # -------------------------
    # Use smaller LR for finetuning
    finetune_lr = cfg.lr * 0.1  # e.g., 1e-4 if cfg.lr is 1e-3
    optimizer = torch.optim.Adam(
        model.parameters(), lr=finetune_lr, weight_decay=cfg.weight_decay
    )

    num_finetune_epochs = 2  # short finetune
    best_stageB_auc = -1.0

    for epoch in range(1, num_finetune_epochs + 1):
        print(f"\n[Stage B] Mixed finetune - Epoch {epoch}/{num_finetune_epochs}")

        train_loss = train_one_epoch_mixed(
            model=model,
            train_loader=train_loader_full,
            optimizer=optimizer,
            device=cfg.device,
            temperature=cfg.temperature,
            contrastive_weight=cfg.contrastive_weight,
            bce_weight=cfg.bce_weight,
            max_contrastive_negatives=cfg.max_contrastive_negatives,
        )
        print(f"Train loss (mixed contrastive+BCE): {train_loss:.4f}")

        val_loss, val_auc = evaluate_pointwise(model, val_loader, cfg.device)
        print(f"Val loss (pointwise): {val_loss:.4f}, Val AUC: {val_auc:.4f}")

        if SKLEARN_AVAILABLE and not math.isnan(val_auc) and val_auc > best_stageB_auc:
            best_stageB_auc = val_auc
            torch.save(model.state_dict(), "best_twotower_stageAB.pt")
            print("Saved Stage B best model based on validation AUC.")

    # Decide which checkpoint to use for final test: prefer Stage B if it exists
    final_ckpt_path = "best_twotower_stageAB.pt"
    if not os.path.exists(final_ckpt_path):
        final_ckpt_path = "best_twotower_stageA.pt"

    if os.path.exists(final_ckpt_path):
        print(f"\nLoading final best model from checkpoint: {final_ckpt_path}")
        model.load_state_dict(torch.load(final_ckpt_path, map_location=cfg.device))
    else:
        print("\nWARNING: No checkpoint found; using current in-memory model for test.")

    # -------------------------
    # Final test evaluation
    # -------------------------
    print("\nCreating test DataLoader for pointwise metrics...")
    test_dataset = TwoTowerDataset(test_df_norm, user2idx, item2idx, lang2idx)
    test_loader = DataLoader(
        test_dataset,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

    test_loss, test_auc = evaluate_pointwise(model, test_loader, cfg.device)
    print(f"Test loss: {test_loss:.4f}, Test AUC: {test_auc:.4f}")

    print("\nBuilding normalized item feature table for ranking evaluation...")
    item_table_norm = build_item_feature_table_norm(train_pool, test_df, means, stds)

    print("Evaluating ranking metrics with negative sampling...")
    rank_metrics = evaluate_ranking_with_neg_sampling(
        model=model,
        test_df=test_df,
        train_pool=train_pool,
        item_table=item_table_norm,
        user2idx=user2idx,
        item2idx=item2idx,
        lang2idx=lang2idx,
        device=cfg.device,
        k_list=cfg.eval_k_list,
        num_negatives=cfg.eval_num_negatives,
    )

    print(f"Ranking evaluation over {rank_metrics['num_users']} users:")
    for k in cfg.eval_k_list:
        print(
            f"  Recall@{k}: {rank_metrics.get(f'Recall@{k}', float('nan')):.4f}, "
            f"NDCG@{k}: {rank_metrics.get(f'NDCG@{k}', float('nan')):.4f}"
        )


if __name__ == "__main__":
    main()

Loading raw data...
Train pool size: 567339, Test size: 141832
Splitting train pool into train and validation by user...
Train size (pos+neg): 453817 [pos=293817, neg=160000], Val size (mixed): 113522
Building ID mappings (users/items/langs)...
Num users: 10000, num items: 365628, num langs: 119
Fitting numeric scalers on training pool...
Normalizing numeric features...
Creating dataloaders for Stage A (contrastive) and Stage B (mixed)...
Initializing model...

[Stage A] Contrastive pretraining - Epoch 1/5
Train loss (contrastive): 8.3370
Val loss (pointwise): 0.6789, Val AUC: 0.5001
Saved Stage A best model based on validation AUC.

[Stage A] Contrastive pretraining - Epoch 2/5
Train loss (contrastive): 8.2596
Val loss (pointwise): 0.6860, Val AUC: 0.4984

[Stage A] Contrastive pretraining - Epoch 3/5
Train loss (contrastive): 7.9308
Val loss (pointwise): 0.6801, Val AUC: 0.4986

[Stage A] Contrastive pretraining - Epoch 4/5
Train loss (contrastive): 7.3453
Val loss (pointwise): 0.684