# 17. Uncertainty Quantification

Quantify prediction uncertainty using Monte Carlo Dropout and calibration analysis.

## Contents
1. [Setup](#1-setup)
2. [Monte Carlo Dropout](#2-monte-carlo-dropout)
3. [Confidence Calibration](#3-confidence-calibration)
4. [Uncertainty Metrics](#4-uncertainty-metrics)
5. [Selective Prediction](#5-selective-prediction)
6. [Temperature Scaling](#6-temperature-scaling)
7. [Visualization](#7-visualization)

---

## 1. Setup

In [None]:
import sys
from pathlib import Path

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / 'src'))

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional
import json
from tqdm.notebook import tqdm

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
# Load model and data
from miracle.model.backbone import MMDTAELSTMBackbone
from miracle.model.multihead_lm import MultiHeadGCodeLM

VOCAB_PATH = project_root / 'data' / 'gcode_vocab_v2.json'
CHECKPOINT_PATH = project_root / 'outputs' / 'final_model' / 'checkpoint_best.pt'
DATA_DIR = project_root / 'outputs' / 'processed_v2'

with open(VOCAB_PATH) as f:
    vocab = json.load(f)

if CHECKPOINT_PATH.exists():
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=False)
    config = checkpoint.get('config', {})
else:
    config = {'hidden_dim': 256, 'num_layers': 4, 'num_heads': 8, 'dropout': 0.1}

backbone = MMDTAELSTMBackbone(
    continuous_dim=155,
    categorical_dims=[10, 10, 50, 50],
    d_model=config.get('hidden_dim', 256),
    num_layers=config.get('num_layers', 4),
    num_heads=config.get('num_heads', 8),
    dropout=config.get('dropout', 0.1)
).to(device)

lm = MultiHeadGCodeLM(
    d_model=config.get('hidden_dim', 256),
    vocab_sizes=vocab.get('head_vocab_sizes', {'type': 10, 'command': 50, 'param_type': 30, 'param_value': 100})
).to(device)

if CHECKPOINT_PATH.exists():
    backbone.load_state_dict(checkpoint['backbone_state_dict'])
    lm.load_state_dict(checkpoint['lm_state_dict'])

print("Models loaded")

In [None]:
# Load test data
test_path = DATA_DIR / 'test.pt'
if test_path.exists():
    test_data = torch.load(test_path, weights_only=False)
    test_continuous = torch.tensor(test_data['continuous'][:50], dtype=torch.float32).to(device)
    test_categorical = torch.tensor(test_data['categorical'][:50], dtype=torch.long).to(device)
else:
    test_continuous = torch.randn(50, 64, 155).to(device)
    test_categorical = torch.randint(0, 10, (50, 64, 4)).to(device)

print(f"Test data: {test_continuous.shape}")

## 2. Monte Carlo Dropout

Enable dropout at inference time for uncertainty estimation.

In [None]:
def enable_dropout(model):
    """Enable dropout layers for MC Dropout."""
    for module in model.modules():
        if isinstance(module, nn.Dropout):
            module.train()

def mc_dropout_predictions(backbone, lm, continuous, categorical, 
                           n_samples=30, head='command'):
    """Generate multiple predictions using MC Dropout."""
    backbone.eval()
    lm.eval()
    enable_dropout(backbone)
    enable_dropout(lm)
    
    all_probs = []
    
    with torch.no_grad():
        for _ in range(n_samples):
            hidden = backbone(continuous, categorical)
            preds = lm(hidden)
            probs = F.softmax(preds[head], dim=-1)
            all_probs.append(probs.cpu())
    
    # Stack: [n_samples, B, T, V]
    all_probs = torch.stack(all_probs)
    
    # Mean and variance
    mean_probs = all_probs.mean(dim=0)
    var_probs = all_probs.var(dim=0)
    
    return mean_probs, var_probs, all_probs

# Run MC Dropout
print("Running MC Dropout (30 samples)...")
mean_probs, var_probs, all_probs = mc_dropout_predictions(
    backbone, lm, test_continuous[:10], test_categorical[:10], n_samples=30
)

print(f"Mean probs shape: {mean_probs.shape}")
print(f"Variance shape: {var_probs.shape}")

In [None]:
# Visualize MC Dropout predictions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

sample_idx, pos_idx = 0, 32

# Distribution of predictions
sample_probs = all_probs[:, sample_idx, pos_idx, :]  # [n_samples, V]
top_classes = mean_probs[sample_idx, pos_idx].argsort(descending=True)[:5]

for i, cls in enumerate(top_classes):
    class_probs = sample_probs[:, cls].numpy()
    axes[0].boxplot(class_probs, positions=[i], widths=0.6)

axes[0].set_xticks(range(len(top_classes)))
axes[0].set_xticklabels([f'Class {c.item()}' for c in top_classes])
axes[0].set_ylabel('Probability')
axes[0].set_title('MC Dropout: Prediction Distribution')

# Mean confidence over sequence
mean_conf = mean_probs[sample_idx].max(dim=-1)[0].numpy()
axes[1].plot(mean_conf)
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Mean Confidence')
axes[1].set_title('Confidence Across Sequence')

# Uncertainty (variance) over sequence
pred_classes = mean_probs[sample_idx].argmax(dim=-1)
uncertainties = [var_probs[sample_idx, t, pred_classes[t]].item() for t in range(len(pred_classes))]
axes[2].plot(uncertainties, color='coral')
axes[2].set_xlabel('Position')
axes[2].set_ylabel('Prediction Variance')
axes[2].set_title('Uncertainty Across Sequence')

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'mc_dropout_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Confidence Calibration

Analyze how well confidence matches accuracy.

In [None]:
def compute_calibration(predictions, targets, n_bins=10):
    """Compute calibration curve."""
    confidences = predictions.max(dim=-1)[0].flatten().cpu().numpy()
    pred_classes = predictions.argmax(dim=-1).flatten().cpu().numpy()
    true_classes = targets.flatten().cpu().numpy()
    
    accuracies = (pred_classes == true_classes).astype(float)
    
    # Bin by confidence
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_centers = (bin_boundaries[:-1] + bin_boundaries[1:]) / 2
    
    bin_accuracies = []
    bin_confidences = []
    bin_counts = []
    
    for i in range(n_bins):
        mask = (confidences >= bin_boundaries[i]) & (confidences < bin_boundaries[i+1])
        if mask.sum() > 0:
            bin_accuracies.append(accuracies[mask].mean())
            bin_confidences.append(confidences[mask].mean())
            bin_counts.append(mask.sum())
        else:
            bin_accuracies.append(0)
            bin_confidences.append(bin_centers[i])
            bin_counts.append(0)
    
    # Expected Calibration Error
    ece = sum(abs(a - c) * n for a, c, n in zip(bin_accuracies, bin_confidences, bin_counts)) / sum(bin_counts)
    
    return {
        'bin_centers': bin_centers,
        'bin_accuracies': np.array(bin_accuracies),
        'bin_confidences': np.array(bin_confidences),
        'bin_counts': np.array(bin_counts),
        'ece': ece
    }

# Generate predictions with normal inference
backbone.eval()
lm.eval()

with torch.no_grad():
    hidden = backbone(test_continuous, test_categorical)
    predictions = lm(hidden)

# Using predictions as pseudo-targets for demo
# In practice, use actual targets
pseudo_targets = predictions['command'].argmax(dim=-1)

# Add some noise to simulate errors
noise_mask = torch.rand_like(pseudo_targets.float()) < 0.2
noisy_targets = pseudo_targets.clone()
noisy_targets[noise_mask] = torch.randint(0, 50, noisy_targets[noise_mask].shape, device=device)

# Compute calibration
probs = F.softmax(predictions['command'], dim=-1)
calibration = compute_calibration(probs, noisy_targets)

print(f"Expected Calibration Error (ECE): {calibration['ece']:.4f}")

In [None]:
# Visualize calibration
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Reliability diagram
axes[0].bar(calibration['bin_centers'], calibration['bin_accuracies'], width=0.08, 
           alpha=0.7, label='Accuracy')
axes[0].plot([0, 1], [0, 1], 'r--', label='Perfect calibration')
axes[0].set_xlabel('Confidence')
axes[0].set_ylabel('Accuracy')
axes[0].set_title(f'Reliability Diagram (ECE={calibration["ece"]:.4f})')
axes[0].legend()
axes[0].set_xlim(0, 1)
axes[0].set_ylim(0, 1)

# Confidence histogram
all_confs = probs.max(dim=-1)[0].flatten().cpu().numpy()
axes[1].hist(all_confs, bins=20, alpha=0.7, edgecolor='black')
axes[1].set_xlabel('Confidence')
axes[1].set_ylabel('Count')
axes[1].set_title('Confidence Distribution')
axes[1].axvline(x=np.mean(all_confs), color='red', linestyle='--', label=f'Mean: {np.mean(all_confs):.3f}')
axes[1].legend()

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'calibration_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Uncertainty Metrics

Different ways to measure prediction uncertainty.

In [None]:
def compute_entropy(probs):
    """Compute predictive entropy."""
    return -(probs * torch.log(probs + 1e-10)).sum(dim=-1)

def compute_mutual_information(all_probs):
    """Compute mutual information (epistemic uncertainty) from MC samples."""
    mean_probs = all_probs.mean(dim=0)
    total_entropy = compute_entropy(mean_probs)
    expected_entropy = compute_entropy(all_probs).mean(dim=0)
    return total_entropy - expected_entropy

def compute_variation_ratio(all_probs):
    """Compute variation ratio from MC samples."""
    predictions = all_probs.argmax(dim=-1)  # [n_samples, B, T]
    mode_freq = torch.zeros(predictions.shape[1:], device=predictions.device)
    
    for b in range(predictions.shape[1]):
        for t in range(predictions.shape[2]):
            counts = torch.bincount(predictions[:, b, t])
            mode_freq[b, t] = counts.max().float() / len(predictions)
    
    return 1 - mode_freq

# Compute all uncertainty metrics
entropy = compute_entropy(mean_probs)
mutual_info = compute_mutual_information(all_probs)
variation_ratio = compute_variation_ratio(all_probs)

print(f"Entropy - mean: {entropy.mean():.4f}, std: {entropy.std():.4f}")
print(f"Mutual Info - mean: {mutual_info.mean():.4f}, std: {mutual_info.std():.4f}")
print(f"Variation Ratio - mean: {variation_ratio.mean():.4f}, std: {variation_ratio.std():.4f}")

In [None]:
# Visualize uncertainty metrics
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

sample_idx = 0

# Entropy
axes[0, 0].plot(entropy[sample_idx].numpy())
axes[0, 0].set_xlabel('Position')
axes[0, 0].set_ylabel('Entropy')
axes[0, 0].set_title('Predictive Entropy (Total Uncertainty)')

# Mutual Information
axes[0, 1].plot(mutual_info[sample_idx].numpy(), color='coral')
axes[0, 1].set_xlabel('Position')
axes[0, 1].set_ylabel('Mutual Information')
axes[0, 1].set_title('Epistemic Uncertainty (Model Uncertainty)')

# Variation Ratio
axes[1, 0].plot(variation_ratio[sample_idx].numpy(), color='forestgreen')
axes[1, 0].set_xlabel('Position')
axes[1, 0].set_ylabel('Variation Ratio')
axes[1, 0].set_title('Prediction Disagreement')

# All metrics together
axes[1, 1].plot(entropy[sample_idx].numpy() / entropy[sample_idx].max(), label='Entropy (normalized)')
axes[1, 1].plot(mutual_info[sample_idx].numpy() / (mutual_info[sample_idx].max() + 1e-6), label='Mutual Info (normalized)')
axes[1, 1].plot(variation_ratio[sample_idx].numpy(), label='Variation Ratio')
axes[1, 1].set_xlabel('Position')
axes[1, 1].set_ylabel('Normalized Value')
axes[1, 1].set_title('Uncertainty Metrics Comparison')
axes[1, 1].legend()

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'uncertainty_metrics.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Selective Prediction

Reject predictions with high uncertainty.

In [None]:
def selective_prediction_curve(probs, targets, uncertainty_metric):
    """Compute accuracy at different coverage levels."""
    predictions = probs.argmax(dim=-1).flatten().cpu().numpy()
    targets_flat = targets.flatten().cpu().numpy()
    uncertainty_flat = uncertainty_metric.flatten().cpu().numpy()
    
    correct = predictions == targets_flat
    
    # Sort by uncertainty (ascending = most confident first)
    sorted_indices = np.argsort(uncertainty_flat)
    correct_sorted = correct[sorted_indices]
    
    # Compute cumulative accuracy at each coverage
    n = len(correct_sorted)
    coverages = np.arange(1, n + 1) / n
    cumulative_acc = np.cumsum(correct_sorted) / np.arange(1, n + 1)
    
    return coverages, cumulative_acc

# Compute selective prediction curves
confidence = mean_probs.max(dim=-1)[0]  # Higher = more confident

# Use entropy as uncertainty (higher = more uncertain)
coverages_ent, acc_ent = selective_prediction_curve(mean_probs, noisy_targets.cpu(), entropy)
coverages_mi, acc_mi = selective_prediction_curve(mean_probs, noisy_targets.cpu(), mutual_info)
coverages_conf, acc_conf = selective_prediction_curve(mean_probs, noisy_targets.cpu(), 1 - confidence)  # Negate confidence

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Coverage vs Accuracy
axes[0].plot(coverages_ent, acc_ent, label='Entropy')
axes[0].plot(coverages_mi, acc_mi, label='Mutual Information')
axes[0].plot(coverages_conf, acc_conf, label='Confidence')
axes[0].axhline(y=acc_conf[-1], color='gray', linestyle='--', alpha=0.5, label='Full coverage acc')
axes[0].set_xlabel('Coverage')
axes[0].set_ylabel('Accuracy')
axes[0].set_title('Selective Prediction: Coverage vs Accuracy')
axes[0].legend()
axes[0].set_xlim(0, 1)

# Risk-coverage curve
risk_ent = 1 - acc_ent
risk_mi = 1 - acc_mi
risk_conf = 1 - acc_conf

axes[1].plot(coverages_ent, risk_ent, label='Entropy')
axes[1].plot(coverages_mi, risk_mi, label='Mutual Information')
axes[1].plot(coverages_conf, risk_conf, label='Confidence')
axes[1].set_xlabel('Coverage')
axes[1].set_ylabel('Risk (Error Rate)')
axes[1].set_title('Risk-Coverage Curve')
axes[1].legend()
axes[1].set_xlim(0, 1)

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'selective_prediction.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Temperature Scaling

Post-hoc calibration using temperature scaling.

In [None]:
class TemperatureScaling(nn.Module):
    """Temperature scaling for calibration."""
    
    def __init__(self):
        super().__init__()
        self.temperature = nn.Parameter(torch.ones(1))
        
    def forward(self, logits):
        return logits / self.temperature

def train_temperature(logits, targets, max_iter=100):
    """Optimize temperature on validation set."""
    temp_model = TemperatureScaling().to(device)
    optimizer = torch.optim.LBFGS([temp_model.temperature], lr=0.01, max_iter=max_iter)
    
    criterion = nn.CrossEntropyLoss()
    
    def closure():
        optimizer.zero_grad()
        scaled_logits = temp_model(logits)
        B, T, V = scaled_logits.shape
        loss = criterion(scaled_logits.view(-1, V), targets.view(-1))
        loss.backward()
        return loss
    
    optimizer.step(closure)
    
    return temp_model.temperature.item()

# Train temperature
logits = predictions['command']
optimal_temp = train_temperature(logits, noisy_targets)
print(f"Optimal temperature: {optimal_temp:.4f}")

# Apply temperature scaling
scaled_logits = logits / optimal_temp
scaled_probs = F.softmax(scaled_logits, dim=-1)

# Compare calibration
calibration_before = compute_calibration(F.softmax(logits, dim=-1), noisy_targets)
calibration_after = compute_calibration(scaled_probs, noisy_targets)

print(f"ECE before: {calibration_before['ece']:.4f}")
print(f"ECE after:  {calibration_after['ece']:.4f}")

In [None]:
# Visualize temperature scaling effect
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Before
axes[0].bar(calibration_before['bin_centers'], calibration_before['bin_accuracies'], 
           width=0.08, alpha=0.7)
axes[0].plot([0, 1], [0, 1], 'r--')
axes[0].set_xlabel('Confidence')
axes[0].set_ylabel('Accuracy')
axes[0].set_title(f'Before Scaling (ECE={calibration_before["ece"]:.4f})')
axes[0].set_xlim(0, 1)
axes[0].set_ylim(0, 1)

# After
axes[1].bar(calibration_after['bin_centers'], calibration_after['bin_accuracies'], 
           width=0.08, alpha=0.7, color='coral')
axes[1].plot([0, 1], [0, 1], 'r--')
axes[1].set_xlabel('Confidence')
axes[1].set_ylabel('Accuracy')
axes[1].set_title(f'After Scaling (ECE={calibration_after["ece"]:.4f}, T={optimal_temp:.2f})')
axes[1].set_xlim(0, 1)
axes[1].set_ylim(0, 1)

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'temperature_scaling.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Visualization

Comprehensive uncertainty visualization.

In [None]:
# Heatmap of uncertainty across sequence
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Confidence heatmap
conf_matrix = mean_probs.max(dim=-1)[0].cpu().numpy()
im1 = axes[0, 0].imshow(conf_matrix, aspect='auto', cmap='RdYlGn')
axes[0, 0].set_xlabel('Position')
axes[0, 0].set_ylabel('Sample')
axes[0, 0].set_title('Confidence Heatmap')
plt.colorbar(im1, ax=axes[0, 0])

# Entropy heatmap
im2 = axes[0, 1].imshow(entropy.numpy(), aspect='auto', cmap='YlOrRd')
axes[0, 1].set_xlabel('Position')
axes[0, 1].set_ylabel('Sample')
axes[0, 1].set_title('Entropy Heatmap')
plt.colorbar(im2, ax=axes[0, 1])

# Mutual Information heatmap
im3 = axes[1, 0].imshow(mutual_info.numpy(), aspect='auto', cmap='YlOrRd')
axes[1, 0].set_xlabel('Position')
axes[1, 0].set_ylabel('Sample')
axes[1, 0].set_title('Mutual Information Heatmap')
plt.colorbar(im3, ax=axes[1, 0])

# Variation Ratio heatmap
im4 = axes[1, 1].imshow(variation_ratio.numpy(), aspect='auto', cmap='YlOrRd')
axes[1, 1].set_xlabel('Position')
axes[1, 1].set_ylabel('Sample')
axes[1, 1].set_title('Variation Ratio Heatmap')
plt.colorbar(im4, ax=axes[1, 1])

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'uncertainty_heatmaps.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Save uncertainty report
report = {
    'mc_dropout_samples': 30,
    'optimal_temperature': optimal_temp,
    'calibration': {
        'ece_before': calibration_before['ece'],
        'ece_after': calibration_after['ece'],
    },
    'uncertainty_stats': {
        'entropy_mean': float(entropy.mean()),
        'entropy_std': float(entropy.std()),
        'mutual_info_mean': float(mutual_info.mean()),
        'mutual_info_std': float(mutual_info.std()),
    }
}

report_path = project_root / 'reports' / 'uncertainty_report.json'
with open(report_path, 'w') as f:
    json.dump(report, f, indent=2)

print(f"Report saved to: {report_path}")

---

## Summary

This notebook covers uncertainty quantification:

1. **MC Dropout**: Enable dropout at inference for Bayesian approximation
2. **Calibration**: Analyze confidence vs accuracy relationship
3. **Metrics**: Entropy, mutual information, variation ratio
4. **Selective Prediction**: Reject uncertain predictions
5. **Temperature Scaling**: Post-hoc calibration
6. **Visualization**: Heatmaps and curves

---

**Navigation:**
← [Previous: 16_architecture_comparison](16_architecture_comparison.ipynb) |
[Next: 18_transfer_learning](18_transfer_learning.ipynb) →