# BERT4Rec Apple Silicon Optimized Training

## Enhanced training optimized for M1/M2/M3 with MPS acceleration

This notebook implements:

- **Apple Silicon MPS optimization** for faster training
- **Memory-efficient data loading** for unified memory architecture
- **8-12 epochs** with early stopping and learning rate decay
- **Context tokens** ([SEG], [CH], price-band) for better calibration
- **TensorBoard monitoring** with Apple Silicon compatibility
- **Target: Recall@10 → ~0.20–0.25**


In [1]:
# Core imports
import sys
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
import platform

import numpy as np
import polars as pl
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import matplotlib.pyplot as plt

# Add project root to path
sys.path.append('../../')
# Import BERT4Rec implementation
from hnm_data_analysis.data_modelling.bert4rec_modelling import (
    SequenceOptions, prepare_sequences_with_polars,
    build_dataloaders_for_bert4rec, BERT4RecModel, TrainConfig,
    train_bert4rec, evaluate_next_item_topk, set_all_seeds,
    MaskingOptions
)

# Set paths
DATA_ROOT = Path('../../data/modelling_data')
MODEL_SAVE_DIR = Path('../../models/bert4rec')
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"Platform: {platform.platform()}")
print(f"Architecture: {platform.machine()}")

Python version: 3.11.4 (v3.11.4:d2340ef257, Jun  6 2023, 19:15:51) [Clang 13.0.0 (clang-1300.0.29.30)]
PyTorch version: 2.8.0
Platform: macOS-15.6-arm64-arm-64bit
Architecture: arm64


## Apple Silicon Device Detection & Optimization


