In [1]:
# setup
import os
import wandb
import torch
import torch.nn as nn
from nrms import NRMS
from typing import List, Dict
from torch.optim import Adam
from torch_optimizer import AdamP
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast
from data import load_and_tokenize_news, load_behaviors, MindDataset, mind_collate_fn

# Setup dataloader

In [2]:
BASE_DATA_DIR = './data/MIND_'


MAX_TITLE_LEN = 100    # each headline → exactly 30 WordPiece IDs (truncated/padded)
MAX_HISTORY  = 50     # each user’s clicked history → exactly 50 articles
BACTH_SIZE = 6


tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
PAD_ID = tokenizer.pad_token_id

train_news_dict = load_and_tokenize_news(BASE_DATA_DIR+'train/news.tsv', tokenizer, MAX_TITLE_LEN)
train_samples   = load_behaviors(BASE_DATA_DIR+'train/behaviors.tsv', train_news_dict, MAX_HISTORY)

val_news_dict = load_and_tokenize_news(BASE_DATA_DIR+'val/news.tsv', tokenizer, MAX_TITLE_LEN)
val_samples   = load_behaviors(BASE_DATA_DIR+'val/behaviors.tsv', val_news_dict, MAX_HISTORY)


train_dataset = MindDataset(train_samples)
val_dataset = MindDataset(val_samples)

train_dl = DataLoader(
    train_dataset,
    batch_size=BACTH_SIZE,
    shuffle=True,
    collate_fn=mind_collate_fn
)

valid_dl = DataLoader(
    val_dataset,
    batch_size=BACTH_SIZE,
    shuffle=True,
    collate_fn=mind_collate_fn
)

# Define model

In [3]:
model = NRMS(
    vocab_size=tokenizer.vocab_size,
    d_embed=128,
    n_heads=4,
    d_mlp=256,
    news_layers=1,
    user_layers=1,
    dropout=0.1,
    pad_max_len=MAX_TITLE_LEN 
)

# Train loop

In [4]:
def train(
    model,
    train_dataloader,
    val_dataloader,
    epochs: int = 2,
    lr: float = 1e-4,
    device: str = "cuda",
    log_interval: int = 100,
    checkpoint_interval: int = 10000,
    project_name: str = "NRMS",
    save_path: str = "./checkpoints/",
    ):
    """
    Trains `model` using train_dataloader, evaluates on val_dataloader each epoch,
    logs metrics to W&B, and finally saves model parameters to `save_path`.

    Args:
        model: a PyTorch nn.Module that returns logits of shape (B, K) given:
               (clicked_ids, clicked_mask, cand_ids, cand_mask)
        train_dataloader: torch.utils.data.DataLoader for training
        val_dataloader: torch.utils.data.DataLoader for validation
        epochs: number of epochs to train
        lr: learning rate for Adam
        device: "cuda" or "cpu"
        project_name: W&B project name
        save_path: where to save the final model.state_dict()
    """
    # Move model to device
    model.to(device)

    # Initialize W&B
    wandb.init(
        project=project_name,
        config={
            "epochs": epochs,
            "learning_rate": lr,
            "optimizer": "Adam",
            "loss_fn": "CrossEntropyLoss",
        },
    )
    wandb.watch(model, log="parameters", log_freq=500)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=lr)

    step = 0 # increases each batch

    for epoch in range(1, epochs + 1):
        ##### Training Phase #####
        model.train()
        total_train_loss = 0.0
        total_train_correct = 0
        total_train_samples = 0

        for clicked_ids, clicked_mask, cand_ids, cand_mask, labels in tqdm(train_dataloader, desc=f"Epoch {epoch} [Train]"):
            clicked_ids = clicked_ids.to(device)
            clicked_mask = clicked_mask.to(device)
            cand_ids = cand_ids.to(device)
            cand_mask = cand_mask.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            scores = model(clicked_ids, ~clicked_mask, cand_ids, cand_mask)

            # Compute training loss
            loss = criterion(scores, labels)
            loss.backward()
            optimizer.step()

            # Accumulate stats
            total_train_loss += loss.item() * labels.size(0)
            preds = scores.argmax(dim=1)
            total_train_correct += (preds == labels).sum().item()
            total_train_samples += labels.size(0)

            if step % log_interval == 0:
                # Log to console
                # print(
                #     f"Step {step} | "
                #     f"Train Loss: {loss.item():.4f}, "
                #     f"Train Acc: {(preds == labels).float().mean().item():.4f}"
                # )
                # Log to W&B
                wandb.log(
                    {
                        "train/loss": loss.item(),
                        "train/accuracy": (preds == labels).float().mean().item(),
                        "step": step,
                    }
                )

            if step % checkpoint_interval == 0:
                print(f'Saving checkpoint at step {step}...')
                checkpoint_path = save_path + f"checkpoint_epoch{epoch}_step{step}.pt"
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")


            step += 1


        avg_train_loss = total_train_loss / total_train_samples
        train_accuracy = total_train_correct / total_train_samples

        ##### Validation Phase #####
        model.eval()
        total_val_loss = 0.0
        total_val_correct = 0
        total_val_samples = 0

        with torch.no_grad():
            for clicked_ids, clicked_mask, cand_ids, cand_mask, labels in tqdm(val_dataloader, desc=f"Epoch {epoch} [Val]"):
                clicked_ids = clicked_ids.to(device)
                clicked_mask = clicked_mask.to(device)
                cand_ids = cand_ids.to(device)
                cand_mask = cand_mask.to(device)
                labels = labels.to(device)

                # Forward pass
                scores = model(clicked_ids, ~clicked_mask, cand_ids, cand_mask)

                # Compute validation loss
                loss = criterion(scores, labels)
                total_val_loss += loss.item() * labels.size(0)

                preds = scores.argmax(dim=1)
                total_val_correct += (preds == labels).sum().item()
                total_val_samples += labels.size(0)

        avg_val_loss = total_val_loss / total_val_samples
        val_accuracy = total_val_correct / total_val_samples

        # Log to console
        print(
            f"Epoch {epoch:02d} | "
            f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f} | "
            f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}"
        )

        # Log to W&B
        wandb.log(
            {
                "epoch": epoch,
                "val/loss": avg_val_loss,
                "val/accuracy": val_accuracy,
            }
        )

        # Save model checkpoint after epoch
        print(f"Saving model parameters for epoch {epoch}...")
        checkpoint_path = save_path + f"checkpoint_epoch{epoch}.pt"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    # Finish the W&B run
    wandb.finish()


