# BERT4Rec Advanced Training

## Enhanced training with early stopping, learning rate decay, and comprehensive monitoring

This notebook implements:

- 8-12 epochs with early stopping
- Learning rate decay when validation loss plateaus
- TensorBoard monitoring
- Recall@10 evaluation during training
- Target: Recall@10 → ~0.20–0.25


In [None]:
# Core imports
import sys
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import polars as pl
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import matplotlib.pyplot as plt

# Add project root to path
sys.path.append('../../')

# Import BERT4Rec implementation
from hnm_data_analysis.data_modelling.bert4rec_modelling import (
    SequenceOptions, prepare_sequences_with_polars,
    build_dataloaders_for_bert4rec, BERT4RecModel, TrainConfig,
    train_bert4rec, evaluate_next_item_topk, set_all_seeds,
    MaskingOptions
)

# Set paths
DATA_ROOT = Path('../../data/modelling_data')
MODEL_SAVE_DIR = Path('../../models/bert4rec')
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

ModuleNotFoundError: No module named 'hnm_data_analysis'

## Enhanced Training Configuration

**Context Tokens for Better Calibration:**

- **[SEG] tokens**: Customer segment information helps model understand user preferences
- **[CH] tokens**: Sales channel context (online vs store) affects purchase patterns
- **Price-band tokens**: Price sensitivity information improves recommendation quality
- **Expected improvement**: +0.02-0.05 boost in Recall@10 and NDCG through better calibration


In [None]:
# Enhanced training configuration
class AdvancedTrainConfig:
    def __init__(self):
        # Training parameters
        self.n_epochs = 12
        self.lr = 1e-4
        self.weight_decay = 0.01
        self.warmup_steps = 1000
        self.grad_clip_norm = 1.0
        
        # Early stopping
        self.early_stopping_patience = 3
        self.min_delta = 0.001  # Minimum improvement to consider
        
        # Learning rate decay
        self.lr_decay_patience = 2
        self.lr_decay_factor = 0.5
        self.min_lr = 1e-6
        
        # Evaluation
        self.eval_every_n_epochs = 1
        self.save_best_model = True
        
        # Target metrics
        self.target_recall_10 = 0.20
        self.target_recall_20 = 0.25

config = AdvancedTrainConfig()
print(f"Training for up to {config.n_epochs} epochs with early stopping")
print(f"Target Recall@10: {config.target_recall_10}")

## Data Loading and Preprocessing


In [None]:
# Load data
print("Loading transaction data...")
df = pl.read_parquet(DATA_ROOT / 'transactions_final.parquet')

print(f"Data shape: {df.shape}")
print(f"Date range: {df['t_dat'].min()} to {df['t_dat'].max()}")
print(f"Unique customers: {df['customer_id'].n_unique():,}")
print(f"Unique articles: {df['article_id'].n_unique():,}")

In [None]:
# Sequence preparation with optimized settings + context tokens
set_all_seeds(42)

sequence_options = SequenceOptions(
    max_len=50,           # Maximum sequence length 
    min_len=3,            # Minimum sequence length
    deduplicate_exact=True,
    treat_same_day_as_basket=True,
    # Enable context tokens for better calibration
    add_segment_prefix=True,     # Add [SEG] tokens based on customer segments
    add_channel_prefix=True,     # Add [CH] tokens based on sales channel
    add_priceband_prefix=True,   # Add price-band tokens
    n_price_bins=10              # Number of price bands
)

print("Preparing sequences with context tokens...")
print("Context tokens enabled:")
print(f"  ✓ Segment prefix: {sequence_options.add_segment_prefix}")
print(f"  ✓ Channel prefix: {sequence_options.add_channel_prefix}")
print(f"  ✓ Price-band prefix: {sequence_options.add_priceband_prefix}")

# Check if required columns exist for context tokens
required_cols = ['customer_id', 'article_id', 't_dat']
optional_cols = {
    'customer_segment': sequence_options.add_segment_prefix,
    'sales_channel_id': sequence_options.add_channel_prefix
}

print(f"\nData columns available: {df.columns}")
missing_required = [col for col in required_cols if col not in df.columns]
if missing_required:
    print(f"❌ Missing required columns: {missing_required}")