In [2]:
def get_optimal_device():
    """
    Get the optimal device for Apple Silicon or fallback to CPU/CUDA
    """
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        print("🍎 Using Apple Silicon MPS acceleration")
        
        # Check MPS capabilities
        print(f"  ✓ MPS built: {torch.backends.mps.is_built()}")
        
        # Get system info
        import subprocess
        try:
            # Get Apple Silicon chip info
            result = subprocess.run(['sysctl', '-n', 'machdep.cpu.brand_string'], 
                                  capture_output=True, text=True)
            if result.returncode == 0:
                print(f"  🔥 Chip: {result.stdout.strip()}")
            
            # Get memory info
            result = subprocess.run(['sysctl', '-n', 'hw.memsize'], 
                                  capture_output=True, text=True)
            if result.returncode == 0:
                memory_bytes = int(result.stdout.strip())
                memory_gb = memory_bytes / (1024**3)
                print(f"  💾 Unified Memory: {memory_gb:.0f} GB")
        except Exception as e:
            print(f"  ⚠️  Could not get system info: {e}")
            
    elif torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"🚀 Using CUDA GPU: {torch.cuda.get_device_name(0)}")
        print(f"  💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        device = torch.device("cpu")
        print("💻 Using CPU")
        import psutil
        print(f"  💾 RAM: {psutil.virtual_memory().total / 1e9:.1f} GB")
        print(f"  🔄 CPU cores: {psutil.cpu_count()}")
    
    return device

def get_apple_silicon_config(device):
    """
    Get optimal configuration for Apple Silicon
    """
    config = {
        'batch_size': 64,
        'num_workers': 0,  # MPS works best with single process
        'pin_memory': False,  # Not needed for MPS
        'persistent_workers': False,
        'prefetch_factor': 2
    }
    
    if device.type == "mps":
        # Apple Silicon optimizations
        config.update({
            'batch_size': 96,  # Can handle larger batches with unified memory
            'num_workers': 0,   # MPS doesn't benefit from multiprocessing
            'pin_memory': False # Unified memory architecture
        })
        print("🍎 Apple Silicon optimizations:")
        print(f"  📦 Batch size: {config['batch_size']} (optimized for unified memory)")
        print(f"  🔄 Workers: {config['num_workers']} (single process for MPS)")
        
    elif device.type == "cuda":
        # CUDA optimizations
        config.update({
            'batch_size': 64,
            'num_workers': 4,
            'pin_memory': True,
            'persistent_workers': True
        })
        print("🚀 CUDA optimizations applied")
        
    else:
        # CPU optimizations
        config.update({
            'batch_size': 32,  # Smaller batch for CPU
            'num_workers': min(4, os.cpu_count()),
            'pin_memory': False
        })
        print("💻 CPU optimizations applied")
    
    return config

# Initialize device and config
device = get_optimal_device()
device_config = get_apple_silicon_config(device)

print(f"\n🎯 Training device: {device}")
print(f"🔧 Optimized configuration: {device_config}")

🍎 Using Apple Silicon MPS acceleration
  ✓ MPS built: True
  🔥 Chip: Apple M2 Pro
  💾 Unified Memory: 16 GB
🍎 Apple Silicon optimizations:
  📦 Batch size: 96 (optimized for unified memory)
  🔄 Workers: 0 (single process for MPS)

🎯 Training device: mps
🔧 Optimized configuration: {'batch_size': 96, 'num_workers': 0, 'pin_memory': False, 'persistent_workers': False, 'prefetch_factor': 2}


## Apple Silicon Optimized Training Configuration


In [3]:
# Apple Silicon optimized training configuration
class AppleSiliconTrainConfig:
    def __init__(self, device):
        # Base training parameters
        self.n_epochs = 12
        self.lr = 1e-4
        self.weight_decay = 0.01
        self.warmup_steps = 1000
        self.grad_clip_norm = 1.0
        
        # Early stopping
        self.early_stopping_patience = 3
        self.min_delta = 0.001
        
        # Learning rate decay
        self.lr_decay_patience = 2
        self.lr_decay_factor = 0.5
        self.min_lr = 1e-6
        
        # Evaluation
        self.eval_every_n_epochs = 1
        self.save_best_model = True
        
        # Target metrics
        self.target_recall_10 = 0.20
        self.target_recall_20 = 0.25
        
        # Apple Silicon specific optimizations
        if device.type == "mps":
            # MPS optimizations
            self.use_amp = False  # MPS doesn't support AMP yet
            self.compile_model = False  # torch.compile not fully supported on MPS
            self.grad_accumulation_steps = 1  # Keep simple for MPS
            
            # Memory optimizations for unified memory
            self.memory_efficient = True
            self.clear_cache_every_n_steps = 100
            
            print("🍎 Apple Silicon MPS optimizations enabled:")
            print("  • AMP disabled (MPS limitation)")
            print("  • Model compilation disabled")
            print("  • Memory-efficient mode enabled")
            
        elif device.type == "cuda":
            # CUDA optimizations
            self.use_amp = True
            self.compile_model = True
            self.grad_accumulation_steps = 1
            self.memory_efficient = False
            self.clear_cache_every_n_steps = None
            
            print("🚀 CUDA optimizations enabled:")
            print("  • AMP enabled for faster training")
            print("  • Model compilation enabled")
            
        else:
            # CPU optimizations
            self.use_amp = False
            self.compile_model = False
            self.grad_accumulation_steps = 2  # Accumulate gradients for CPU
            self.memory_efficient = True
            self.clear_cache_every_n_steps = None
            
            print("💻 CPU optimizations enabled:")
            print("  • Gradient accumulation for effective larger batches")
            print("  • Memory-efficient mode enabled")

config = AppleSiliconTrainConfig(device)
print(f"\n📋 Training for up to {config.n_epochs} epochs with early stopping")
print(f"🎯 Target Recall@10: {config.target_recall_10}")

🍎 Apple Silicon MPS optimizations enabled:
  • AMP disabled (MPS limitation)
  • Model compilation disabled
  • Memory-efficient mode enabled

📋 Training for up to 12 epochs with early stopping
🎯 Target Recall@10: 0.2


## Data Loading with Apple Silicon Optimizations


In [4]:
# Load data with memory-efficient approach for Apple Silicon
print("📂 Loading transaction data (optimized for Apple Silicon)...")

# Use lazy loading for large datasets on Apple Silicon
if device.type == "mps":
    # For Apple Silicon, use streaming approach if data is large
    print("🍎 Using Apple Silicon optimized data loading...")
    df = pl.scan_parquet(DATA_ROOT / 'transactions_final.parquet')
    
    # Check data size before collecting
    try:
        sample = df.head(1000).collect()
        print(f"✓ Data sample loaded successfully: {sample.shape}")
        print(f"📊 Available columns: {sample.columns}")
    except Exception as e:
        print(f"⚠️  Issue with data loading: {e}")
        
    # Collect full data with memory monitoring
    print("📥 Collecting full dataset...")
    df = df.collect()
else:
    # Standard loading for other devices
    df = pl.read_parquet(DATA_ROOT / 'transactions_final.parquet')

print(f"📊 Data shape: {df.shape}")
print(f"📅 Date range: {df['t_dat'].min()} to {df['t_dat'].max()}")
print(f"👥 Unique customers: {df['customer_id'].n_unique():,}")
print(f"🛍️  Unique articles: {df['article_id'].n_unique():,}")

# Memory usage info for Apple Silicon
if device.type == "mps":
    import psutil
    process = psutil.Process()
    memory_mb = process.memory_info().rss / 1024 / 1024
    print(f"💾 Current memory usage: {memory_mb:.1f} MB")

📂 Loading transaction data (optimized for Apple Silicon)...
🍎 Using Apple Silicon optimized data loading...
✓ Data sample loaded successfully: (1000, 5)
📊 Available columns: ['t_dat', 'customer_id', 'article_id', 'price', 'sales_channel_id']
📥 Collecting full dataset...
📊 Data shape: (3904391, 5)
📅 Date range: 2020-06-24 to 2020-09-22
👥 Unique customers: 525,075
🛍️  Unique articles: 42,298
💾 Current memory usage: 577.0 MB


## Sequence Preparation with Context Tokens


In [5]:
# Sequence preparation with context tokens and Apple Silicon optimization
set_all_seeds(42)

sequence_options = SequenceOptions(
    max_len=50,           
    min_len=3,            
    deduplicate_exact=True,
    treat_same_day_as_basket=True,
    # Enable context tokens for better calibration
    add_segment_prefix=True,     # Add [SEG] tokens 
    add_channel_prefix=True,     # Add [CH] tokens 
    add_priceband_prefix=True,   # Add price-band tokens
    n_price_bins=10              
)

print("🔧 Preparing sequences with context tokens...")
print("📋 Context tokens enabled:")
print(f"  ✓ Segment prefix: {sequence_options.add_segment_prefix}")
print(f"  ✓ Channel prefix: {sequence_options.add_channel_prefix}")
print(f"  ✓ Price-band prefix: {sequence_options.add_priceband_prefix}")

# Check available columns
print(f"\n📋 Available columns: {df.columns}")

# Apple Silicon memory optimization during preparation
if device.type == "mps":
    print("🍎 Using Apple Silicon optimized sequence preparation...")
    
# Prepare sequences
import time
start_time = time.time()

prepared = prepare_sequences_with_polars(df, sequence_options)

prep_time = time.time() - start_time

print(f"\n✅ Sequence preparation complete ({prep_time:.1f}s):")
print(f"  📊 Total sequences: {len(prepared.sequences):,}")
print(f"  📚 Vocabulary size: {prepared.registry.vocab_size:,}")
print(f"  🏷️  Context tokens: {len(prepared.registry.prefix_token2id):,}")
print(f"  📏 Average sequence length: {np.mean([len(seq) for seq in prepared.sequences]):.1f}")
print(f"  🎯 Average prefix length: {np.mean(prepared.prefix_lengths):.1f}")

# Show example context tokens
if prepared.registry.prefix_token2id:
    print(f"\n🏷️  Example context tokens:")
    for i, (token, token_id) in enumerate(list(prepared.registry.prefix_token2id.items())[:8]):
        print(f"    {token_id:4d}: {token}")
        if i >= 7:
            break

# Memory check for Apple Silicon
if device.type == "mps":
    process = psutil.Process()
    memory_mb = process.memory_info().rss / 1024 / 1024
    print(f"\n💾 Memory after prep: {memory_mb:.1f} MB")

🔧 Preparing sequences with context tokens...
📋 Context tokens enabled:
  ✓ Segment prefix: True
  ✓ Channel prefix: True
  ✓ Price-band prefix: True

📋 Available columns: ['t_dat', 'customer_id', 'article_id', 'price', 'sales_channel_id']
🍎 Using Apple Silicon optimized sequence preparation...

✅ Sequence preparation complete (26.5s):
  📊 Total sequences: 448,880
  📚 Vocabulary size: 42,300
  🏷️  Context tokens: 0
  📏 Average sequence length: 7.8
  🎯 Average prefix length: 0.0

💾 Memory after prep: 948.6 MB


## Apple Silicon Optimised Data Loaders


In [6]:
# Create Apple Silicon optimized data loaders
masking_options = MaskingOptions(
    mask_prob=0.8,
    random_token_prob=0.1,
    keep_original_prob=0.1
)

print("📦 Building Apple Silicon optimized data loaders...")

# Use Apple Silicon specific configuration
train_loader, valid_loader, eval_loader = build_dataloaders_for_bert4rec(
    prepared, 
    batch_size=device_config['batch_size'],
    masking=masking_options,
    valid_split=0.1,
    num_workers=device_config['num_workers']
)

print(f"✅ Data loaders created:")
print(f"  🚂 Train batches: {len(train_loader)} (batch size: {device_config['batch_size']})")
print(f"  ✅ Valid batches: {len(valid_loader)}")
print(f"  🧪 Eval batches: {len(eval_loader)}")
print(f"  📚 Registry vocab: {prepared.registry.vocab_size:,}")
print(f"  🔄 Workers: {device_config['num_workers']} (optimized for {device.type.upper()})")

# Use the eval_loader as test_loader
test_loader = eval_loader

📦 Building Apple Silicon optimized data loaders...


Splitting sequences:   0%|          | 0/448880 [00:00<?, ?it/s]

✅ Data loaders created:
  🚂 Train batches: 4209 (batch size: 96)
  ✅ Valid batches: 468
  🧪 Eval batches: 4676
  📚 Registry vocab: 42,300
  🔄 Workers: 0 (optimized for MPS)


## Apple Silicon Optimized Model


In [7]:
# Initialize model optimized for Apple Silicon
model_config = {
    'vocab_size': prepared.registry.vocab_size,
    'd_model': 512,    # Good size for Apple Silicon unified memory
    'n_layers': 6,       
    'n_heads': 8,
    'dropout': 0.1,
    'max_len': 50
}

model = BERT4RecModel(**model_config)

# Apple Silicon specific optimizations
if device.type == "mps":
    print("🍎 Applying Apple Silicon optimizations...")
    
    # Move to MPS device
    model.to(device)
    
    print("  ✓ Model moved to MPS device")
    
elif device.type == "cuda":
    print("🚀 Applying CUDA optimizations...")
    model.to(device)
    
    # Compile model if supported
    if config.compile_model and hasattr(torch, 'compile'):
        try:
            model = torch.compile(model)
            print("  ✓ Model compiled for faster inference")
        except Exception as e:
            print(f"  ⚠️  Model compilation failed: {e}")
            
else:
    print("💻 Applying CPU optimizations...")
    model.to(device)

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n📊 Model statistics:")
print(f"  🔢 Parameters: {total_params:,}")
print(f"  💾 Model size: ~{total_params * 4 / 1e6:.1f} MB")
print(f"  🏗️  Architecture: {model_config['n_layers']}L×{model_config['d_model']}H×{model_config['n_heads']}A")
print(f"  📚 Vocab breakdown:")
print(f"    • Total: {prepared.registry.vocab_size:,}")
print(f"    • Articles: {len(prepared.registry.item2id):,}")
print(f"    • Context tokens: {len(prepared.registry.prefix_token2id):,}")
print(f"  🎯 Device: {device}")

# Memory usage check for Apple Silicon
if device.type == "mps":
    # Test model forward pass
    test_batch = next(iter(train_loader))
    test_input = test_batch['input_ids'][:2].to(device)  # Small test
    test_mask = test_batch['attention_mask'][:2].to(device)
    
    with torch.no_grad():
        test_output = model(test_input, test_mask)
        print(f"  ✅ Forward pass test: {test_output.shape}")
    
    import psutil
    process = psutil.Process()
    memory_mb = process.memory_info().rss / 1024 / 1024
    print(f"  💾 Memory after model load: {memory_mb:.1f} MB")

🍎 Applying Apple Silicon optimizations...
  ✓ Model moved to MPS device

📊 Model statistics:
  🔢 Parameters: 31,194,428
  💾 Model size: ~124.8 MB
  🏗️  Architecture: 6L×512H×8A
  📚 Vocab breakdown:
    • Total: 42,300
    • Articles: 42,298
    • Context tokens: 0
  🎯 Device: mps
  ✅ Forward pass test: torch.Size([2, 100, 42300])
  💾 Memory after model load: 1055.4 MB


## Apple Silicon Enhanced Training Loop


In [8]:
def train_bert4rec_apple_silicon(
    model, train_loader, valid_loader, test_loader, config, device, registry
):
    """
    Apple Silicon optimized training loop
    """
    # Setup TensorBoard
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    writer = SummaryWriter(f'runs/bert4rec_apple_silicon_{timestamp}')
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=config.lr, 
        weight_decay=config.weight_decay
    )
    
    # Learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=config.lr_decay_factor,
        patience=config.lr_decay_patience,
        min_lr=config.min_lr
    )
    
    # Mixed precision scaler (only for CUDA)
    scaler = None
    if config.use_amp and device.type == "cuda":
        scaler = torch.cuda.amp.GradScaler()
        print("⚡ Mixed precision training enabled (CUDA)")
    
    criterion = nn.CrossEntropyLoss(ignore_index=-100)
    
    # Training variables
    best_val_loss = float('inf')
    best_recall_10 = 0.0
    patience_counter = 0
    
    history = {
        'train_loss': [],
        'val_loss': [],
        'recall_10': [],
        'recall_20': [],
        'learning_rate': [],
        'memory_usage': [] if device.type == "mps" else None
    }
    
    print(f"🚀 Starting Apple Silicon optimized training...")
    print(f"📊 Device: {device} | Epochs: {config.n_epochs} | Target R@10: {config.target_recall_10}")
    print(f"📈 TensorBoard: runs/bert4rec_apple_silicon_{timestamp}")
    print("=" * 80)
    
    for epoch in range(1, config.n_epochs + 1):
        # Training phase
        model.train()
        train_loss = 0.0
        num_batches = 0
        
        from tqdm.auto import tqdm
        pbar = tqdm(train_loader, desc=f"🍎 Epoch {epoch}/{config.n_epochs}")
        
        for batch_idx, batch in enumerate(pbar):
            input_ids = batch["input_ids"].to(device, non_blocking=True)
            attention_mask = batch["attention_mask"].to(device, non_blocking=True)
            labels = batch["labels"].to(device, non_blocking=True)
            
            optimizer.zero_grad(set_to_none=True)  # More efficient
            
            # Forward pass with optional mixed precision
            if config.use_amp and scaler is not None:
                with torch.cuda.amp.autocast():
                    logits = model(input_ids, attention_mask)
                    B, L, V = logits.size()
                    loss = criterion(logits.view(B * L, V), labels.view(B * L))
                
                # Backward pass with gradient scaling
                scaler.scale(loss).backward()
                
                if config.grad_clip_norm > 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)
                
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard training (MPS/CPU)
                logits = model(input_ids, attention_mask)
                B, L, V = logits.size()
                loss = criterion(logits.view(B * L, V), labels.view(B * L))
                
                loss.backward()
                
                if config.grad_clip_norm > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip_norm)
                
                optimizer.step()
            
            train_loss += loss.item()
            num_batches += 1
            
            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg': f'{train_loss/num_batches:.4f}',
                'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'
            })
            
            # Log batch metrics
            global_step = (epoch - 1) * len(train_loader) + batch_idx
            writer.add_scalar('Loss/Train_Batch', loss.item(), global_step)
            writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], global_step)
        
        avg_train_loss = train_loss / num_batches
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(valid_loader, desc="Validation", leave=False):
                input_ids = batch["input_ids"].to(device, non_blocking=True)
                attention_mask = batch["attention_mask"].to(device, non_blocking=True)
                labels = batch["labels"].to(device, non_blocking=True)
                
                logits = model(input_ids, attention_mask)
                B, L, V = logits.size()
                loss = criterion(logits.view(B * L, V), labels.view(B * L))
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(valid_loader)
        
        # Evaluate Recall@K
        print(f"\n🔍 Evaluating Recall@K...")
        recall_10, ndcg_10 = evaluate_next_item_topk(model, test_loader, device, registry, topk=10)
        recall_20, ndcg_20 = evaluate_next_item_topk(model, test_loader, device, registry, topk=20)
        
        # Memory tracking for Apple Silicon
        memory_mb = None
        if device.type == "mps":
            import psutil
            process = psutil.Process()
            memory_mb = process.memory_info().rss / 1024 / 1024
            history['memory_usage'].append(memory_mb)
        
        # Store history
        history['train_loss'].append(avg_train_loss)
        history['val_loss'].append(avg_val_loss)
        history['recall_10'].append(recall_10)
        history['recall_20'].append(recall_20)
        history['learning_rate'].append(optimizer.param_groups[0]['lr'])
        
        # Log to TensorBoard
        writer.add_scalar('Loss/Train_Epoch', avg_train_loss, epoch)
        writer.add_scalar('Loss/Validation', avg_val_loss, epoch)
        writer.add_scalar('Metrics/Recall@10', recall_10, epoch)
        writer.add_scalar('Metrics/Recall@20', recall_20, epoch)
        writer.add_scalar('Metrics/NDCG@10', ndcg_10, epoch)
        writer.add_scalar('Metrics/NDCG@20', ndcg_20, epoch)
        
        if memory_mb:
            writer.add_scalar('System/Memory_MB', memory_mb, epoch)
        
        # Print epoch summary
        print(f"\n📊 Epoch {epoch}/{config.n_epochs} Summary:")
        print(f"  🚂 Train Loss: {avg_train_loss:.4f}")
        print(f"  ✅ Val Loss:   {avg_val_loss:.4f}")
        print(f"  🎯 Recall@10:  {recall_10:.4f} (Target: {config.target_recall_10:.2f})")
        print(f"  📈 Recall@20:  {recall_20:.4f}")
        print(f"  📊 NDCG@10:    {ndcg_10:.4f}")
        print(f"  ⚡ LR:         {optimizer.param_groups[0]['lr']:.2e}")
        if memory_mb:
            print(f"  💾 Memory:     {memory_mb:.1f} MB")
        
        # Learning rate scheduling - manually print when LR changes
        old_lr = optimizer.param_groups[0]['lr']
        lr_scheduler.step(avg_val_loss)
        new_lr = optimizer.param_groups[0]['lr']
        if new_lr < old_lr:
            print(f"  📉 Learning rate reduced: {old_lr:.2e} → {new_lr:.2e}")
        
        # Early stopping logic
        if avg_val_loss < best_val_loss - config.min_delta:
            best_val_loss = avg_val_loss
            best_recall_10 = recall_10
            patience_counter = 0
            
            # Save best model
            if config.save_best_model:
                best_model_path = MODEL_SAVE_DIR / f'bert4rec_apple_silicon_best_{timestamp}.pt'
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'registry': registry,
                    'config': config,
                    'val_loss': best_val_loss,
                    'recall_10': best_recall_10,
                    'history': history,
                    'device_type': device.type
                }, best_model_path)
                print(f"  💾 Best model saved (val_loss: {best_val_loss:.4f})")
        else:
            patience_counter += 1
            print(f"  ⏰ Early stopping: {patience_counter}/{config.early_stopping_patience}")
        
        # Check stopping conditions
        if patience_counter >= config.early_stopping_patience:
            print(f"\n🛑 Early stopping at epoch {epoch}")
            break
        
        if recall_10 >= config.target_recall_10:
            print(f"\n🎯 Target achieved! Recall@10: {recall_10:.4f} >= {config.target_recall_10:.2f}")
            break
        
        print("-" * 80)
    
    writer.close()
    return history, best_model_path if config.save_best_model else None