# Run training!

In [5]:
wandb.login()

wandb: Currently logged in as: danielvolkov (the_magnivim) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


True

In [6]:
train(
    model,
    train_dl,
    valid_dl,
    epochs=2,
    lr=1e-4,
    device="cuda" if torch.cuda.is_available() else "cpu",
    log_interval=100,
    checkpoint_interval=1000,
    project_name="NRMS",
    save_path="./checkpoints/"
)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Saving checkpoint at step 0...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step0.pt


Epoch 1 [Train]:   4%|▍         | 1000/26161 [01:46<39:59, 10.49it/s] 

Saving checkpoint at step 1000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step1000.pt


Epoch 1 [Train]:   8%|▊         | 1999/26161 [03:30<40:04, 10.05it/s]

Saving checkpoint at step 2000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step2000.pt


Epoch 1 [Train]:  11%|█▏        | 2999/26161 [05:12<37:28, 10.30it/s]

Saving checkpoint at step 3000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step3000.pt


Epoch 1 [Train]:  15%|█▌        | 4000/26161 [06:59<40:33,  9.11it/s]

Saving checkpoint at step 4000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step4000.pt


Epoch 1 [Train]:  19%|█▉        | 5000/26161 [08:45<38:56,  9.06it/s]

Saving checkpoint at step 5000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step5000.pt


Epoch 1 [Train]:  23%|██▎       | 5999/26161 [10:30<32:28, 10.35it/s]

Saving checkpoint at step 6000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step6000.pt


Epoch 1 [Train]:  27%|██▋       | 7000/26161 [12:14<37:43,  8.47it/s]

Saving checkpoint at step 7000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step7000.pt


Epoch 1 [Train]:  31%|███       | 8000/26161 [13:58<37:15,  8.12it/s]

Saving checkpoint at step 8000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step8000.pt


Epoch 1 [Train]:  34%|███▍      | 8999/26161 [15:44<27:56, 10.23it/s]

Saving checkpoint at step 9000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step9000.pt


Epoch 1 [Train]:  38%|███▊      | 10000/26161 [17:32<30:35,  8.80it/s]

Saving checkpoint at step 10000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step10000.pt