else:
    print(f"✅ All required columns present")

for col, needed in optional_cols.items():
    if needed:
        if col in df.columns:
            print(f"✅ {col} available for context tokens")
        else:
            print(f"⚠️  {col} not found - will skip this context token type")

# Prepare sequences (function will handle missing columns gracefully)
prepared = prepare_sequences_with_polars(df, sequence_options)

print(f"\nSequence preparation complete:")
print(f"  Total sequences: {len(prepared.sequences):,}")
print(f"  Vocabulary size: {prepared.registry.vocab_size:,}")
print(f"  Prefix tokens: {len(prepared.registry.prefix_token2id):,}")
print(f"  Average sequence length: {np.mean([len(seq) for seq in prepared.sequences]):.1f}")
print(f"  Average prefix length: {np.mean(prepared.prefix_lengths):.1f}")

# Show some example prefix tokens
if prepared.registry.prefix_token2id:
    print(f"\nExample context tokens:")
    for i, (token, token_id) in enumerate(list(prepared.registry.prefix_token2id.items())[:10]):
        print(f"  {token_id}: {token}")
        if i >= 9:
            break

## Enhanced Training Function with Monitoring


In [None]:
def train_bert4rec_advanced(
    model, train_loader, valid_loader, test_loader, config, device, registry
):
    """
    Enhanced training with early stopping, LR decay, and monitoring
    """
    # Setup TensorBoard
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    writer = SummaryWriter(f'runs/bert4rec_advanced_{timestamp}')
    
    # Optimizers and schedulers
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=config.lr, 
        weight_decay=config.weight_decay
    )
    
    # Learning rate scheduler (reduce on plateau)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=config.lr_decay_factor,
        patience=config.lr_decay_patience,
        min_lr=config.min_lr
    )
    
    criterion = nn.CrossEntropyLoss(ignore_index=-100)
    
    # Early stopping variables
    best_val_loss = float('inf')
    best_recall_10 = 0.0
    patience_counter = 0
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'recall_10': [],
        'recall_20': [],
        'learning_rate': []
    }
    
    print(f"Starting training for up to {config.n_epochs} epochs...")
    print(f"TensorBoard logs: runs/bert4rec_advanced_{timestamp}")
    
    for epoch in range(1, config.n_epochs + 1):
        # Training phase
        model.train()
        train_loss = 0.0
        num_batches = 0
        
        from tqdm.auto import tqdm
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{config.n_epochs}")
        
        for batch_idx, batch in enumerate(pbar):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            optimizer.zero_grad()
            logits = model(input_ids, attention_mask)
            
            # Calculate loss
            B, L, V = logits.size()
            loss = criterion(logits.view(B * L, V), labels.view(B * L))
            
            loss.backward()
            
            if config.grad_clip_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)
            
            optimizer.step()
            
            train_loss += loss.item()
            num_batches += 1
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_loss': f'{train_loss/num_batches:.4f}'
            })
            
            # Log batch metrics to TensorBoard
            global_step = (epoch - 1) * len(train_loader) + batch_idx
            writer.add_scalar('Loss/Train_Batch', loss.item(), global_step)
            writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], global_step)
        
        avg_train_loss = train_loss / num_batches
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(valid_loader, desc="Validation", leave=False):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].to(device)
                
                logits = model(input_ids, attention_mask)
                B, L, V = logits.size()
                loss = criterion(logits.view(B * L, V), labels.view(B * L))
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(valid_loader)
        
        # Evaluate Recall@K every epoch
        print(f"\nEvaluating Recall@K...")
        recall_10, ndcg_10 = evaluate_next_item_topk(model, test_loader, device, registry, topk=10)
        recall_20, ndcg_20 = evaluate_next_item_topk(model, test_loader, device, registry, topk=20)
        
        # Store history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['recall_10'].append(recall_10)
        history['recall_20'].append(recall_20)
        history['learning_rate'].append(optimizer.param_groups[0]['lr'])
        
        # Log to TensorBoard
        writer.add_scalar('Loss/Train_Epoch', avg_train_loss, epoch)
        writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
        writer.add_scalar('Metrics/Recall@10', recall_10, epoch)
        writer.add_scalar('Metrics/Recall@20', recall_20, epoch)
        writer.add_scalar('Metrics/NDCG@10', ndcg_10, epoch)
        writer.add_scalar('Metrics/NDCG@20', ndcg_20, epoch)
        
        # Print epoch summary
        print(f"\nEpoch {epoch}/{config.n_epochs} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss:   {avg_val_loss:.4f}")
        print(f"  Recall@10:  {recall_10:.4f} (Target: {config.target_recall_10:.2f})")
        print(f"  Recall@20:  {recall_20:.4f} (Target: {config.target_recall_20:.2f})")
        print(f"  NDCG@10:    {ndcg_10:.4f}")
        print(f"  LR:         {optimizer.param_groups[0]['lr']:.2e}")
        
        # Learning rate decay on plateau - track changes manually
        old_lr = optimizer.param_groups[0]['lr']
        lr_scheduler.step(avg_val_loss)
        new_lr = optimizer.param_groups[0]['lr']
        if new_lr < old_lr:
            print(f"  📉 Learning rate reduced: {old_lr:.2e} → {new_lr:.2e}")
        
        # Early stopping check
        if avg_val_loss < best_val_loss - config.min_delta:
            best_val_loss = avg_val_loss
            best_recall_10 = recall_10
            patience_counter = 0
            
            # Save best model
            if config.save_best_model:
                best_model_path = MODEL_SAVE_DIR / f'bert4rec_best_{timestamp}.pt'
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'registry': registry,
                    'config': config,
                    'val_loss': best_val_loss,
                    'recall_10': best_recall_10,
                    'history': history
                }, best_model_path)
                print(f"  💾 Saved best model (val_loss: {best_val_loss:.4f})")
        else:
            patience_counter += 1
            print(f"  ⏰ Early stopping patience: {patience_counter}/{config.early_stopping_patience}")
        
        # Check early stopping
        if patience_counter >= config.early_stopping_patience:
            print(f"\n🛑 Early stopping at epoch {epoch}")
            print(f"   Best validation loss: {best_val_loss:.4f}")
            print(f"   Best Recall@10: {best_recall_10:.4f}")
            break
        
        # Check if target achieved
        if recall_10 >= config.target_recall_10:
            print(f"\n🎯 Target Recall@10 achieved: {recall_10:.4f} >= {config.target_recall_10:.2f}")
            break
        
        print("-" * 80)
    
    writer.close()
    
    return history, best_model_path if config.save_best_model else None

