In [4]:
from typing import List
import torch
from transformers import BertTokenizer

# Constants — make sure these match your training settings
MAX_HISTORY = 50
MAX_TITLE_LEN = 100
PAD_ID = 0  # [PAD] token for BERT

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

def tokenize_titles(titles: List[str], max_len: int = MAX_TITLE_LEN) -> torch.Tensor:
    """
    Tokenizes and pads a list of article titles using BERT tokenizer.
    Returns: token_ids (N, max_len), padding_mask (N, max_len)
    """
    encodings = tokenizer(
        titles,
        padding="max_length",
        truncation=True,
        max_length=max_len,
        return_tensors="pt",
        return_attention_mask=True,
        add_special_tokens=False  # NRMS does not expect [CLS] or [SEP]
    )
    token_ids = encodings["input_ids"]       # (N, max_len)
    padding_mask = ~encodings["attention_mask"].bool()  # True = pad
    return token_ids, padding_mask

def recommend_topk_from_titles(
    model: torch.nn.Module,
    history_titles: List[str],
    candidate_titles: List[str],
    topk: int = 5,
    device: torch.device = torch.device("cpu")
) -> List[str]:
    """
    Recommends top-k titles from a list of candidate article titles,
    given a user's clicked history (also as titles).

    Args:
        model:            Trained NRMS model.
        history_titles:   List of clicked article titles (strings).
        candidate_titles: List of candidate article titles (strings).
        topk:             Number of top articles to return.
        device:           Torch device to run the model on.

    Returns:
        List of top-k recommended article titles (strings).
    """
    model.to(device)
    model.eval()

    # 1. Tokenize history and candidates
    hist_tokens, hist_mask = tokenize_titles(history_titles, max_len=MAX_TITLE_LEN)
    cand_tokens, cand_mask = tokenize_titles(candidate_titles, max_len=MAX_TITLE_LEN)

    # 2. Pad history to MAX_HISTORY size
    num_hist = len(history_titles)
    if num_hist < MAX_HISTORY:
        pad_len = MAX_HISTORY - num_hist
        pad_tokens = torch.full((pad_len, MAX_TITLE_LEN), PAD_ID, dtype=torch.long)
        pad_mask = torch.ones((pad_len, MAX_TITLE_LEN), dtype=torch.bool)
        hist_tokens = torch.cat([pad_tokens, hist_tokens], dim=0)
        hist_mask = torch.cat([pad_mask, hist_mask], dim=0)
    elif num_hist > MAX_HISTORY:
        hist_tokens = hist_tokens[-MAX_HISTORY:]
        hist_mask = hist_mask[-MAX_HISTORY:]

    # 3. Add batch dimension
    clicked_ids = hist_tokens.unsqueeze(0).to(device)    # (1, MAX_HISTORY, MAX_TITLE_LEN)
    clicked_mask = hist_mask.unsqueeze(0).to(device)     # (1, MAX_HISTORY, MAX_TITLE_LEN)
    cand_ids = cand_tokens.unsqueeze(0).to(device)       # (1, K, MAX_TITLE_LEN)
    cand_mask = cand_mask.unsqueeze(0).to(device)        # (1, K, MAX_TITLE_LEN)

    # 4. Forward pass
    with torch.no_grad():
        logits = model(clicked_ids, clicked_mask, cand_ids, cand_mask)  # (1, K)

    scores = logits.squeeze(0)  # (K,)
    topk_vals, topk_idxs = torch.topk(scores, k=min(topk, scores.size(0)))

    return [candidate_titles[i] for i in topk_idxs.tolist()]

# Load model from checkpoint

In [None]:
from nrms import NRMS

CHECK_PATH = './checkpoints/checkpoint_epoch1_step1000.pt'

# This has to be the same as the trained model
model = NRMS(
    vocab_size=tokenizer.vocab_size,
    d_embed=512,
    n_heads=8,
    d_mlp=2048,
    news_layers=1,
    user_layers=1,
    dropout=0.1,
    pad_max_len=MAX_TITLE_LEN 
)

In [22]:
model.load_state_dict(torch.load(CHECK_PATH, map_location="cpu"))

history = [
    "OpenAI Unveils GPT-5 Major Leap in Multimodal AI Capabilities",
    "Apple Confirms WWDC 2025 Event Expected Focus on Vision Pro 2 and AI Tools",
    "Google Integrates Gemini AI Across Android 15 What It Means for Users",
]

candidates = [
    "UN Urges Immediate Ceasefire in Sudan as Humanitarian Crisis Deepens",
    "Frog population increases in Amazon",
    "NVIDIA Surpasses $3 Trillion Market Cap Amid AI Chip Boom",
    "New 2025 Study Reveals Link Between Sleep Quality and Mental Health",
    "Global Inflation Slows, But Food Prices Remain Stubbornly High",
    "Israel and Hamas Agree to Extend Ceasefire for Humanitarian Aid",
    "TECH vibes lmao: Xiami's New iPhone 16 Pro Max Features Leaked",
    "Broadway Revival of The Phantom of the Opera Sells Out in Record Time",
]

top_titles = recommend_topk_from_titles(
    model=model,
    history_titles=history,
    candidate_titles=candidates,
    topk=3,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)

print("Top recommendations:")
for title in top_titles:
    print(" •", title)


Top recommendations:
 • TECH vibes lmao: Xiami's New iPhone 16 Pro Max Features Leaked
 • NVIDIA Surpasses $3 Trillion Market Cap Amid AI Chip Boom
 • Broadway Revival of The Phantom of the Opera Sells Out in Record Time


  model.load_state_dict(torch.load(CHECK_PATH, map_location="cpu"))