Epoch 1 [Train]:  42%|████▏     | 10999/26161 [19:17<22:45, 11.10it/s]

Saving checkpoint at step 11000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step11000.pt


Epoch 1 [Train]:  46%|████▌     | 12000/26161 [21:04<26:30,  8.90it/s]

Saving checkpoint at step 12000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step12000.pt


Epoch 1 [Train]:  50%|████▉     | 13000/26161 [22:49<31:08,  7.04it/s]

Saving checkpoint at step 13000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step13000.pt


Epoch 1 [Train]:  54%|█████▎    | 14000/26161 [24:37<25:17,  8.01it/s]

Saving checkpoint at step 14000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step14000.pt


Epoch 1 [Train]:  57%|█████▋    | 15000/26161 [26:19<20:45,  8.96it/s]

Saving checkpoint at step 15000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step15000.pt


Epoch 1 [Train]:  61%|██████    | 15999/26161 [28:04<15:19, 11.06it/s]

Saving checkpoint at step 16000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step16000.pt


Epoch 1 [Train]:  65%|██████▍   | 17000/26161 [29:46<16:12,  9.42it/s]

Saving checkpoint at step 17000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step17000.pt


Epoch 1 [Train]:  69%|██████▉   | 17999/26161 [31:28<13:52,  9.81it/s]

Saving checkpoint at step 18000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step18000.pt


Epoch 1 [Train]:  73%|███████▎  | 19000/26161 [33:11<14:28,  8.24it/s]

Saving checkpoint at step 19000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step19000.pt


Epoch 1 [Train]:  76%|███████▋  | 20000/26161 [34:53<13:58,  7.35it/s]

Saving checkpoint at step 20000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step20000.pt


Epoch 1 [Train]:  80%|████████  | 21000/26161 [36:40<08:10, 10.53it/s]

Saving checkpoint at step 21000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step21000.pt


Epoch 1 [Train]:  84%|████████▍ | 22000/26161 [38:24<07:41,  9.01it/s]

Saving checkpoint at step 22000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step22000.pt


Epoch 1 [Train]:  88%|████████▊ | 23000/26161 [40:09<05:22,  9.80it/s]

Saving checkpoint at step 23000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step23000.pt


Epoch 1 [Train]:  92%|█████████▏| 24000/26161 [41:53<04:22,  8.23it/s]

Saving checkpoint at step 24000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step24000.pt


Epoch 1 [Train]:  96%|█████████▌| 25000/26161 [43:38<02:08,  9.04it/s]

Saving checkpoint at step 25000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step25000.pt


Epoch 1 [Train]:  99%|█████████▉| 26000/26161 [45:22<00:17,  9.17it/s]

Saving checkpoint at step 26000...
Checkpoint saved to ./checkpoints/checkpoint_epoch1_step26000.pt


Epoch 1 [Train]: 100%|██████████| 26161/26161 [45:39<00:00,  9.55it/s]
  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
Epoch 1 [Val]: 100%|██████████| 12192/12192 [05:45<00:00, 35.26it/s]


Epoch 01 | Train Loss: 473.1472, Train Acc: 0.1106 | Val Loss: 4.4230, Val Acc: 0.1033
Saving model parameters for epoch 1...
Checkpoint saved to ./checkpoints/checkpoint_epoch1.pt


Epoch 2 [Train]:   3%|▎         | 838/26161 [01:27<43:39,  9.67it/s]

Saving checkpoint at step 27000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step27000.pt


Epoch 2 [Train]:   7%|▋         | 1838/26161 [03:13<44:35,  9.09it/s]  

Saving checkpoint at step 28000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step28000.pt


Epoch 2 [Train]:  11%|█         | 2839/26161 [04:59<35:14, 11.03it/s]

Saving checkpoint at step 29000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step29000.pt


Epoch 2 [Train]:  15%|█▍        | 3839/26161 [06:45<44:12,  8.42it/s]

Saving checkpoint at step 30000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step30000.pt


Epoch 2 [Train]:  18%|█▊        | 4839/26161 [08:27<32:17, 11.00it/s]

Saving checkpoint at step 31000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step31000.pt


Epoch 2 [Train]:  22%|██▏       | 5839/26161 [10:12<35:45,  9.47it/s]

Saving checkpoint at step 32000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step32000.pt


Epoch 2 [Train]:  26%|██▌       | 6838/26161 [11:58<34:10,  9.42it/s]

