In [None]:
# setup
import os
import wandb
import torch
import torch.nn as nn
from nrms import NRMS
from typing import List, Dict
from torch.optim import AdamW
from tqdm import tqdm
import math
from torch_optimizer import AdamP
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast
from data import load_and_tokenize_news, load_behaviors, MindDataset, mind_collate_fn

# Setup dataloader

In [2]:
BASE_DATA_DIR = './data/MIND_'


MAX_TITLE_LEN = 100   # each headline → exactly MAX_TITLE_LEN tokens (truncated/padded)
MAX_HISTORY  = 50     # each user’s clicked history → exactly MAX_HISTORY articles
BACTH_SIZE = 6


tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
PAD_ID = tokenizer.pad_token_id

train_news_dict = load_and_tokenize_news(BASE_DATA_DIR+'train/news.tsv', tokenizer, MAX_TITLE_LEN)
train_samples   = load_behaviors(BASE_DATA_DIR+'train/behaviors.tsv', train_news_dict, MAX_HISTORY)

val_news_dict = load_and_tokenize_news(BASE_DATA_DIR+'val/news.tsv', tokenizer, MAX_TITLE_LEN)
val_samples   = load_behaviors(BASE_DATA_DIR+'val/behaviors.tsv', val_news_dict, MAX_HISTORY)


train_dataset = MindDataset(train_samples)
val_dataset = MindDataset(val_samples)

train_dl = DataLoader(
    train_dataset,
    batch_size=BACTH_SIZE,
    shuffle=True,
    collate_fn=mind_collate_fn
)

valid_dl = DataLoader(
    val_dataset,
    batch_size=BACTH_SIZE,
    shuffle=True,
    collate_fn=mind_collate_fn
)

# Define model

In [3]:
model = NRMS(
    vocab_size=tokenizer.vocab_size,
    d_embed_word = 128,
    d_embed_news = 256,
    n_heads_news = 8,
    n_heads_user = 8,
    d_mlp_news = 512,
    d_mlp_user = 512,
    news_layers = 1,
    user_layers = 1,
    dropout = 0.1,
    pad_max_len = MAX_TITLE_LEN,
)

In [4]:
f'{sum(p.numel() for p in model.parameters()): ,}'

' 5,092,480'

# Train loop