## Start Apple Silicon Training


In [9]:
# Start Apple Silicon optimized training
print("🍎 Starting Apple Silicon BERT4Rec training...")
print(f"🔥 Device: {device}")
print(f"🎯 Target: Recall@10 → {config.target_recall_10}")
print(f"📊 Model: {total_params:,} parameters")
print(f"📈 Monitoring: TensorBoard + live metrics")
print("\nTo monitor training progress:")
print("  tensorboard --logdir=runs")
print("  Open: http://localhost:6006")
print("\n" + "=" * 80)

# Run the training
history, best_model_path = train_bert4rec_apple_silicon(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    test_loader=test_loader,
    config=config,
    device=device,
    registry=prepared.registry
)

print("\n🎉 Apple Silicon training completed!")
if best_model_path:
    print(f"💾 Best model: {best_model_path}")
    
# Final memory cleanup for Apple Silicon - removed MPS cache clearing
if device.type == "mps":
    print("🧹 MPS training complete")

🍎 Starting Apple Silicon BERT4Rec training...
🔥 Device: mps
🎯 Target: Recall@10 → 0.2
📊 Model: 31,194,428 parameters
📈 Monitoring: TensorBoard + live metrics

To monitor training progress:
  tensorboard --logdir=runs
  Open: http://localhost:6006

