In [1]:
# setup
import os
import wandb
import torch
import torch.nn as nn
from nrms import NRMS
from typing import List, Dict
from torch.optim import Adam
from torch_optimizer import AdamP
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast
from data import load_and_tokenize_news, load_behaviors, MindDataset, mind_collate_fn

# Setup dataloader

In [2]:
BASE_DATA_DIR = './data/MIND_'


MAX_TITLE_LEN = 100    # each headline → exactly 30 WordPiece IDs (truncated/padded)
MAX_HISTORY  = 50     # each user’s clicked history → exactly 50 articles
BACTH_SIZE = 3


tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
PAD_ID = tokenizer.pad_token_id

train_news_dict = load_and_tokenize_news(BASE_DATA_DIR+'train/news.tsv', tokenizer, MAX_TITLE_LEN)
train_samples   = load_behaviors(BASE_DATA_DIR+'train/behaviors.tsv', train_news_dict, MAX_HISTORY)

val_news_dict = load_and_tokenize_news(BASE_DATA_DIR+'val/news.tsv', tokenizer, MAX_TITLE_LEN)
val_samples   = load_behaviors(BASE_DATA_DIR+'val/behaviors.tsv', val_news_dict, MAX_HISTORY)


train_dataset = MindDataset(train_samples)
val_dataset = MindDataset(val_samples)

train_dl = DataLoader(
    train_dataset,
    batch_size=BACTH_SIZE,
    shuffle=True,
    collate_fn=mind_collate_fn
)

valid_dl = DataLoader(
    val_dataset,
    batch_size=BACTH_SIZE,
    shuffle=True,
    collate_fn=mind_collate_fn
)

# Define model

In [3]:
model = NRMS(
    vocab_size=tokenizer.vocab_size,
    d_embed=512,
    n_heads=8,
    d_mlp=2048,
    news_layers=1,
    user_layers=1,
    dropout=0.1,
    pad_max_len=MAX_TITLE_LEN 
)

# Train loop!

In [4]:
def train(
    model,
    train_dataloader,
    val_dataloader,
    epochs: int = 2,
    lr: float = 1e-4,
    device: str = "cuda",
    log_interval: int = 100,
    checkpoint_interval: int = 10000,
    project_name: str = "NRMS",
    save_path: str = "./checkpoints/",
    ):
    """
    Trains `model` using train_dataloader, evaluates on val_dataloader each epoch,
    logs metrics to W&B, and finally saves model parameters to `save_path`.

    Args:
        model: a PyTorch nn.Module that returns logits of shape (B, K) given:
               (clicked_ids, clicked_mask, cand_ids, cand_mask)
        train_dataloader: torch.utils.data.DataLoader for training
        val_dataloader: torch.utils.data.DataLoader for validation
        epochs: number of epochs to train
        lr: learning rate for Adam
        device: "cuda" or "cpu"
        project_name: W&B project name
        save_path: where to save the final model.state_dict()
    """
    # Move model to device
    model.to(device)

    # Initialize W&B
    wandb.init(
        project=project_name,
        config={
            "epochs": epochs,
            "learning_rate": lr,
            "optimizer": "Adam",
            "loss_fn": "CrossEntropyLoss",
        },
    )
    wandb.watch(model, log="parameters", log_freq=500)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=lr)

    step = 0 # increases each batch

    for epoch in range(1, epochs + 1):
        ##### Training Phase #####
        model.train()
        total_train_loss = 0.0
        total_train_correct = 0
        total_train_samples = 0

        for clicked_ids, clicked_mask, cand_ids, cand_mask, labels in tqdm(train_dataloader, desc=f"Epoch {epoch} [Train]"):
            clicked_ids = clicked_ids.to(device)
            clicked_mask = clicked_mask.to(device)
            cand_ids = cand_ids.to(device)
            cand_mask = cand_mask.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            scores = model(clicked_ids, ~clicked_mask, cand_ids, cand_mask)

            # Compute training loss
            loss = criterion(scores, labels)
            loss.backward()
            optimizer.step()

            # Accumulate stats
            total_train_loss += loss.item() * labels.size(0)
            preds = scores.argmax(dim=1)
            total_train_correct += (preds == labels).sum().item()
            total_train_samples += labels.size(0)

            if step % log_interval == 0:
                # Log to console
                print(
                    f"Step {step} | "
                    f"Train Loss: {loss.item():.4f}, "
                    f"Train Acc: {(preds == labels).float().mean().item():.4f}"
                )
                # Log to W&B
                wandb.log(
                    {
                        "train/loss": loss.item(),
                        "train/accuracy": (preds == labels).float().mean().item(),
                        "step": step,
                    }
                )

            if step % checkpoint_interval == 0:
                print(f'Saving checkpoint at step {step}...')
                checkpoint_path = save_path + f"checkpoint_epoch{epoch}_step{step}.pt"
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")


            step += 1


        avg_train_loss = total_train_loss / total_train_samples
        train_accuracy = total_train_correct / total_train_samples

        ##### Validation Phase #####
        model.eval()
        total_val_loss = 0.0
        total_val_correct = 0
        total_val_samples = 0

        with torch.no_grad():
            for clicked_ids, clicked_mask, cand_ids, cand_mask, labels in tqdm(val_dataloader, desc=f"Epoch {epoch} [Val]"):
                clicked_ids = clicked_ids.to(device)
                clicked_mask = clicked_mask.to(device)
                cand_ids = cand_ids.to(device)
                cand_mask = cand_mask.to(device)
                labels = labels.to(device)

                # Forward pass
                scores = model(clicked_ids, ~clicked_mask, cand_ids, cand_mask)

                # Compute validation loss
                loss = criterion(scores, labels)
                total_val_loss += loss.item() * labels.size(0)

                preds = scores.argmax(dim=1)
                total_val_correct += (preds == labels).sum().item()
                total_val_samples += labels.size(0)

        avg_val_loss = total_val_loss / total_val_samples
        val_accuracy = total_val_correct / total_val_samples

        # Log to console
        print(
            f"Epoch {epoch:02d} | "
            f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f} | "
            f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}"
        )

        # Log to W&B
        wandb.log(
            {
                "epoch": epoch,
                "val/loss": avg_val_loss,
                "val/accuracy": val_accuracy,
            }
        )

        # Save model checkpoint after epoch
        print(f"Saving model parameters for epoch {epoch}...")
        checkpoint_path = save_path + f"checkpoint_epoch{epoch}.pt"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    # Finish the W&B run
    wandb.finish()