In [None]:
def train(
    model,
    train_dataloader,
    val_dataloader,
    epochs: int = 2,
    lr: float = 1e-4,
    device: str = "cuda",
    log_interval: int = 100,
    checkpoint_interval: int = 10000,
    project_name: str = "NRMS",
    save_path: str = "./checkpoints/",
    ):
    """
    Trains `model` using train_dataloader, evaluates on val_dataloader each epoch,
    logs metrics to W&B, and finally saves model parameters to `save_path`.

    Args:
        model: a PyTorch nn.Module that returns logits of shape (B, K) given:
               (clicked_ids, clicked_mask, cand_ids, cand_mask)
        train_dataloader: torch.utils.data.DataLoader for training
        val_dataloader: torch.utils.data.DataLoader for validation
        epochs: number of epochs to train
        lr: learning rate for Adam
        device: "cuda" or "cpu"
        project_name: W&B project name
        save_path: where to save the final model.state_dict()
    """
    # Move model to device
    model.to(device)

    # Initialize W&B
    wandb.init(
        project=project_name,
        config={
            "epochs": epochs,
            "learning_rate": lr,
            "optimizer": "Adam",
            "loss_fn": "CrossEntropyLoss",
        },
    )
    wandb.watch(model, log="parameters", log_freq=500)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(),
                      lr=lr,
                      betas=(0.9, 0.999),
                      eps=1e-8,
                      weight_decay=1e-4
                )

    step = 0 # increases each batch

    for epoch in range(1, epochs + 1):
        ##### Training Phase #####
        model.train()
        total_train_loss = 0.0
        total_train_correct = 0
        total_train_samples = 0

        for clicked_ids, clicked_mask, cand_ids, cand_mask, labels in tqdm(train_dataloader, desc=f"Epoch {epoch} [Train]"):
            clicked_ids = clicked_ids.to(device)
            clicked_mask = clicked_mask.to(device)
            cand_ids = cand_ids.to(device)
            cand_mask = cand_mask.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            scores = model(clicked_ids, ~clicked_mask, cand_ids, cand_mask)

            # Compute training loss
            loss = criterion(scores, labels)
            loss.backward()
            optimizer.step()

            # Accumulate stats
            total_train_loss += loss.item() * labels.size(0)
            preds = scores.argmax(dim=1)
            total_train_correct += (preds == labels).sum().item()
            total_train_samples += labels.size(0)

            if step % log_interval == 0:
                # Log to console
                # print(
                #     f"Step {step} | "
                #     f"Train Loss: {loss.item():.4f}, "
                #     f"Train Acc: {(preds == labels).float().mean().item():.4f}"
                # )
                # Log to W&B
                wandb.log(
                    {
                        "train/loss": loss.item(),
                        "train/accuracy": (preds == labels).float().mean().item(),
                        "step": step,
                    }
                )

            if step % checkpoint_interval == 0:
                print(f'Saving checkpoint at step {step}...')
                checkpoint_path = save_path + f"checkpoint_epoch{epoch}_step{step}.pt"
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")


            step += 1


        avg_train_loss = total_train_loss / total_train_samples
        train_accuracy = total_train_correct / total_train_samples

        ##### Validation Phase #####
        model.eval()
        total_val_loss = 0.0
        total_val_correct = 0
        total_val_samples = 0

        with torch.no_grad():
            for clicked_ids, clicked_mask, cand_ids, cand_mask, labels in tqdm(val_dataloader, desc=f"Epoch {epoch} [Val]"):
                clicked_ids = clicked_ids.to(device)
                clicked_mask = clicked_mask.to(device)
                cand_ids = cand_ids.to(device)
                cand_mask = cand_mask.to(device)
                labels = labels.to(device)

                # Forward pass
                scores = model(clicked_ids, ~clicked_mask, cand_ids, cand_mask)

                # Compute validation loss
                loss = criterion(scores, labels)
                total_val_loss += loss.item() * labels.size(0)

                preds = scores.argmax(dim=1)
                total_val_correct += (preds == labels).sum().item()
                total_val_samples += labels.size(0)

        avg_val_loss = total_val_loss / total_val_samples
        val_accuracy = total_val_correct / total_val_samples

        # Log to console
        print(
            f"Epoch {epoch:02d} | "
            f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f} | "
            f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}"
        )

        # Log to W&B
        wandb.log(
            {
                "epoch": epoch,
                "val/loss": avg_val_loss,
                "val/accuracy": val_accuracy,
            }
        )

        # Save model checkpoint after epoch
        print(f"Saving model parameters for epoch {epoch}...")
        checkpoint_path = save_path + f"checkpoint_epoch{epoch}.pt"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    # Finish the W&B run
    wandb.finish()