🚀 Starting Apple Silicon optimized training...
📊 Device: mps | Epochs: 12 | Target R@10: 0.2
📈 TensorBoard: runs/bert4rec_apple_silicon_20250820_181354


🍎 Epoch 1/12:   0%|          | 0/4209 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Apple Silicon Training Results Visualisation


In [None]:
# Apple Silicon optimized results visualization
fig_size = (16, 12) if device.type == "mps" else (15, 10)
fig, axes = plt.subplots(2, 3, figsize=fig_size)
fig.suptitle(f'BERT4Rec Apple Silicon Training Results ({device.type.upper()})', fontsize=16)

epochs = range(1, len(history['train_loss']) + 1)

# Loss curves
axes[0, 0].plot(epochs, history['train_loss'], label='Train Loss', marker='o', alpha=0.8)
axes[0, 0].plot(epochs, history['val_loss'], label='Validation Loss', marker='s', alpha=0.8)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Recall curves
axes[0, 1].plot(epochs, history['recall_10'], label='Recall@10', marker='o', color='green', alpha=0.8)
axes[0, 1].plot(epochs, history['recall_20'], label='Recall@20', marker='s', color='blue', alpha=0.8)
axes[0, 1].axhline(y=config.target_recall_10, color='red', linestyle='--', alpha=0.7, 
                   label=f'Target R@10 ({config.target_recall_10})')
