In [1]:
# setup
import os
import wandb
import torch
import torch.nn as nn
from nrms import NRMS
from typing import List, Dict
from torch.optim.adamw import AdamW
from tqdm import tqdm
import math
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, get_linear_schedule_with_warmup
from data import load_and_tokenize_news, load_behaviors, MindDataset, mind_collate_fn
from sklearn.metrics import roc_auc_score

In [2]:
def evaluate_model(model: torch.nn.Module, 
                   dataloader: DataLoader, 
                   device: torch.device):
    """
    Evaluates a news recommender model on a set of metrics:
      - AUC (averaged per‐impression)
      - MRR (Mean Reciprocal Rank)
      - nDCG@5
      - nDCG@10

    Assumptions:
      * Each batch from dataloader returns:
          clicked_ids:   torch.LongTensor of shape (B, L_click)
          clicked_mask:  torch.BoolTensor of  shape (B, L_click)
          cand_ids:      torch.LongTensor of shape (B, K, L_cand)
          cand_mask:     torch.BoolTensor of  shape (B, K, L_cand)
          labels:        torch.LongTensor of shape (B,)
            where labels[b] ∈ {0, …, K-1} is the index (in the candidate list)
            of the single “clicked” article for instance b.
      * The model’s forward pass is called as:
            scores: torch.Tensor = model(clicked_ids, ~clicked_mask, 
                                         cand_ids, cand_mask)
        and returns a FloatTensor of shape (B, K), where K is the number of candidates.

    Returns:
      A dict containing the four metrics:
        {
          "AUC": float,
          "MRR": float,
          "nDCG@5": float,
          "nDCG@10": float
        }
    """
    model.eval()
    model.to(device)

    total_auc = 0.0
    total_mrr = 0.0
    total_ndcg_5 = 0.0
    total_ndcg_10 = 0.0
    total_instances = 0

    with torch.no_grad():
        for batch in tqdm(dataloader):
            # Unpack
            clicked_ids, clicked_mask, cand_ids, cand_mask, labels = batch
            # Move to device
            clicked_ids  = clicked_ids.to(device)    # (B, L_click)
            clicked_mask = clicked_mask.to(device)   # (B, L_click)
            cand_ids     = cand_ids.to(device)       # (B, K, L_cand)
            cand_mask    = cand_mask.to(device)      # (B, K, L_cand)
            labels       = labels.to(device)         # (B,)

            # Forward pass → scores of shape (B, K)
            scores: torch.Tensor = model(
                clicked_ids, 
                ~clicked_mask,   # note: model expects the bitwise inverse of clicked_mask
                cand_ids, 
                cand_mask
            )  # scores[b, j] is the predicted score/logit for candidate j of instance b

            B, K = scores.shape
            total_instances += B

            # Convert to CPU+numpy for metric computations
            scores_cpu = scores.cpu().numpy()      # shape (B, K)
            labels_cpu = labels.cpu().numpy()      # shape (B,)

            for b in range(B):
                y_true = [0] * K
                true_index = int(labels_cpu[b])
                y_true[true_index] = 1

                y_score = scores_cpu[b]  # length-K array of floats

                # ----- AUC (per‐impression) -----
                # If there is exactly one positive and K-1 negatives, roc_auc_score still works.
                try:
                    auc_b = roc_auc_score(y_true, y_score)
                except ValueError:
                    # In the rare case all y_true are the same label (shouldn't happen if exactly one click),
                    # roc_auc_score will throw a ValueError. Fallback to 0.5.
                    auc_b = 0.5
                total_auc += auc_b

                # ----- MRR (Mean Reciprocal Rank) -----
                # Compute the rank (1-based) of the true_index in the sorted scores (descending)
                # We can do this by counting how many scores are strictly greater than the score at true_index.
                target_score = y_score[true_index]
                # Rank is 1 + #items whose score > target_score
                rank = 1 + int((y_score > target_score).sum())
                # If multiple candidates have exactly the same score as the target, this effectively assigns
                # the clicked item the worst possible rank among its ties. In practice, ties are rare.
                rr = 1.0 / rank
                total_mrr += rr

                # ----- nDCG@5 and nDCG@10 -----
                # The DCG formula for a single clicked item at position 'rank':
                #   DCG@k = 1 / log2(rank + 1)   if rank <= k
                #          = 0                   if rank > k
                #
                # Since there is exactly one positive, IDCG@k = 1.0 (clicked item at rank=1).
                # Hence nDCG@k = DCG@k / IDCG@k = DCG@k.

                # Precompute discount for this rank:
                discount = 1.0 / math.log2(rank + 1)

                # nDCG@5:
                if rank <= 5:
                    ndcg5_b = discount
                else:
                    ndcg5_b = 0.0
                total_ndcg_5 += ndcg5_b

                # nDCG@10:
                if rank <= 10:
                    ndcg10_b = discount
                else:
                    ndcg10_b = 0.0
                total_ndcg_10 += ndcg10_b

    # Compute averages
    avg_auc    = total_auc / total_instances
    avg_mrr    = total_mrr / total_instances
    avg_ndcg_5 = total_ndcg_5 / total_instances
    avg_ndcg_10= total_ndcg_10 / total_instances

    return {
        "AUC": avg_auc,
        "MRR": avg_mrr,
        "nDCG@5": avg_ndcg_5,
        "nDCG@10": avg_ndcg_10
    }