## Setup Data Loaders


In [None]:
# Create data loaders with time-based splits
masking_options = MaskingOptions(
    mask_prob=0.8,
    random_token_prob=0.1,
    keep_original_prob=0.1
)

print("Building data loaders...")
train_loader, valid_loader, eval_loader = build_dataloaders_for_bert4rec(
    prepared, 
    batch_size=64,
    masking=masking_options,
    valid_split=0.1,
    num_workers=4
)

print(f"Train batches: {len(train_loader)}")
print(f"Valid batches: {len(valid_loader)}")
print(f"Eval batches: {len(eval_loader)}")
print(f"Registry vocab size: {prepared.registry.vocab_size}")

# Use eval_loader as test_loader
test_loader = eval_loader

## Initialize and Train Model


In [None]:
# Debug vocabulary and token IDs before model initialization
print("🔍 Debugging vocabulary and token IDs...")

# Check vocabulary info
print(f"Registry vocab size: {prepared.registry.vocab_size:,}")
print(f"Article items: {len(prepared.registry.item2id):,}")
print(f"Context tokens: {len(prepared.registry.prefix_token2id):,}")

# Check token ID ranges in sequences
all_token_ids = []
for seq in prepared.sequences[:100]:  # Check first 100 sequences
    all_token_ids.extend(seq)