Saving checkpoint at step 33000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step33000.pt


Epoch 2 [Train]:  30%|██▉       | 7839/26161 [13:41<33:16,  9.18it/s]

Saving checkpoint at step 34000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step34000.pt


Epoch 2 [Train]:  34%|███▍      | 8838/26161 [15:22<27:34, 10.47it/s]

Saving checkpoint at step 35000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step35000.pt


Epoch 2 [Train]:  38%|███▊      | 9839/26161 [17:04<30:12,  9.01it/s]

Saving checkpoint at step 36000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step36000.pt


Epoch 2 [Train]:  41%|████▏     | 10838/26161 [18:45<25:38,  9.96it/s]

Saving checkpoint at step 37000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step37000.pt


Epoch 2 [Train]:  45%|████▌     | 11839/26161 [20:27<25:04,  9.52it/s]

Saving checkpoint at step 38000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step38000.pt


Epoch 2 [Train]:  49%|████▉     | 12838/26161 [22:08<21:27, 10.35it/s]

Saving checkpoint at step 39000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step39000.pt


Epoch 2 [Train]:  53%|█████▎    | 13839/26161 [23:51<21:55,  9.37it/s]

Saving checkpoint at step 40000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step40000.pt


Epoch 2 [Train]:  57%|█████▋    | 14839/26161 [25:36<18:51, 10.01it/s]

Saving checkpoint at step 41000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step41000.pt


Epoch 2 [Train]:  61%|██████    | 15839/26161 [27:20<14:13, 12.09it/s]

Saving checkpoint at step 42000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step42000.pt


Epoch 2 [Train]:  64%|██████▍   | 16838/26161 [29:08<20:28,  7.59it/s]

Saving checkpoint at step 43000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step43000.pt


Epoch 2 [Train]:  68%|██████▊   | 17838/26161 [31:00<15:04,  9.20it/s]

Saving checkpoint at step 44000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step44000.pt


Epoch 2 [Train]:  72%|███████▏  | 18839/26161 [32:43<13:14,  9.21it/s]

Saving checkpoint at step 45000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step45000.pt


Epoch 2 [Train]:  76%|███████▌  | 19838/26161 [34:24<10:24, 10.13it/s]

Saving checkpoint at step 46000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step46000.pt


Epoch 2 [Train]:  80%|███████▉  | 20838/26161 [36:06<09:14,  9.59it/s]

Saving checkpoint at step 47000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step47000.pt


Epoch 2 [Train]:  83%|████████▎ | 21839/26161 [37:50<07:53,  9.12it/s]

Saving checkpoint at step 48000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step48000.pt


Epoch 2 [Train]:  87%|████████▋ | 22838/26161 [39:33<05:50,  9.49it/s]

Saving checkpoint at step 49000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step49000.pt


Epoch 2 [Train]:  91%|█████████ | 23838/26161 [41:24<03:51, 10.03it/s]

Saving checkpoint at step 50000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step50000.pt


Epoch 2 [Train]:  95%|█████████▍| 24838/26161 [43:13<02:13,  9.90it/s]

Saving checkpoint at step 51000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step51000.pt


Epoch 2 [Train]:  99%|█████████▉| 25839/26161 [45:02<00:32,  9.82it/s]

Saving checkpoint at step 52000...
Checkpoint saved to ./checkpoints/checkpoint_epoch2_step52000.pt


Epoch 2 [Train]: 100%|██████████| 26161/26161 [45:36<00:00,  9.56it/s]
Epoch 2 [Val]: 100%|██████████| 12192/12192 [05:42<00:00, 35.64it/s]


Epoch 02 | Train Loss: 8.3159, Train Acc: 0.1188 | Val Loss: 3.1634, Val Acc: 0.0947
Saving model parameters for epoch 2...
Checkpoint saved to ./checkpoints/checkpoint_epoch2.pt


0,1
epoch,▁█
step,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇█
train/accuracy,▃▃▃▃█▃▃▁▃▁▃▃▃█▃▆▁▁▁▃▃▃▃▆▃▁▁▃▁▃▁▁▆▆▆▆▃▁▁▃
train/loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/accuracy,█▁
val/loss,█▁

0,1
epoch,2.0
step,52300.0
train/accuracy,0.0
train/loss,4.08601
val/accuracy,0.09465
val/loss,3.16344
