# Setup dataloader

In [1]:
import os
from typing import List, Dict
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast
from data import load_and_tokenize_news, load_behaviors, MindDataset, mind_collate_fn

NEWS_TSV_PATH = "./data/MIND_val/news.tsv"
BEHAVIORS_TSV_PATH = "./data/MIND_val/behaviors.tsv"

MAX_TITLE_LEN = 100    # each headline → exactly 30 WordPiece IDs (truncated/padded)
MAX_HISTORY  = 50     # each user’s clicked history → exactly 50 articles

# Load BERT tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
PAD_ID = tokenizer.pad_token_id  # usually 0


news_dict = load_and_tokenize_news(NEWS_TSV_PATH, tokenizer, MAX_TITLE_LEN)
samples   = load_behaviors(BEHAVIORS_TSV_PATH, news_dict, MAX_HISTORY)

print(f"  → Total news articles tokenized: {len(news_dict)}")
print(f"  → Total impression samples loaded: {len(samples)}")


dataset = MindDataset(samples)
loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=mind_collate_fn
)

  → Total news articles tokenized: 42416
  → Total impression samples loaded: 73152


# Define model

In [2]:
from nrms import NRMS


model = NRMS(
    vocab_size=tokenizer.vocab_size,
    d_embed=768,
    n_heads=12,
    d_mlp=3072,
    news_layers=1,
    user_layers=1,
    dropout=0.1,
    pad_max_len=MAX_TITLE_LEN 
)

In [None]:
for clicked_ids, clicked_mask, cand_ids, cand_mask, labels in loader:
    print(model(clicked_ids, clicked_mask, cand_ids, cand_mask))
    break

: 