if all_token_ids:
    min_token_id = min(all_token_ids)
    max_token_id = max(all_token_ids)
    print(f"Token ID range in data: {min_token_id} to {max_token_id}")
    print(f"Model vocab size: {prepared.registry.vocab_size}")
    
    if max_token_id >= prepared.registry.vocab_size:
        print(f"❌ ERROR: Max token ID ({max_token_id}) >= vocab size ({prepared.registry.vocab_size})")
        print("This will cause IndexError!")
    else:
        print(f"✅ Token IDs are within vocabulary range")

# Check a sample batch from data loader
print("\n🔍 Checking sample batch from train_loader...")
sample_batch = next(iter(train_loader))
input_ids = sample_batch["input_ids"]
print(f"Batch input_ids shape: {input_ids.shape}")
print(f"Batch token ID range: {input_ids.min().item()} to {input_ids.max().item()}")

if input_ids.max().item() >= prepared.registry.vocab_size:
    print(f"❌ ERROR: Batch max token ID ({input_ids.max().item()}) >= vocab size ({prepared.registry.vocab_size})")
else:
    print(f"✅ Batch token IDs are within vocabulary range")

# Initialize model with larger capacity for context tokens
model = BERT4RecModel(
    vocab_size=prepared.registry.vocab_size,  # Includes items + context tokens
    d_model=512,    # Larger model for better context understanding
    n_layers=6,     # More layers for complex patterns
    n_heads=8,
    dropout=0.1,
    max_len=50      # Match max sequence length
)

model.to(device)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n📊 Model initialized successfully:")
print(f"Model parameters: {total_params:,}")
print(f"Model size: ~{total_params * 4 / 1e6:.1f} MB")
print(f"Model vocab size: {model.token_emb.num_embeddings}")
print(f"Vocab breakdown:")
print(f"  Total vocab size: {prepared.registry.vocab_size:,}")
print(f"  Article items: {len(prepared.registry.item2id):,}")
print(f"  Context tokens: {len(prepared.registry.prefix_token2id):,}")
special_tokens = prepared.registry.vocab_size - len(prepared.registry.item2id) - len(prepared.registry.prefix_token2id)
print(f"  Special tokens: {special_tokens:,}")

In [None]:
# Start advanced training
print("🚀 Starting advanced BERT4Rec training...")
print(f"⚡ Device: {device}")
print(f"📊 Target Recall@10: {config.target_recall_10:.2f}")
print("📈 TensorBoard will be available during training")
print("\nTo monitor in real-time, run in another terminal:")
print("   tensorboard --logdir=runs")
print("   Then open: http://localhost:6006")
print("\n" + "=" * 80)

# Train the model
history, best_model_path = train_bert4rec_advanced(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    test_loader=test_loader,
    config=config,
    device=device,
    registry=prepared.registry
)

print("\n🎉 Training completed!")
if best_model_path:
    print(f"💾 Best model saved to: {best_model_path}")

## Training Analysis and Visualization


In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('BERT4Rec Advanced Training Results', fontsize=16)

epochs = range(1, len(history['train_loss']) + 1)

# Loss curves
axes[0, 0].plot(epochs, history['train_loss'], label='Train Loss', marker='o')
axes[0, 0].plot(epochs, history['val_loss'], label='Validation Loss', marker='s')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Recall curves
axes[0, 1].plot(epochs, history['recall_10'], label='Recall@10', marker='o', color='green')
axes[0, 1].plot(epochs, history['recall_20'], label='Recall@20', marker='s', color='blue')
axes[0, 1].axhline(y=config.target_recall_10, color='red', linestyle='--', label=f'Target R@10 ({config.target_recall_10})')
axes[0, 1].axhline(y=config.target_recall_20, color='orange', linestyle='--', label=f'Target R@20 ({config.target_recall_20})')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Recall')
axes[0, 1].set_title('Recall@K Progress')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Learning rate
axes[1, 0].plot(epochs, history['learning_rate'], marker='o', color='purple')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_title('Learning Rate Schedule')
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True)

