# 03 - Training Models

**Learn how to train the Sensor Multi-Head Decoder for G-code generation.**

## Learning Objectives
- Understand the two-stage training architecture
- Load preprocessed data and frozen encoder
- Configure the SensorMultiHeadDecoder
- Run training with curriculum learning and focal loss
- Monitor per-head metrics during training
- Save and evaluate checkpoints

## Table of Contents
1. [Architecture Overview](#1.-Architecture-Overview)
2. [Environment Setup](#2.-Environment-Setup)
3. [Load Data & Encoder](#3.-Load-Data-&-Encoder)
4. [Model Configuration](#4.-Model-Configuration)
5. [Training Configuration](#5.-Training-Configuration)
6. [Training Loop](#6.-Training-Loop)
7. [Results & Checkpoints](#7.-Results-&-Checkpoints)
8. [Production Training](#8.-Production-Training)

---
## 1. Architecture Overview

### Two-Stage Training

The G-code fingerprinting model uses a **two-stage architecture**:

| Stage | Component | Status | Purpose |
|-------|-----------|--------|--------|
| 1 | MM-DTAE-LSTM Encoder | **FROZEN** | Extract sensor embeddings + classify operations (100% accurate) |
| 2 | SensorMultiHeadDecoder | **TRAINED** | Generate G-code tokens from embeddings |

### Why Freeze the Encoder?

- The encoder achieves ~100% operation classification accuracy
- Freezing prevents catastrophic forgetting
- Faster training (only decoder gradients)
- Stable sensor embeddings for decoder learning

### Multi-Head Outputs

```
SensorMultiHeadDecoder Outputs:
├── type_logits      [B, L, 4]      → SPECIAL, COMMAND, PARAM, NUMERIC
├── command_logits   [B, L, 6]      → G0, G1, G2, G3, G53, OTHER
├── param_type_logits[B, L, 10]     → X, Y, Z, F, R, S, I, J, K, OTHER
├── sign_logits      [B, L, 3]      → Positive, Negative, Zero
├── digit_logits     [B, L, 6, 10]  → 6 digits (2 int + 4 decimal)
└── aux_value        [B, L, 1]      → Auxiliary regression target
```

---
## 2. Environment Setup

In [None]:
# Setup
import sys
import os
from pathlib import Path

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / 'src'))

print(f"Project root: {project_root}")

In [None]:
# Imports
import json
import math
import time
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Device
if torch.cuda.is_available():
    device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

print(f"Using device: {device}")
print("✓ Imports successful")

---
## 3. Load Data & Encoder

In [None]:
# Load data splits
split_dir = project_root / 'outputs' / 'stratified_splits_v2'

if not split_dir.exists():
    # Try alternative paths
    alt_dirs = [
        project_root / 'outputs' / 'multilabel_stratified_splits',
        project_root / 'outputs' / 'sample_stratified_splits',
    ]
    for d in alt_dirs:
        if d.exists():
            split_dir = d
            break

print(f"Split directory: {split_dir}")

# Load training data
train_data = np.load(split_dir / 'train_sequences.npz', allow_pickle=True)
val_data = np.load(split_dir / 'val_sequences.npz', allow_pickle=True)

print(f"\nData loaded:")
print(f"  Train samples: {len(train_data['continuous'])}")
print(f"  Val samples:   {len(val_data['continuous'])}")

print(f"\nData shapes:")
for key in ['continuous', 'categorical', 'tokens', 'operation_type']:
    if key in train_data:
        print(f"  {key}: {train_data[key].shape}")

In [None]:
# Create PyTorch datasets
class SensorTokenDataset(Dataset):
    """Dataset for sensor-to-token training."""
    
    def __init__(self, data):
        self.continuous = torch.FloatTensor(data['continuous'])
        self.categorical = torch.LongTensor(data['categorical'])
        self.tokens = torch.LongTensor(data['tokens'])
        self.operation_type = torch.LongTensor(data['operation_type'])
        
        # Optional: raw values for auxiliary regression
        if 'param_value_raw' in data:
            self.param_value_raw = torch.FloatTensor(data['param_value_raw'])
        else:
            self.param_value_raw = torch.zeros_like(self.tokens, dtype=torch.float32)
    
    def __len__(self):
        return len(self.continuous)
    
    def __getitem__(self, idx):
        return {
            'continuous': self.continuous[idx],
            'categorical': self.categorical[idx],
            'tokens': self.tokens[idx],
            'operation_type': self.operation_type[idx],
            'param_value_raw': self.param_value_raw[idx],
        }

# Create datasets
train_dataset = SensorTokenDataset(train_data)
val_dataset = SensorTokenDataset(val_data)

# Create dataloaders
BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Dataloaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches:   {len(val_loader)}")

In [None]:
# Load frozen encoder
encoder_path = project_root / 'outputs' / 'mm_dtae_lstm_v2' / 'best_model.pt'

if encoder_path.exists():
    print(f"Loading encoder from: {encoder_path}")
    encoder_checkpoint = torch.load(encoder_path, map_location=device, weights_only=False)
    
    print(f"\nEncoder checkpoint keys: {list(encoder_checkpoint.keys())}")
    
    # The encoder embeddings are pre-computed in the split data
    # For this notebook, we'll use the sensor data directly
    print(f"\n✓ Encoder checkpoint loaded")
else:
    print(f"⚠ Encoder not found at: {encoder_path}")
    print(f"  Will use raw sensor data instead of embeddings")

---
## 4. Model Configuration

### SensorMultiHeadDecoder Configuration

| Parameter | Value | Description |
|-----------|-------|-------------|
| d_model | 192 | Hidden dimension |
| n_heads | 8 | Attention heads |
| n_layers | 4 | Transformer layers |
| sensor_dim | 128 | Encoder output dimension |
| n_operations | 9 | Operation type classes |
| n_types | 4 | Token type classes |
| n_commands | 6 | Command classes |
| n_param_types | 10 | Parameter type classes |
| dropout | 0.3 | Dropout rate |

In [None]:
# Model configuration
model_config = {
    'vocab_size': 668,  # Total vocabulary size
    'd_model': 192,
    'n_heads': 8,
    'n_layers': 4,
    'sensor_dim': 128,
    'n_operations': 9,
    'n_types': 4,
    'n_commands': 6,
    'n_param_types': 10,
    'max_seq_len': 32,
    'dropout': 0.3,
    'embed_dropout': 0.1,
}

print("Model Configuration:")
print("="*50)
for k, v in model_config.items():
    print(f"  {k:20s}: {v}")

In [None]:
# Create model
from miracle.model.sensor_multihead_decoder import SensorMultiHeadDecoder

model = SensorMultiHeadDecoder(
    vocab_size=model_config['vocab_size'],
    d_model=model_config['d_model'],
    n_heads=model_config['n_heads'],
    n_layers=model_config['n_layers'],
    sensor_dim=model_config['sensor_dim'],
    n_operations=model_config['n_operations'],
    n_types=model_config['n_types'],
    n_commands=model_config['n_commands'],
    n_param_types=model_config['n_param_types'],
    max_seq_len=model_config['max_seq_len'],
    dropout=model_config['dropout'],
    embed_dropout=model_config['embed_dropout'],
).to(device)

# Count parameters
param_counts = model.count_parameters()

print("\nModel Parameter Counts:")
print("-" * 50)
for name, count in param_counts.items():
    print(f"  {name:25s}: {count:>10,}")

---
## 5. Training Configuration

### Key Training Features

1. **Focal Loss**: Handles class imbalance (gamma=3.0)
2. **Label Smoothing**: Improves generalization (0.1)
3. **Curriculum Learning**: Structure → Coarse Digits → Full Precision
4. **Cosine LR Schedule**: With warmup epochs
5. **Gradient Clipping**: Prevents exploding gradients (1.0)

In [None]:
# Training configuration
train_config = {
    # Training
    'max_epochs': 5,  # Short demo (use 150 for production)
    'batch_size': BATCH_SIZE,
    'patience': 30,
    
    # Optimizer
    'learning_rate': 2e-4,
    'weight_decay': 0.05,
    'grad_clip': 1.0,
    
    # LR Schedule
    'warmup_epochs': 1,  # Short demo
    
    # Loss weights
    'type_weight': 1.0,
    'command_weight': 2.5,
    'param_type_weight': 1.5,
    'digit_weight': 1.0,
    
    # Focal loss
    'use_focal_loss': True,
    'focal_gamma': 3.0,
    
    # Label smoothing
    'label_smoothing': 0.1,
    
    # Curriculum (disabled for short demo)
    'use_curriculum': False,
}

print("Training Configuration:")
print("="*50)
for k, v in train_config.items():
    print(f"  {k:25s}: {v}")

In [None]:
# Create optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=train_config['learning_rate'],
    weight_decay=train_config['weight_decay'],
    betas=(0.9, 0.999),
)

# Learning rate scheduler (cosine with warmup)
def get_lr_lambda(epoch, warmup_epochs=1, total_epochs=5):
    if epoch < warmup_epochs:
        return (epoch + 1) / warmup_epochs
    else:
        progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
        return 0.5 * (1 + math.cos(math.pi * progress))

scheduler = optim.lr_scheduler.LambdaLR(
    optimizer, 
    lr_lambda=lambda e: get_lr_lambda(e, train_config['warmup_epochs'], train_config['max_epochs'])
)

print(f"Optimizer: AdamW (lr={train_config['learning_rate']}, wd={train_config['weight_decay']})")
print(f"Scheduler: Cosine with {train_config['warmup_epochs']} warmup epochs")

In [None]:
# Loss functions
from miracle.training.losses import FocalLoss

# Focal loss for classification heads
if train_config['use_focal_loss']:
    type_criterion = FocalLoss(gamma=train_config['focal_gamma'], label_smoothing=train_config['label_smoothing'])
    command_criterion = FocalLoss(gamma=train_config['focal_gamma'], label_smoothing=train_config['label_smoothing'])
    param_type_criterion = FocalLoss(gamma=train_config['focal_gamma'], label_smoothing=train_config['label_smoothing'])
    digit_criterion = FocalLoss(gamma=train_config['focal_gamma'], label_smoothing=train_config['label_smoothing'])
else:
    type_criterion = nn.CrossEntropyLoss(label_smoothing=train_config['label_smoothing'])
    command_criterion = nn.CrossEntropyLoss(label_smoothing=train_config['label_smoothing'])
    param_type_criterion = nn.CrossEntropyLoss(label_smoothing=train_config['label_smoothing'])
    digit_criterion = nn.CrossEntropyLoss(label_smoothing=train_config['label_smoothing'])

print(f"Loss: {'Focal Loss (gamma=' + str(train_config['focal_gamma']) + ')' if train_config['use_focal_loss'] else 'Cross Entropy'}")
print(f"Label smoothing: {train_config['label_smoothing']}")

---
## 6. Training Loop

In [None]:
# Training functions
def compute_loss(outputs, batch, train_config):
    """Compute multi-head loss."""
    tokens = batch['tokens']
    B, L = tokens.shape
    
    # For this demo, we use the legacy token prediction
    # Production code uses full multi-head decomposition
    legacy_logits = outputs['legacy_logits']  # [B, L, vocab_size]
    
    # Compute cross-entropy loss
    legacy_logits_flat = legacy_logits.view(-1, legacy_logits.size(-1))
    tokens_flat = tokens.view(-1)
    
    # Ignore padding (assuming PAD=0)
    loss = F.cross_entropy(legacy_logits_flat, tokens_flat, ignore_index=0)
    
    return loss

def compute_accuracy(outputs, batch):
    """Compute token accuracy."""
    tokens = batch['tokens']
    legacy_logits = outputs['legacy_logits']
    
    # Get predictions
    preds = legacy_logits.argmax(dim=-1)  # [B, L]
    
    # Mask padding
    mask = tokens != 0
    
    correct = ((preds == tokens) & mask).sum().item()
    total = mask.sum().item()
    
    return correct / max(total, 1)

print("✓ Training functions defined")

In [None]:
def train_epoch(model, train_loader, optimizer, train_config, device):
    """Train for one epoch."""
    model.train()
    
    total_loss = 0
    total_acc = 0
    n_batches = 0
    
    pbar = tqdm(train_loader, desc='Training', leave=False)
    for batch in pbar:
        # Move to device
        continuous = batch['continuous'].to(device)
        tokens = batch['tokens'].to(device)
        operation_type = batch['operation_type'].to(device)
        
        # For demo: use continuous as sensor embeddings (normally from encoder)
        # Average pool to get [B, sensor_dim]
        sensor_emb = continuous.mean(dim=1, keepdim=True).expand(-1, continuous.size(1), -1)
        sensor_emb = sensor_emb[:, :, :128]  # Take first 128 dims
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(tokens, sensor_emb, operation_type)
        
        # Compute loss
        batch_dict = {'tokens': tokens}
        loss = compute_loss(outputs, batch_dict, train_config)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), train_config['grad_clip'])
        
        # Update
        optimizer.step()
        
        # Metrics
        acc = compute_accuracy(outputs, batch_dict)
        total_loss += loss.item()
        total_acc += acc
        n_batches += 1
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{acc:.2%}'})
    
    return total_loss / n_batches, total_acc / n_batches

@torch.no_grad()
def validate(model, val_loader, train_config, device):
    """Validate the model."""
    model.eval()
    
    total_loss = 0
    total_acc = 0
    n_batches = 0
    
    for batch in tqdm(val_loader, desc='Validating', leave=False):
        continuous = batch['continuous'].to(device)
        tokens = batch['tokens'].to(device)
        operation_type = batch['operation_type'].to(device)
        
        # Sensor embeddings (same as training)
        sensor_emb = continuous.mean(dim=1, keepdim=True).expand(-1, continuous.size(1), -1)
        sensor_emb = sensor_emb[:, :, :128]
        
        # Forward pass
        outputs = model(tokens, sensor_emb, operation_type)
        
        # Metrics
        batch_dict = {'tokens': tokens}
        loss = compute_loss(outputs, batch_dict, train_config)
        acc = compute_accuracy(outputs, batch_dict)
        
        total_loss += loss.item()
        total_acc += acc
        n_batches += 1
    
    return total_loss / n_batches, total_acc / n_batches

print("✓ Train/validate functions defined")

In [None]:
# Run training loop
history = {
    'train_loss': [], 'val_loss': [],
    'train_acc': [], 'val_acc': [],
    'lr': []
}

best_val_acc = 0
best_epoch = 0

print(f"Starting training for {train_config['max_epochs']} epochs...")
print("="*60)

for epoch in range(train_config['max_epochs']):
    epoch_start = time.time()
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, train_config, device)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, train_config, device)
    
    # Update scheduler
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    
    # Record history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    history['lr'].append(current_lr)
    
    # Check best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch
    
    epoch_time = time.time() - epoch_start
    
    print(f"Epoch {epoch+1}/{train_config['max_epochs']} ({epoch_time:.1f}s)")
    print(f"  Train: loss={train_loss:.4f}, acc={train_acc:.2%}")
    print(f"  Val:   loss={val_loss:.4f}, acc={val_acc:.2%}")
    print(f"  LR:    {current_lr:.2e}")

print("\n" + "="*60)
print(f"Training complete!")
print(f"Best val accuracy: {best_val_acc:.2%} (epoch {best_epoch+1})")

In [None]:
# Visualize training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

epochs = range(1, len(history['train_loss']) + 1)

# Loss
ax1 = axes[0]
ax1.plot(epochs, history['train_loss'], 'b-o', label='Train', linewidth=2, markersize=6)
ax1.plot(epochs, history['val_loss'], 'r-s', label='Val', linewidth=2, markersize=6)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss', fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy
ax2 = axes[1]
ax2.plot(epochs, [a*100 for a in history['train_acc']], 'b-o', label='Train', linewidth=2, markersize=6)
ax2.plot(epochs, [a*100 for a in history['val_acc']], 'r-s', label='Val', linewidth=2, markersize=6)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Token Accuracy', fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Learning Rate
ax3 = axes[2]
ax3.plot(epochs, history['lr'], 'g-o', linewidth=2, markersize=6)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Learning Rate')
ax3.set_title('LR Schedule', fontweight='bold')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---
## 7. Results & Checkpoints

In [None]:
# Save checkpoint
checkpoint_dir = project_root / 'outputs' / 'notebook_training'
checkpoint_dir.mkdir(parents=True, exist_ok=True)

checkpoint = {
    'epoch': train_config['max_epochs'],
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'config': {**model_config, **train_config},
    'history': history,
    'best_val_acc': best_val_acc,
}

checkpoint_path = checkpoint_dir / 'checkpoint_demo.pt'
torch.save(checkpoint, checkpoint_path)

print(f"✓ Checkpoint saved to: {checkpoint_path}")
print(f"  Size: {checkpoint_path.stat().st_size / (1024*1024):.2f} MB")

In [None]:
# Load and inspect production checkpoint
prod_checkpoint_path = project_root / 'outputs' / 'sensor_multihead_v3' / 'best_model.pt'

if prod_checkpoint_path.exists():
    prod_checkpoint = torch.load(prod_checkpoint_path, map_location='cpu', weights_only=False)
    
    print("Production Checkpoint:")
    print("="*50)
    print(f"  Path: {prod_checkpoint_path}")
    print(f"  Size: {prod_checkpoint_path.stat().st_size / (1024*1024):.1f} MB")
    
    print(f"\nCheckpoint contents:")
    for key in prod_checkpoint.keys():
        if 'state_dict' in key:
            print(f"  {key}: {len(prod_checkpoint[key])} parameters")
        elif isinstance(prod_checkpoint[key], dict):
            print(f"  {key}: dict")
        else:
            print(f"  {key}: {type(prod_checkpoint[key]).__name__}")
else:
    print(f"Production checkpoint not found at: {prod_checkpoint_path}")

---
## 8. Production Training

For production training, use the command-line script with full features.

In [None]:
# Production training commands
print("="*70)
print("PRODUCTION TRAINING COMMANDS")
print("="*70)

print("\n1. Full Training (Recommended):")
print("-"*50)
cmd1 = """
PYTORCH_ENABLE_MPS_FALLBACK=1 PYTHONPATH=src .venv/bin/python \
    scripts/train_sensor_multihead.py \
    --split-dir outputs/stratified_splits_v2 \
    --vocab-path data/vocabulary_4digit_hybrid.json \
    --encoder-path outputs/mm_dtae_lstm_v2/best_model.pt \
    --output-dir outputs/sensor_multihead_v4 \
    --d-model 192 \
    --n-heads 8 \
    --n-layers 4 \
    --max-epochs 150 \
    --batch-size 32 \
    --learning-rate 2e-4 \
    --use-focal-loss \
    --curriculum \
    --use-wandb
"""
print(cmd1)

print("\n2. Quick Training (Testing):")
print("-"*50)
cmd2 = """
PYTORCH_ENABLE_MPS_FALLBACK=1 PYTHONPATH=src .venv/bin/python \
    scripts/train_sensor_multihead.py \
    --split-dir outputs/stratified_splits_v2 \
    --vocab-path data/vocabulary_4digit_hybrid.json \
    --encoder-path outputs/mm_dtae_lstm_v2/best_model.pt \
    --output-dir outputs/quick_test \
    --max-epochs 10 \
    --batch-size 16
"""
print(cmd2)

print("\n3. Resume Training:")
print("-"*50)
cmd3 = """
PYTHONPATH=src .venv/bin/python scripts/train_sensor_multihead.py \
    --split-dir outputs/stratified_splits_v2 \
    --vocab-path data/vocabulary_4digit_hybrid.json \
    --encoder-path outputs/mm_dtae_lstm_v2/best_model.pt \
    --output-dir outputs/sensor_multihead_v3 \
    --resume outputs/sensor_multihead_v3/best_model.pt \
    --max-epochs 200
"""
print(cmd3)

---
## Summary

In this notebook, you learned:

1. **Two-Stage Architecture**: Frozen encoder + trainable decoder
2. **Data Loading**: NPZ files with sensor data and tokens
3. **Model Configuration**: SensorMultiHeadDecoder with 192 hidden dim, 8 heads, 4 layers
4. **Training Features**: Focal loss, curriculum learning, cosine LR schedule
5. **Monitoring**: Per-epoch loss and accuracy tracking
6. **Checkpoints**: Saving and loading model state

### Key Takeaways

| Feature | Setting | Purpose |
|---------|---------|--------|
| Focal Loss | gamma=3.0 | Handle class imbalance |
| Label Smoothing | 0.1 | Improve generalization |
| Curriculum | 3 phases | Structure → Digits → Full |
| Command Weight | 2.5x | Prioritize command accuracy |
| Dropout | 0.3 | Regularization |

### Production Results (sensor_multihead_v3)

| Metric | Value |
|--------|-------|
| Val Token Accuracy | ~90% |
| Test Token Accuracy | ~90% |
| Training Time | ~2-3 hours (150 epochs) |

---

**Navigation:**
← [Previous: 02_data_preprocessing](02_data_preprocessing.ipynb) |
[Next: 04_inference_prediction](04_inference_prediction.ipynb) →

**Related:** [07_hyperparameter_sweeps](07_hyperparameter_sweeps.ipynb) | [08_model_evaluation](08_model_evaluation.ipynb)