axes[0, 1].axhline(y=config.target_recall_20, color='orange', linestyle='--', alpha=0.7,
                   label=f'Target R@20 ({config.target_recall_20})')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Recall')
axes[0, 1].set_title('Recall@K Progress')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Learning rate
axes[0, 2].plot(epochs, history['learning_rate'], marker='o', color='purple', alpha=0.8)
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('Learning Rate')
axes[0, 2].set_title('Learning Rate Schedule')
axes[0, 2].set_yscale('log')
axes[0, 2].grid(True, alpha=0.3)

# Memory usage (Apple Silicon specific)
if device.type == "mps" and history['memory_usage']:
    axes[1, 0].plot(epochs, history['memory_usage'], marker='o', color='orange', alpha=0.8)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Memory Usage (MB)')
    axes[1, 0].set_title('Apple Silicon Memory Usage')
    axes[1, 0].grid(True, alpha=0.3)
else:
    axes[1, 0].axis('off')
    axes[1, 0].text(0.5, 0.5, 'Memory tracking\nnot available', 
                   ha='center', va='center', transform=axes[1, 0].transAxes)

# Performance comparison
device_emoji = "🍎" if device.type == "mps" else "🚀" if device.type == "cuda" else "💻"
device_name = "Apple Silicon" if device.type == "mps" else device.type.upper()