# Performance summary
axes[1, 1].axis('off')
summary_text = f"""
Training Summary:
├─ Epochs completed: {len(epochs)}
├─ Final train loss: {history['train_loss'][-1]:.4f}
├─ Final val loss: {history['val_loss'][-1]:.4f}
├─ Best Recall@10: {max(history['recall_10']):.4f}
├─ Best Recall@20: {max(history['recall_20']):.4f}
├─ Target R@10 achieved: {'✅' if max(history['recall_10']) >= config.target_recall_10 else '❌'}
└─ Final learning rate: {history['learning_rate'][-1]:.2e}

Model Configuration:
├─ Vocab size: {registry.vocab_size:,}
├─ Hidden size: 512
├─ Layers: 6
├─ Parameters: {total_params:,}
└─ Device: {device}
"""

axes[1, 1].text(0.1, 0.9, summary_text, fontsize=11, fontfamily='monospace', 
               verticalalignment='top', transform=axes[1, 1].transAxes)

plt.tight_layout()
plt.show()

# Save the plot
plt.savefig(MODEL_SAVE_DIR / 'training_results.png', dpi=300, bbox_inches='tight')
print(f"📊 Training plots saved to: {MODEL_SAVE_DIR / 'training_results.png'}")

## Final Model Evaluation


In [None]:
# Load best model for final evaluation
if best_model_path and best_model_path.exists():
    print("Loading best model for final evaluation...")
    checkpoint = torch.load(best_model_path, map_location=device)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    print(f"Best model from epoch {checkpoint['epoch']}:")
    print(f"├─ Validation loss: {checkpoint['val_loss']:.4f}")
    print(f"└─ Recall@10: {checkpoint['recall_10']:.4f}")

# Comprehensive evaluation on test set
print("\n🔍 Final comprehensive evaluation...")

test_metrics = {}
for k in [5, 10, 20, 50]:
    recall_k, ndcg_k = evaluate_next_item_topk(model, test_loader, device, prepared.registry, topk=k)
    test_metrics[f'recall_{k}'] = recall_k
    test_metrics[f'ndcg_{k}'] = ndcg_k
    print(f"Recall@{k:2d}: {recall_k:.4f} | NDCG@{k:2d}: {ndcg_k:.4f}")

# Check if we achieved target performance
print("\n🎯 Target Achievement:")
print(f"Target Recall@10 ({config.target_recall_10:.2f}): {'✅ ACHIEVED' if test_metrics['recall_10'] >= config.target_recall_10 else '❌ Not achieved'}")
print(f"Target Recall@20 ({config.target_recall_20:.2f}): {'✅ ACHIEVED' if test_metrics['recall_20'] >= config.target_recall_20 else '❌ Not achieved'}")

if test_metrics['recall_10'] >= config.target_recall_10:
    print("\n🎉 Model is ready for re-ranking stage!")
    print("   Next steps:")
    print("   1. Deploy for inference (Modal/BaseTen)")
    print("   2. Implement re-ranking with business rules")
    print("   3. A/B test against current system")
else:
    print("\n📈 Consider further improvements:")
    print("   1. Increase model size (hidden_size, num_layers)")
    print("   2. Train for more epochs")
    print("   3. Adjust sequence length or vocabulary")
    print("   4. Try different masking strategies")

## Save Final Results


In [None]:
# Save comprehensive results
import json
from datetime import datetime

results = {
    'timestamp': datetime.now().isoformat(),
    'config': {
        'n_epochs': config.n_epochs,
        'lr': config.lr,
        'weight_decay': config.weight_decay,
        'early_stopping_patience': config.early_stopping_patience,
        'lr_decay_patience': config.lr_decay_patience,
        'target_recall_10': config.target_recall_10
    },
    'model': {
        'd_model': 512,
        'n_layers': 6,
        'vocab_size': prepared.registry.vocab_size,
        'total_parameters': total_params
    },
    'training_history': history,
    'final_metrics': test_metrics,
    'target_achieved': test_metrics['recall_10'] >= config.target_recall_10,
    'best_model_path': str(best_model_path) if best_model_path else None
}

results_path = MODEL_SAVE_DIR / 'training_results.json'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"📄 Complete results saved to: {results_path}")
print(f"📊 TensorBoard logs available in: runs/")
print(f"💾 Best model saved to: {best_model_path}")

print("\n" + "=" * 80)
print("🏁 BERT4Rec Advanced Training Complete!")
print("=" * 80)