In [3]:
BASE_DATA_DIR = './data/MIND_'


MAX_TITLE_LEN = 100   # each headline → exactly MAX_TITLE_LEN tokens (truncated/padded)
MAX_HISTORY  = 50     # each user’s clicked history → exactly MAX_HISTORY articles
BACTH_SIZE = 20


tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
PAD_ID = tokenizer.pad_token_id

train_news_dict = load_and_tokenize_news(BASE_DATA_DIR+'train/news.tsv', tokenizer, MAX_TITLE_LEN)
train_samples   = load_behaviors(BASE_DATA_DIR+'train/behaviors.tsv', train_news_dict, MAX_HISTORY)

val_news_dict = load_and_tokenize_news(BASE_DATA_DIR+'val/news.tsv', tokenizer, MAX_TITLE_LEN)
val_samples   = load_behaviors(BASE_DATA_DIR+'val/behaviors.tsv', val_news_dict, MAX_HISTORY)


train_dataset = MindDataset(train_samples)
val_dataset = MindDataset(val_samples)

train_dl = DataLoader(
    train_dataset,
    batch_size=BACTH_SIZE,
    shuffle=True,
    collate_fn=mind_collate_fn
)

valid_dl = DataLoader(
    val_dataset,
    batch_size=BACTH_SIZE,
    shuffle=True,
    collate_fn=mind_collate_fn
)

# Load the model

In [4]:
from nrms import NRMS

CHECK_PATH = './checkpoints/checkpoint_epoch5.pt'

# Constants — make sure these match your training settings
MAX_HISTORY = 50
MAX_TITLE_LEN = 100
PAD_ID = 0  # [PAD] token for BERT

# This has to be the same as the trained model
model = NRMS(
    vocab_size=tokenizer.vocab_size,
    d_embed_word = 128,
    d_embed_news = 256,
    n_heads_news = 8,
    n_heads_user = 8,
    d_mlp_news = 512,
    d_mlp_user = 512,
    news_layers = 1,
    user_layers = 1,
    dropout = 0.1,
    pad_max_len = MAX_TITLE_LEN,
)

model.load_state_dict(torch.load(CHECK_PATH, map_location="cpu"))

  model.load_state_dict(torch.load(CHECK_PATH, map_location="cpu"))


<All keys matched successfully>

# Eval on train (sanity check really)

In [5]:
metrics = evaluate_model(model, train_dl, torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("---- Evaluation Metrics ----")
for metric_name, value in metrics.items():
    print(f"{metric_name} = {value:.4f}")

  return torch._transformer_encoder_layer_fwd(
  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
100%|██████████| 7849/7849 [20:51<00:00,  6.27it/s]

---- Evaluation Metrics ----
AUC = 0.9091
MRR = 0.3410
nDCG@5 = 0.3510
nDCG@10 = 0.4041





# Eval on val set
(should maybe eval on test set, take unseen subset of big MIND)

In [5]:
val_metrics = evaluate_model(model, valid_dl, torch.device("cuda" if torch.cuda.is_available() else "cpu"))
print("---- Evaluation Metrics ----")
for metric_name, value in val_metrics.items():
    print(f"{metric_name} = {value:.4f}")

  return torch._transformer_encoder_layer_fwd(
  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
100%|██████████| 3658/3658 [10:14<00:00,  5.95it/s]

---- Evaluation Metrics ----
AUC = 0.8892
MRR = 0.2609
nDCG@5 = 0.2722
nDCG@10 = 0.3270





# Results & Baselines

(All metrics - higher is better)


| Model | AUC | MRR | nDCG@5 | nDCG@10 |
| - | - | - | - | - |
| Epoch2 | 0.8893 | 0.2632 | 0.2728 | 0.3290 |
| Epoch5 | 0.8892 | 0.2609 | 0.2722 | 0.3270 |
| NRMS paper (on large MIND) | 0.6275 | 0.2985 | 0.3217 | 0.4139 |