In [None]:
def train(
    model,
    train_dataloader,
    val_dataloader,
    epochs: int = 2,
    lr: float = 1e-4,
    device: str = "cuda",
    log_interval: int = 100,
    checkpoint_interval: int = 10000,
    project_name: str = "NRMS",
    save_path: str = "./checkpoints/",
):
    """
    Trains `model` using train_dataloader, evaluates on val_dataloader each epoch,
    logs loss + MRR to W&B (with step as the x-axis), and finally saves model parameters.
    """
    model.to(device)

    # Initialize W&B
    wandb.init(
        project=project_name,
        config={
            "epochs": epochs,
            "learning_rate": lr,
            "optimizer": "Adam",
            "loss_fn": "CrossEntropyLoss",
            "metric": "MRR",
        },
    )
    # Only log weight histograms (no gradients) to cut down on storage
    wandb.watch(model, log="parameters", log_freq=500)

    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(),
                      lr=lr,
                      betas=(0.9, 0.999),
                      eps=1e-8,
                      weight_decay=1e-4
                )

    step = 0  # global batch index

    for epoch in range(1, epochs + 1):
        ##### Training Phase #####
        model.train()
        total_train_loss = 0.0
        total_train_mrr = 0.0
        total_train_samples = 0

        for clicked_ids, clicked_mask, cand_ids, cand_mask, labels in tqdm(
            train_dataloader, desc=f"Epoch {epoch} [Train]"
        ):
            clicked_ids = clicked_ids.to(device)
            clicked_mask = clicked_mask.to(device)
            cand_ids = cand_ids.to(device)
            cand_mask = cand_mask.to(device)
            labels = labels.to(device)  # (B,)

            optimizer.zero_grad()

            # Forward pass → logits of shape (B, K)
            scores: torch.Tensor = model(
                clicked_ids, ~clicked_mask, cand_ids, cand_mask
            )  # (B, K)
            loss = criterion(scores, labels)
            loss.backward()
            optimizer.step()

            # Compute batch MRR
            with torch.no_grad():
                batch_size = labels.size(0)
                batch_mrr = 0.0
                for i in range(batch_size):
                    true_idx = labels[i].item()
                    sorted_indices = scores[i].argsort(descending=True)
                    rank = (sorted_indices == true_idx).nonzero(as_tuple=False).item() + 1
                    batch_mrr += 1.0 / rank
                batch_mrr /= batch_size

            # Accumulate totals (for printing at end of epoch)
            total_train_loss += loss.item() * batch_size
            total_train_mrr += batch_mrr * batch_size
            total_train_samples += batch_size

            # Log scalars to W&B every log_interval steps
            if step % log_interval == 0:
                wandb.log(
                    {
                        "train/loss": loss.item(),
                        "train/MRR": batch_mrr,
                        "epoch": epoch,
                    },
                    step=step,
                )

            # Save checkpoint periodically
            if step % checkpoint_interval == 0 and step > 0:
                print(f"Saving checkpoint at step {step}...")
                checkpoint_path = save_path + f"checkpoint_epoch{epoch}_step{step}.pt"
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")

            step += 1

        # End of epoch: compute averages
        avg_train_loss = total_train_loss / total_train_samples
        avg_train_mrr = total_train_mrr / total_train_samples
        print(
            f"Epoch {epoch:02d} | "
            f"Train Loss: {avg_train_loss:.4f}, Train MRR: {avg_train_mrr:.4f}"
        )

        ##### Validation Phase #####
        model.eval()
        total_val_loss = 0.0
        total_val_mrr = 0.0
        total_val_samples = 0

        with torch.no_grad():
            for clicked_ids, clicked_mask, cand_ids, cand_mask, labels in tqdm(
                val_dataloader, desc=f"Epoch {epoch} [Val]"
            ):
                clicked_ids = clicked_ids.to(device)
                clicked_mask = clicked_mask.to(device)
                cand_ids = cand_ids.to(device)
                cand_mask = cand_mask.to(device)
                labels = labels.to(device)

                scores = model(clicked_ids, ~clicked_mask, cand_ids, cand_mask)
                loss = criterion(scores, labels)

                batch_size = labels.size(0)
                batch_mrr = 0.0
                for i in range(batch_size):
                    true_idx = labels[i].item()
                    sorted_indices = scores[i].argsort(descending=True)
                    rank = (sorted_indices == true_idx).nonzero(as_tuple=False).item() + 1
                    batch_mrr += 1.0 / rank
                batch_mrr /= batch_size

                total_val_loss += loss.item() * batch_size
                total_val_mrr += batch_mrr * batch_size
                total_val_samples += batch_size

        avg_val_loss = total_val_loss / total_val_samples
        avg_val_mrr = total_val_mrr / total_val_samples
        print(
            f"Epoch {epoch:02d} | "
            f"Val Loss:   {avg_val_loss:.4f}, Val MRR:   {avg_val_mrr:.4f}"
        )

        # Log validation metrics to W&B (use current step so they share the same x-axis if desired)
        wandb.log(
            {
                "val/loss": avg_val_loss,
                "val/MRR": avg_val_mrr,
                "epoch": epoch,
            },
            step=step,
        )

        # Save model checkpoint at end of epoch
        print(f"Saving model parameters for epoch {epoch}...")
        checkpoint_path = save_path + f"checkpoint_epoch{epoch}.pt"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    # Finish the W&B run
    wandb.finish()


# Run training!

In [10]:
wandb.login()

wandb: Currently logged in as: danielvolkov (the_magnivim) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


True

In [11]:
train(
    model,
    train_dl,
    valid_dl,
    epochs=2,
    lr=1e-4,
    device="cuda" if torch.cuda.is_available() else "cpu",
    log_interval=100,
    checkpoint_interval=1000,
    project_name="NRMS",
    save_path="./checkpoints/"
)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Epoch 1 [Train]:   4%|▍         | 1000/26161 [02:27<1:01:43,  6.79it/s]

Saving checkpoint at step 1000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step1000.pt


Epoch 1 [Train]:   8%|▊         | 2000/26161 [04:57<1:10:47,  5.69it/s]

Saving checkpoint at step 2000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step2000.pt


Epoch 1 [Train]:  11%|█▏        | 3000/26161 [07:22<59:06,  6.53it/s]  

Saving checkpoint at step 3000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step3000.pt


Epoch 1 [Train]:  15%|█▌        | 4000/26161 [09:48<52:09,  7.08it/s]  

Saving checkpoint at step 4000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step4000.pt


Epoch 1 [Train]:  19%|█▊        | 4888/26161 [12:03<52:28,  6.76it/s]  


KeyboardInterrupt: 