In [5]:
wandb.login()

wandb: Currently logged in as: danielvolkov (the_magnivim) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


True

In [6]:
train(
    model,
    train_dl,
    valid_dl,
    epochs=2,
    lr=1e-4,
    device="cuda" if torch.cuda.is_available() else "cpu",
    log_interval=100,
    checkpoint_interval=1000,
    project_name="NRMS",
    save_path="./checkpoints/"
)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Step 0 | Train Loss: 1387607.8750, Train Acc: 0.0000
Saving checkpoint at step 0...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step0.pt


Epoch 1 [Train]:   0%|          | 100/52322 [00:46<9:14:59,  1.57it/s]

Step 100 | Train Loss: 5662.2017, Train Acc: 0.3333


Epoch 1 [Train]:   0%|          | 200/52322 [01:54<6:43:07,  2.15it/s] 

Step 200 | Train Loss: 10791.1758, Train Acc: 0.0000


Epoch 1 [Train]:   1%|          | 300/52322 [03:19<4:43:44,  3.06it/s] 

Step 300 | Train Loss: 5539.2075, Train Acc: 0.3333


Epoch 1 [Train]:   1%|          | 400/52322 [04:18<5:58:09,  2.42it/s] 

Step 400 | Train Loss: 3117.8674, Train Acc: 0.3333


Epoch 1 [Train]:   1%|          | 500/52322 [05:32<15:29:14,  1.08s/it]

Step 500 | Train Loss: 2573.4370, Train Acc: 0.0000


Epoch 1 [Train]:   1%|          | 600/52322 [06:50<12:14:09,  1.17it/s]

Step 600 | Train Loss: 5861.9907, Train Acc: 0.0000


Epoch 1 [Train]:   1%|▏         | 700/52322 [07:34<7:24:01,  1.94it/s] 

Step 700 | Train Loss: 4711.4434, Train Acc: 0.0000


Epoch 1 [Train]:   2%|▏         | 800/52322 [08:24<5:42:25,  2.51it/s] 

Step 800 | Train Loss: 2749.4280, Train Acc: 0.0000


Epoch 1 [Train]:   2%|▏         | 900/52322 [09:30<14:25:16,  1.01s/it]

Step 900 | Train Loss: 1032.6036, Train Acc: 0.3333


Epoch 1 [Train]:   2%|▏         | 1000/52322 [10:45<5:54:36,  2.41it/s]

Step 1000 | Train Loss: 3739.7471, Train Acc: 0.0000
Saving checkpoint at step 1000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step1000.pt


Epoch 1 [Train]:   2%|▏         | 1100/52322 [12:05<7:31:18,  1.89it/s] 

Step 1100 | Train Loss: 2864.2708, Train Acc: 0.0000


Epoch 1 [Train]:   2%|▏         | 1200/52322 [13:53<5:06:39,  2.78it/s] 

Step 1200 | Train Loss: 2004.3756, Train Acc: 0.3333


Epoch 1 [Train]:   2%|▏         | 1206/52322 [13:58<9:52:31,  1.44it/s] 


KeyboardInterrupt: 