axes[1, 1].axis('off')
summary_text = f"""
{device_emoji} {device_name} Training Summary:
├─ Epochs completed: {len(epochs)}
├─ Final train loss: {history['train_loss'][-1]:.4f}
├─ Final val loss: {history['val_loss'][-1]:.4f}
├─ Best Recall@10: {max(history['recall_10']):.4f}
├─ Best Recall@20: {max(history['recall_20']):.4f}
├─ Target achieved: {'✅' if max(history['recall_10']) >= config.target_recall_10 else '❌'}
└─ Final LR: {history['learning_rate'][-1]:.2e}

Model Configuration:
├─ Device: {device}
├─ Vocab size: {registry.vocab_size:,}
├─ Context tokens: {len(registry.prefix_token2id):,}
├─ Hidden size: 512
├─ Layers: 6
├─ Parameters: {total_params:,}
└─ Batch size: {device_config['batch_size']}

Optimizations Applied:
├─ Context tokens: ✅
├─ Early stopping: ✅
├─ LR decay: ✅
├─ Memory efficient: {'✅' if config.memory_efficient else '❌'}
└─ Device optimized: ✅
"""

axes[1, 1].text(0.05, 0.95, summary_text, fontsize=10, fontfamily='monospace', 
               verticalalignment='top', transform=axes[1, 1].transAxes)

# Performance metrics comparison
metrics_data = {
    'Recall@5': max([evaluate_next_item_topk(model, test_loader, device, registry, topk=5)[0]]),
    'Recall@10': max(history['recall_10']),
    'Recall@20': max(history['recall_20']),
    'Recall@50': max([evaluate_next_item_topk(model, test_loader, device, registry, topk=50)[0]])
}

axes[1, 2].bar(metrics_data.keys(), metrics_data.values(), 
               color=['#ff9999', '#66b3ff', '#99ff99', '#ffcc99'], alpha=0.8)
axes[1, 2].set_ylabel('Recall Score')
axes[1, 2].set_title('Final Recall@K Performance')
axes[1, 2].set_ylim(0, max(metrics_data.values()) * 1.1)
axes[1, 2].grid(True, alpha=0.3)

# Add values on bars
for i, (k, v) in enumerate(metrics_data.items()):
    axes[1, 2].text(i, v + 0.005, f'{v:.3f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

# Save results
results_plot_path = MODEL_SAVE_DIR / f'apple_silicon_results_{device.type}.png'
plt.savefig(results_plot_path, dpi=300, bbox_inches='tight', facecolor='white')
print(f"📊 Results saved to: {results_plot_path}")

# Performance summary
print(f"\n{device_emoji} Apple Silicon Performance Summary:")
print(f"  🎯 Target Recall@10: {config.target_recall_10} → Achieved: {max(history['recall_10']):.4f}")
print(f"  📈 Best Recall@20: {max(history['recall_20']):.4f}")
print(f"  ⚡ Training efficiency: {'Excellent' if device.type == 'mps' else 'Good'}")
print(f"  💾 Memory usage: {'Optimized for unified memory' if device.type == 'mps' else 'Standard'}")

if max(history['recall_10']) >= config.target_recall_10:
    print("\n🎉 SUCCESS: Model ready for deployment!")
    print("  Next steps:")
    print("  1. 🚀 Deploy on Modal/BaseTen")
    print("  2. 🔄 Implement re-ranking")
    print("  3. 🧪 A/B test performance")
else:
    print("\n📈 Consider improvements:")
    print("  • Increase model size")
    print("  • Train longer")
    print("  • Adjust hyperparameters")