# VMAT Dose Prediction: Results Analysis

Post-training analysis notebook for evaluating model performance, comparing models, and generating publication-ready figures.

**Usage:** Run after training and inference are complete.

## Setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pathlib import Path
import json
from typing import Dict, List, Optional, Tuple
import warnings
warnings.filterwarnings('ignore')

# Configure for publication-quality figures
plt.rcParams.update({
    'font.size': 12,
    'font.family': 'sans-serif',
    'axes.labelsize': 12,
    'axes.titlesize': 14,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
})

# Color scheme
COLORS = {
    'gt': '#2E86AB',      # Blue - ground truth
    'pred': '#E94F37',    # Red - prediction
    'diff': '#7B2D8E',    # Purple - difference
    'pass': '#4CAF50',    # Green - pass
    'fail': '#F44336',    # Red - fail
}

print("Setup complete!")

## 1. Configuration

**Edit these paths to match your setup:**

In [None]:
# ============================================================================
# EDIT THESE PATHS
# ============================================================================

# Directories
TEST_DATA_DIR = Path("./test_npz")           # Ground truth test cases
DDPM_PRED_DIR = Path("./predictions/ddpm")   # Diffusion predictions
BASELINE_PRED_DIR = Path("./predictions/baseline")  # Baseline predictions

# Results files
DDPM_RESULTS = DDPM_PRED_DIR / "evaluation_results.json"
BASELINE_RESULTS = BASELINE_PRED_DIR / "baseline_evaluation_results.json"

# Output directory for figures
FIGURES_DIR = Path("./figures")
FIGURES_DIR.mkdir(exist_ok=True)

# Prescription dose
RX_DOSE_GY = 70.0

# ============================================================================

# Verify paths exist
print("Checking paths...")
for p in [TEST_DATA_DIR, DDPM_PRED_DIR]:
    status = "✓" if p.exists() else "✗ NOT FOUND"
    print(f"  {p}: {status}")

for p in [DDPM_RESULTS, BASELINE_RESULTS]:
    status = "✓" if p.exists() else "⚠ Not found (optional)"
    print(f"  {p}: {status}")

## 2. Load Results

In [None]:
def load_results(results_path: Path) -> Optional[Dict]:
    """Load evaluation results JSON."""
    if not results_path.exists():
        print(f"Warning: {results_path} not found")
        return None
    with open(results_path) as f:
        return json.load(f)

# Load results
ddpm_results = load_results(DDPM_RESULTS)
baseline_results = load_results(BASELINE_RESULTS)

# Summary
if ddpm_results:
    print(f"\nDDPM Results: {ddpm_results['n_cases']} cases")
    print(f"  MAE: {ddpm_results['aggregate_metrics']['mae_gy_mean']:.2f} ± {ddpm_results['aggregate_metrics']['mae_gy_std']:.2f} Gy")
    if 'gamma_pass_rate_mean' in ddpm_results['aggregate_metrics']:
        print(f"  Gamma: {ddpm_results['aggregate_metrics']['gamma_pass_rate_mean']:.1f} ± {ddpm_results['aggregate_metrics']['gamma_pass_rate_std']:.1f}%")

if baseline_results:
    print(f"\nBaseline Results: {baseline_results['n_cases']} cases")
    print(f"  MAE: {baseline_results['aggregate_metrics']['mae_gy_mean']:.2f} ± {baseline_results['aggregate_metrics']['mae_gy_std']:.2f} Gy")
    if 'gamma_pass_rate_mean' in baseline_results['aggregate_metrics']:
        print(f"  Gamma: {baseline_results['aggregate_metrics']['gamma_pass_rate_mean']:.1f} ± {baseline_results['aggregate_metrics']['gamma_pass_rate_std']:.1f}%")

## 3. Model Comparison Table

In [None]:
def create_comparison_table(ddpm: Dict, baseline: Dict) -> None:
    """Print formatted comparison table."""
    
    print("\n" + "="*70)
    print("MODEL COMPARISON")
    print("="*70)
    print(f"{'Metric':<30} {'DDPM':<18} {'Baseline':<18}")
    print("-"*70)
    
    # MAE
    ddpm_mae = f"{ddpm['aggregate_metrics']['mae_gy_mean']:.2f} ± {ddpm['aggregate_metrics']['mae_gy_std']:.2f}"
    base_mae = f"{baseline['aggregate_metrics']['mae_gy_mean']:.2f} ± {baseline['aggregate_metrics']['mae_gy_std']:.2f}"
    winner = "←" if ddpm['aggregate_metrics']['mae_gy_mean'] < baseline['aggregate_metrics']['mae_gy_mean'] else "→"
    print(f"{'MAE (Gy)':<30} {ddpm_mae:<18} {base_mae:<18} {winner}")
    
    # Gamma
    if 'gamma_pass_rate_mean' in ddpm['aggregate_metrics']:
        ddpm_gamma = f"{ddpm['aggregate_metrics']['gamma_pass_rate_mean']:.1f} ± {ddpm['aggregate_metrics']['gamma_pass_rate_std']:.1f}"
        base_gamma = f"{baseline['aggregate_metrics']['gamma_pass_rate_mean']:.1f} ± {baseline['aggregate_metrics']['gamma_pass_rate_std']:.1f}"
        winner = "←" if ddpm['aggregate_metrics']['gamma_pass_rate_mean'] > baseline['aggregate_metrics']['gamma_pass_rate_mean'] else "→"
        print(f"{'Gamma 3%/3mm (%)':<30} {ddpm_gamma:<18} {base_gamma:<18} {winner}")
    
    # Clinical constraints
    if 'clinical_constraints' in ddpm:
        ddpm_pass = f"{ddpm['clinical_constraints']['cases_all_passed']}/{ddpm['n_cases']}"
        base_pass = f"{baseline['clinical_constraints']['cases_all_passed']}/{baseline['n_cases']}"
        print(f"{'Cases passing all constraints':<30} {ddpm_pass:<18} {base_pass:<18}")
    
    print("-"*70)
    print("← = DDPM better, → = Baseline better")
    print("="*70)

if ddpm_results and baseline_results:
    create_comparison_table(ddpm_results, baseline_results)
else:
    print("Need both DDPM and baseline results for comparison")

## 4. Per-Case Metrics Visualization

In [None]:
def plot_per_case_metrics(ddpm: Dict, baseline: Dict = None) -> plt.Figure:
    """Bar chart comparing per-case MAE and Gamma."""
    
    # Extract per-case data
    cases = [r['case_id'] for r in ddpm['per_case_results']]
    ddpm_mae = [r['dose_metrics']['mae_gy'] for r in ddpm['per_case_results']]
    ddpm_gamma = [r['gamma']['gamma_pass_rate'] for r in ddpm['per_case_results'] 
                  if 'gamma' in r and r['gamma'].get('gamma_pass_rate')]
    
    if baseline:
        base_mae = [r['dose_metrics']['mae_gy'] for r in baseline['per_case_results']]
        base_gamma = [r['gamma']['gamma_pass_rate'] for r in baseline['per_case_results']
                     if 'gamma' in r and r['gamma'].get('gamma_pass_rate')]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    x = np.arange(len(cases))
    width = 0.35
    
    # MAE plot
    ax = axes[0]
    bars1 = ax.bar(x - width/2, ddpm_mae, width, label='DDPM', color=COLORS['gt'])
    if baseline:
        bars2 = ax.bar(x + width/2, base_mae, width, label='Baseline', color=COLORS['pred'])
    ax.axhline(y=2.0, color='green', linestyle='--', label='Target (< 2 Gy)')
    ax.set_xlabel('Case')
    ax.set_ylabel('MAE (Gy)')
    ax.set_title('Mean Absolute Error by Case')
    ax.set_xticks(x)
    ax.set_xticklabels(cases, rotation=45, ha='right')
    ax.legend()
    ax.set_ylim(0, max(ddpm_mae + (base_mae if baseline else [])) * 1.2)
    
    # Gamma plot
    ax = axes[1]
    if ddpm_gamma:
        bars1 = ax.bar(x - width/2, ddpm_gamma, width, label='DDPM', color=COLORS['gt'])
        if baseline and base_gamma:
            bars2 = ax.bar(x + width/2, base_gamma, width, label='Baseline', color=COLORS['pred'])
        ax.axhline(y=95.0, color='green', linestyle='--', label='Target (> 95%)')
        ax.set_xlabel('Case')
        ax.set_ylabel('Gamma Pass Rate (%)')
        ax.set_title('Gamma (3%/3mm) by Case')
        ax.set_xticks(x)
        ax.set_xticklabels(cases, rotation=45, ha='right')
        ax.legend()
        ax.set_ylim(0, 105)
    
    plt.tight_layout()
    return fig

if ddpm_results:
    fig = plot_per_case_metrics(ddpm_results, baseline_results)
    fig.savefig(FIGURES_DIR / 'per_case_metrics.png')
    plt.show()
    print(f"Saved: {FIGURES_DIR / 'per_case_metrics.png'}")

## 5. Load Case Data for Visualization

In [None]:
def load_case_data(case_id: str) -> Dict:
    """Load ground truth, predictions, and masks for a case."""
    
    # Find ground truth
    gt_path = TEST_DATA_DIR / f"{case_id}.npz"
    if not gt_path.exists():
        raise FileNotFoundError(f"Ground truth not found: {gt_path}")
    
    gt_data = np.load(gt_path, allow_pickle=True)
    
    result = {
        'case_id': case_id,
        'ct': gt_data['ct'],
        'gt_dose': gt_data['dose'],
        'masks': gt_data['masks'],
        'masks_sdf': gt_data['masks_sdf'] if 'masks_sdf' in gt_data.files else None,
    }
    
    # Load DDPM prediction
    ddpm_path = DDPM_PRED_DIR / f"{case_id}_pred.npz"
    if ddpm_path.exists():
        ddpm_data = np.load(ddpm_path)
        result['ddpm_dose'] = ddpm_data['dose']
    
    # Load baseline prediction
    baseline_path = BASELINE_PRED_DIR / f"{case_id}_pred.npz"
    if baseline_path.exists():
        baseline_data = np.load(baseline_path)
        result['baseline_dose'] = baseline_data['dose']
    
    return result

# Get list of test cases
test_cases = [f.stem for f in TEST_DATA_DIR.glob("*.npz")] if TEST_DATA_DIR.exists() else []
print(f"Found {len(test_cases)} test cases: {test_cases}")

# Load first case for visualization
if test_cases:
    case_data = load_case_data(test_cases[0])
    print(f"\nLoaded case: {case_data['case_id']}")
    print(f"  CT shape: {case_data['ct'].shape}")
    print(f"  GT dose shape: {case_data['gt_dose'].shape}")
    if 'ddpm_dose' in case_data:
        print(f"  DDPM prediction: loaded")
    if 'baseline_dose' in case_data:
        print(f"  Baseline prediction: loaded")
else:
    print("No test cases found. Check TEST_DATA_DIR path.")
    case_data = None

## 6. Dose Distribution Comparison

In [None]:
def plot_dose_comparison(
    case_data: Dict,
    slice_idx: int = None,
    view: str = 'axial',  # 'axial', 'sagittal', 'coronal'
    model: str = 'ddpm',  # 'ddpm' or 'baseline'
) -> plt.Figure:
    """
    Plot ground truth, prediction, and difference maps.
    """
    gt_dose = case_data['gt_dose'] * RX_DOSE_GY
    pred_key = f'{model}_dose'
    
    if pred_key not in case_data:
        raise ValueError(f"No {model} prediction available")
    
    pred_dose = case_data[pred_key] * RX_DOSE_GY
    diff = pred_dose - gt_dose
    
    # Get slice
    if view == 'axial':
        idx = slice_idx or gt_dose.shape[2] // 2
        gt_slice = gt_dose[:, :, idx]
        pred_slice = pred_dose[:, :, idx]
        diff_slice = diff[:, :, idx]
        ct_slice = case_data['ct'][:, :, idx]
    elif view == 'sagittal':
        idx = slice_idx or gt_dose.shape[0] // 2
        gt_slice = gt_dose[idx, :, :]
        pred_slice = pred_dose[idx, :, :]
        diff_slice = diff[idx, :, :]
        ct_slice = case_data['ct'][idx, :, :]
    else:  # coronal
        idx = slice_idx or gt_dose.shape[1] // 2
        gt_slice = gt_dose[:, idx, :]
        pred_slice = pred_dose[:, idx, :]
        diff_slice = diff[:, idx, :]
        ct_slice = case_data['ct'][:, idx, :]
    
    # Create figure
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Row 1: CT + Dose overlays
    dose_vmax = max(gt_slice.max(), pred_slice.max())
    
    # GT
    axes[0, 0].imshow(ct_slice, cmap='gray', aspect='auto')
    im = axes[0, 0].imshow(gt_slice, cmap='jet', alpha=0.6, vmin=0, vmax=dose_vmax, aspect='auto')
    axes[0, 0].set_title('Ground Truth Dose')
    axes[0, 0].axis('off')
    plt.colorbar(im, ax=axes[0, 0], label='Dose (Gy)', shrink=0.8)
    
    # Prediction
    axes[0, 1].imshow(ct_slice, cmap='gray', aspect='auto')
    im = axes[0, 1].imshow(pred_slice, cmap='jet', alpha=0.6, vmin=0, vmax=dose_vmax, aspect='auto')
    axes[0, 1].set_title(f'Predicted Dose ({model.upper()})')
    axes[0, 1].axis('off')
    plt.colorbar(im, ax=axes[0, 1], label='Dose (Gy)', shrink=0.8)
    
    # Difference
    diff_max = max(abs(diff_slice.min()), abs(diff_slice.max()), 5)
    axes[0, 2].imshow(ct_slice, cmap='gray', aspect='auto')
    im = axes[0, 2].imshow(diff_slice, cmap='RdBu_r', alpha=0.7, vmin=-diff_max, vmax=diff_max, aspect='auto')
    axes[0, 2].set_title('Difference (Pred - GT)')
    axes[0, 2].axis('off')
    plt.colorbar(im, ax=axes[0, 2], label='Δ Dose (Gy)', shrink=0.8)
    
    # Row 2: Dose only (no CT background)
    axes[1, 0].imshow(gt_slice, cmap='jet', vmin=0, vmax=dose_vmax, aspect='auto')
    axes[1, 0].set_title('Ground Truth')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(pred_slice, cmap='jet', vmin=0, vmax=dose_vmax, aspect='auto')
    axes[1, 1].set_title('Prediction')
    axes[1, 1].axis('off')
    
    # Absolute error
    abs_diff = np.abs(diff_slice)
    im = axes[1, 2].imshow(abs_diff, cmap='hot', vmin=0, vmax=5, aspect='auto')
    axes[1, 2].set_title('Absolute Error')
    axes[1, 2].axis('off')
    plt.colorbar(im, ax=axes[1, 2], label='|Error| (Gy)', shrink=0.8)
    
    fig.suptitle(f"Case: {case_data['case_id']} | View: {view} | Slice: {idx}", fontsize=14)
    plt.tight_layout()
    
    return fig

# Plot for first test case
if case_data and 'ddpm_dose' in case_data:
    fig = plot_dose_comparison(case_data, view='axial', model='ddpm')
    fig.savefig(FIGURES_DIR / f"{case_data['case_id']}_dose_comparison.png")
    plt.show()
    print(f"Saved: {FIGURES_DIR / f'{case_data[chr(39)+chr(39)+'case_id'+chr(39)+chr(39)]}_dose_comparison.png'}")

## 7. DVH Analysis

In [None]:
STRUCTURE_NAMES = {
    0: 'PTV70',
    1: 'PTV56',
    2: 'Prostate',
    3: 'Rectum',
    4: 'Bladder',
    5: 'Femur_L',
    6: 'Femur_R',
    7: 'Bowel',
}

STRUCTURE_COLORS = {
    'PTV70': '#E41A1C',
    'PTV56': '#FF7F00',
    'Prostate': '#984EA3',
    'Rectum': '#4DAF4A',
    'Bladder': '#377EB8',
    'Femur_L': '#A65628',
    'Femur_R': '#F781BF',
    'Bowel': '#999999',
}

def compute_dvh(dose: np.ndarray, mask: np.ndarray, bins: int = 200) -> Tuple[np.ndarray, np.ndarray]:
    """Compute cumulative DVH."""
    dose_in_struct = dose[mask > 0]
    if len(dose_in_struct) == 0:
        return np.array([]), np.array([])
    
    max_dose = dose_in_struct.max()
    hist, edges = np.histogram(dose_in_struct, bins=bins, range=(0, max_dose * 1.05))
    cumulative = np.cumsum(hist[::-1])[::-1]
    cumulative = cumulative / cumulative[0] * 100  # Normalize to %
    
    return edges[:-1], cumulative

def plot_dvh_comparison(
    case_data: Dict,
    structures: List[str] = None,
    model: str = 'ddpm',
) -> plt.Figure:
    """Plot DVH comparison for ground truth vs prediction."""
    
    gt_dose = case_data['gt_dose'] * RX_DOSE_GY
    pred_dose = case_data[f'{model}_dose'] * RX_DOSE_GY
    masks = case_data['masks']
    
    if structures is None:
        structures = ['PTV70', 'Rectum', 'Bladder']
    
    fig, ax = plt.subplots(figsize=(10, 7))
    
    for struct_name in structures:
        # Find structure index
        struct_idx = None
        for idx, name in STRUCTURE_NAMES.items():
            if name == struct_name:
                struct_idx = idx
                break
        
        if struct_idx is None or struct_idx >= masks.shape[0]:
            continue
        
        mask = masks[struct_idx]
        if mask.sum() == 0:
            continue
        
        color = STRUCTURE_COLORS.get(struct_name, 'gray')
        
        # Ground truth DVH
        gt_x, gt_y = compute_dvh(gt_dose, mask)
        if len(gt_x) > 0:
            ax.plot(gt_x, gt_y, color=color, linestyle='-', linewidth=2, 
                   label=f'{struct_name} (GT)')
        
        # Predicted DVH
        pred_x, pred_y = compute_dvh(pred_dose, mask)
        if len(pred_x) > 0:
            ax.plot(pred_x, pred_y, color=color, linestyle='--', linewidth=2,
                   label=f'{struct_name} (Pred)')
    
    ax.set_xlabel('Dose (Gy)', fontsize=12)
    ax.set_ylabel('Volume (%)', fontsize=12)
    ax.set_title(f'DVH Comparison: {case_data["case_id"]}', fontsize=14)
    ax.set_xlim(0, 80)
    ax.set_ylim(0, 105)
    ax.grid(True, alpha=0.3)
    ax.legend(loc='upper right', ncol=2)
    
    # Add prescription lines
    ax.axvline(x=70, color='red', linestyle=':', alpha=0.5)
    ax.axvline(x=56, color='orange', linestyle=':', alpha=0.5)
    
    plt.tight_layout()
    return fig

if case_data and 'ddpm_dose' in case_data:
    fig = plot_dvh_comparison(case_data, structures=['PTV70', 'Rectum', 'Bladder'], model='ddpm')
    fig.savefig(FIGURES_DIR / f"{case_data['case_id']}_dvh.png")
    plt.show()
    print(f"Saved: {FIGURES_DIR / f'{case_data[chr(39)+chr(39)+'case_id'+chr(39)+chr(39)]}_dvh.png'}")

## 8. DVH Metrics Table

In [None]:
def compute_dvh_metrics(dose: np.ndarray, mask: np.ndarray) -> Dict:
    """Compute DVH metrics for a structure."""
    dose_in_struct = dose[mask > 0]
    if len(dose_in_struct) == 0:
        return {}
    
    return {
        'D95': np.percentile(dose_in_struct, 5),   # Dose to 95% of volume
        'D50': np.percentile(dose_in_struct, 50),  # Dose to 50% of volume
        'D5': np.percentile(dose_in_struct, 95),   # Dose to 5% of volume
        'Dmax': dose_in_struct.max(),
        'Dmean': dose_in_struct.mean(),
        'V70': (dose_in_struct >= 70).sum() / len(dose_in_struct) * 100,
        'V60': (dose_in_struct >= 60).sum() / len(dose_in_struct) * 100,
        'V50': (dose_in_struct >= 50).sum() / len(dose_in_struct) * 100,
    }

def print_dvh_metrics_table(case_data: Dict, model: str = 'ddpm') -> None:
    """Print DVH metrics comparison table."""
    
    gt_dose = case_data['gt_dose'] * RX_DOSE_GY
    pred_dose = case_data[f'{model}_dose'] * RX_DOSE_GY
    masks = case_data['masks']
    
    print(f"\nDVH Metrics: {case_data['case_id']} ({model.upper()})")
    print("="*80)
    
    for struct_idx, struct_name in STRUCTURE_NAMES.items():
        if struct_idx >= masks.shape[0]:
            continue
        
        mask = masks[struct_idx]
        if mask.sum() == 0:
            continue
        
        gt_metrics = compute_dvh_metrics(gt_dose, mask)
        pred_metrics = compute_dvh_metrics(pred_dose, mask)
        
        print(f"\n{struct_name}:")
        print(f"  {'Metric':<10} {'GT':>10} {'Pred':>10} {'Diff':>10}")
        print(f"  {'-'*42}")
        
        for key in ['D95', 'D50', 'Dmean', 'Dmax']:
            gt_val = gt_metrics.get(key, 0)
            pred_val = pred_metrics.get(key, 0)
            diff = pred_val - gt_val
            print(f"  {key:<10} {gt_val:>10.2f} {pred_val:>10.2f} {diff:>+10.2f} Gy")
        
        for key in ['V70', 'V60', 'V50']:
            gt_val = gt_metrics.get(key, 0)
            pred_val = pred_metrics.get(key, 0)
            diff = pred_val - gt_val
            print(f"  {key:<10} {gt_val:>10.1f} {pred_val:>10.1f} {diff:>+10.1f} %")

if case_data and 'ddpm_dose' in case_data:
    print_dvh_metrics_table(case_data, model='ddpm')

## 9. Clinical Constraints Analysis

In [None]:
def plot_clinical_constraints(results: Dict) -> plt.Figure:
    """Visualize clinical constraint pass/fail across cases."""
    
    if 'clinical_constraints' not in results:
        print("No clinical constraints data available")
        return None
    
    # Extract violation data
    cases = [r['case_id'] for r in results['per_case_results']]
    violations_per_case = [
        len(r.get('clinical_constraints', {}).get('violations', []))
        for r in results['per_case_results']
    ]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Bar chart of violations per case
    colors = [COLORS['pass'] if v == 0 else COLORS['fail'] for v in violations_per_case]
    axes[0].bar(range(len(cases)), violations_per_case, color=colors)
    axes[0].set_xticks(range(len(cases)))
    axes[0].set_xticklabels(cases, rotation=45, ha='right')
    axes[0].set_xlabel('Case')
    axes[0].set_ylabel('Number of Violations')
    axes[0].set_title('Clinical Constraint Violations by Case')
    axes[0].axhline(y=0, color='green', linestyle='--', alpha=0.5)
    
    # Pie chart of overall pass/fail
    cc = results['clinical_constraints']
    passed = cc['cases_all_passed']
    failed = results['n_cases'] - passed
    
    axes[1].pie(
        [passed, failed],
        labels=['Passed', 'Failed'],
        colors=[COLORS['pass'], COLORS['fail']],
        autopct='%1.0f%%',
        startangle=90,
        explode=(0.05, 0),
    )
    axes[1].set_title(f'Cases Meeting All Constraints\n({passed}/{results["n_cases"]})')
    
    plt.tight_layout()
    return fig

if ddpm_results:
    fig = plot_clinical_constraints(ddpm_results)
    if fig:
        fig.savefig(FIGURES_DIR / 'clinical_constraints.png')
        plt.show()

In [None]:
def print_violation_summary(results: Dict) -> None:
    """Print summary of clinical constraint violations."""
    
    if 'clinical_constraints' not in results:
        print("No clinical constraints data available")
        return
    
    cc = results['clinical_constraints']
    
    print("\n" + "="*60)
    print("CLINICAL CONSTRAINTS SUMMARY")
    print("="*60)
    print(f"Cases passing ALL constraints: {cc['cases_all_passed']}/{results['n_cases']}")
    print(f"Total violations: {cc['total_violations']}")
    
    if cc['violation_counts']:
        print("\nViolation breakdown:")
        for violation, count in sorted(cc['violation_counts'].items(), key=lambda x: -x[1]):
            print(f"  {violation}: {count} case(s)")
    else:
        print("\n✓ No violations!")
    
    print("="*60)

if ddpm_results:
    print_violation_summary(ddpm_results)

## 10. Error Analysis

In [None]:
def plot_error_histogram(case_data: Dict, model: str = 'ddpm') -> plt.Figure:
    """Plot histogram of dose errors."""
    
    gt_dose = case_data['gt_dose'] * RX_DOSE_GY
    pred_dose = case_data[f'{model}_dose'] * RX_DOSE_GY
    
    # Only consider voxels with significant dose
    mask = gt_dose > 5  # >5 Gy
    errors = (pred_dose[mask] - gt_dose[mask])
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Error histogram
    axes[0].hist(errors, bins=100, color=COLORS['diff'], alpha=0.7, edgecolor='black')
    axes[0].axvline(x=0, color='black', linestyle='-', linewidth=2)
    axes[0].axvline(x=errors.mean(), color='red', linestyle='--', label=f'Mean: {errors.mean():.2f} Gy')
    axes[0].set_xlabel('Error (Gy)')
    axes[0].set_ylabel('Frequency')
    axes[0].set_title(f'Dose Error Distribution (Dose > 5 Gy)\nCase: {case_data["case_id"]}')
    axes[0].legend()
    
    # Statistics
    stats_text = f"""
    Statistics (Dose > 5 Gy region):
    
    Mean Error: {errors.mean():.3f} Gy
    Std Error:  {errors.std():.3f} Gy
    MAE:        {np.abs(errors).mean():.3f} Gy
    
    Percentiles:
      5th:  {np.percentile(errors, 5):.2f} Gy
      50th: {np.percentile(errors, 50):.2f} Gy
      95th: {np.percentile(errors, 95):.2f} Gy
    
    Max Underdose: {errors.min():.2f} Gy
    Max Overdose:  {errors.max():.2f} Gy
    """
    
    axes[1].text(0.1, 0.5, stats_text, transform=axes[1].transAxes,
                fontsize=11, verticalalignment='center', fontfamily='monospace')
    axes[1].axis('off')
    axes[1].set_title('Error Statistics')
    
    plt.tight_layout()
    return fig

if case_data and 'ddpm_dose' in case_data:
    fig = plot_error_histogram(case_data, model='ddpm')
    fig.savefig(FIGURES_DIR / f"{case_data['case_id']}_error_histogram.png")
    plt.show()

## 11. Publication Figure: Combined Results

In [None]:
def create_publication_figure(case_data: Dict, model: str = 'ddpm') -> plt.Figure:
    """Create a publication-ready combined figure."""
    
    gt_dose = case_data['gt_dose'] * RX_DOSE_GY
    pred_dose = case_data[f'{model}_dose'] * RX_DOSE_GY
    diff = pred_dose - gt_dose
    masks = case_data['masks']
    ct = case_data['ct']
    
    # Get mid-slices
    z_mid = gt_dose.shape[2] // 2
    
    fig = plt.figure(figsize=(16, 10))
    gs = gridspec.GridSpec(2, 4, figure=fig, wspace=0.3, hspace=0.3)
    
    dose_vmax = max(gt_dose[:,:,z_mid].max(), pred_dose[:,:,z_mid].max())
    
    # Row 1: CT, GT Dose, Pred Dose, Difference
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.imshow(ct[:,:,z_mid], cmap='gray')
    ax1.set_title('CT')
    ax1.axis('off')
    
    ax2 = fig.add_subplot(gs[0, 1])
    im2 = ax2.imshow(gt_dose[:,:,z_mid], cmap='jet', vmin=0, vmax=dose_vmax)
    ax2.set_title('Ground Truth')
    ax2.axis('off')
    plt.colorbar(im2, ax=ax2, label='Gy', shrink=0.8)
    
    ax3 = fig.add_subplot(gs[0, 2])
    im3 = ax3.imshow(pred_dose[:,:,z_mid], cmap='jet', vmin=0, vmax=dose_vmax)
    ax3.set_title('Prediction')
    ax3.axis('off')
    plt.colorbar(im3, ax=ax3, label='Gy', shrink=0.8)
    
    ax4 = fig.add_subplot(gs[0, 3])
    im4 = ax4.imshow(diff[:,:,z_mid], cmap='RdBu_r', vmin=-5, vmax=5)
    ax4.set_title('Difference')
    ax4.axis('off')
    plt.colorbar(im4, ax=ax4, label='Gy', shrink=0.8)
    
    # Row 2: DVH and Statistics
    ax5 = fig.add_subplot(gs[1, :2])
    
    for struct_name in ['PTV70', 'Rectum', 'Bladder']:
        struct_idx = None
        for idx, name in STRUCTURE_NAMES.items():
            if name == struct_name:
                struct_idx = idx
                break
        if struct_idx is None or struct_idx >= masks.shape[0]:
            continue
        
        mask = masks[struct_idx]
        if mask.sum() == 0:
            continue
        
        color = STRUCTURE_COLORS.get(struct_name, 'gray')
        gt_x, gt_y = compute_dvh(gt_dose, mask)
        pred_x, pred_y = compute_dvh(pred_dose, mask)
        
        if len(gt_x) > 0:
            ax5.plot(gt_x, gt_y, color=color, linestyle='-', linewidth=2, label=f'{struct_name} (GT)')
            ax5.plot(pred_x, pred_y, color=color, linestyle='--', linewidth=2, label=f'{struct_name} (Pred)')
    
    ax5.set_xlabel('Dose (Gy)')
    ax5.set_ylabel('Volume (%)')
    ax5.set_title('DVH Comparison')
    ax5.set_xlim(0, 80)
    ax5.set_ylim(0, 105)
    ax5.grid(True, alpha=0.3)
    ax5.legend(loc='upper right', fontsize=9)
    
    # Statistics panel
    ax6 = fig.add_subplot(gs[1, 2:])
    
    mask_high = gt_dose > 5
    errors = pred_dose[mask_high] - gt_dose[mask_high]
    
    stats_text = f"""
    Case: {case_data['case_id']}
    Model: {model.upper()}
    
    Overall Metrics:
    ---------------------
    MAE:  {np.abs(errors).mean():.2f} Gy
    RMSE: {np.sqrt((errors**2).mean()):.2f} Gy
    
    Error Range:
    ---------------------
    Min:  {errors.min():.2f} Gy
    Max:  {errors.max():.2f} Gy
    
    Error Percentiles:
    ---------------------
    5th:  {np.percentile(errors, 5):.2f} Gy
    95th: {np.percentile(errors, 95):.2f} Gy
    """
    
    ax6.text(0.1, 0.5, stats_text, transform=ax6.transAxes,
            fontsize=11, verticalalignment='center', fontfamily='monospace',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    ax6.axis('off')
    ax6.set_title('Statistics')
    
    fig.suptitle(f'VMAT Dose Prediction Results', fontsize=16, fontweight='bold')
    
    return fig

if case_data and 'ddpm_dose' in case_data:
    fig = create_publication_figure(case_data, model='ddpm')
    fig.savefig(FIGURES_DIR / f"{case_data['case_id']}_publication.png", dpi=300)
    fig.savefig(FIGURES_DIR / f"{case_data['case_id']}_publication.pdf")
    plt.show()
    print(f"Saved publication figure (PNG and PDF)")

## 12. Generate All Figures for All Cases

In [None]:
def generate_all_figures(test_cases: List[str], model: str = 'ddpm') -> None:
    """Generate standard figure set for all test cases."""
    
    print(f"Generating figures for {len(test_cases)} cases...")
    
    for case_id in test_cases:
        print(f"\n  Processing {case_id}...")
        
        try:
            case_data = load_case_data(case_id)
            
            if f'{model}_dose' not in case_data:
                print(f"    ⚠ No {model} prediction found, skipping")
                continue
            
            # Dose comparison
            fig = plot_dose_comparison(case_data, model=model)
            fig.savefig(FIGURES_DIR / f"{case_id}_dose.png")
            plt.close(fig)
            
            # DVH
            fig = plot_dvh_comparison(case_data, model=model)
            fig.savefig(FIGURES_DIR / f"{case_id}_dvh.png")
            plt.close(fig)
            
            # Publication figure
            fig = create_publication_figure(case_data, model=model)
            fig.savefig(FIGURES_DIR / f"{case_id}_publication.png", dpi=300)
            plt.close(fig)
            
            print(f"    ✓ Generated 3 figures")
            
        except Exception as e:
            print(f"    ✗ Error: {e}")
    
    print(f"\nAll figures saved to: {FIGURES_DIR}")

# Uncomment to generate all figures:
# generate_all_figures(test_cases, model='ddpm')

## 13. Export Summary Report

In [None]:
def generate_summary_report(
    ddpm_results: Dict,
    baseline_results: Dict = None,
    output_path: Path = None,
) -> str:
    """Generate markdown summary report."""
    
    if output_path is None:
        output_path = FIGURES_DIR / 'summary_report.md'
    
    lines = [
        "# VMAT Dose Prediction: Results Summary",
        "",
        f"Generated: {np.datetime64('now')}",
        "",
        "## Model Performance",
        "",
        "| Metric | DDPM |" + (" Baseline |" if baseline_results else ""),
        "|--------|------|" + ("----------|" if baseline_results else ""),
    ]
    
    # MAE
    ddpm_mae = f"{ddpm_results['aggregate_metrics']['mae_gy_mean']:.2f} ± {ddpm_results['aggregate_metrics']['mae_gy_std']:.2f}"
    line = f"| MAE (Gy) | {ddpm_mae} |"
    if baseline_results:
        base_mae = f"{baseline_results['aggregate_metrics']['mae_gy_mean']:.2f} ± {baseline_results['aggregate_metrics']['mae_gy_std']:.2f}"
        line += f" {base_mae} |"
    lines.append(line)
    
    # Gamma
    if 'gamma_pass_rate_mean' in ddpm_results['aggregate_metrics']:
        ddpm_gamma = f"{ddpm_results['aggregate_metrics']['gamma_pass_rate_mean']:.1f} ± {ddpm_results['aggregate_metrics']['gamma_pass_rate_std']:.1f}"
        line = f"| Gamma 3%/3mm (%) | {ddpm_gamma} |"
        if baseline_results and 'gamma_pass_rate_mean' in baseline_results['aggregate_metrics']:
            base_gamma = f"{baseline_results['aggregate_metrics']['gamma_pass_rate_mean']:.1f} ± {baseline_results['aggregate_metrics']['gamma_pass_rate_std']:.1f}"
            line += f" {base_gamma} |"
        lines.append(line)
    
    lines.extend([
        "",
        "## Goal Assessment",
        "",
        f"- MAE < 2.0 Gy: {'✓ MET' if ddpm_results['goal_assessment']['mae_goal_met'] else '✗ NOT MET'}",
    ])
    
    if ddpm_results['goal_assessment'].get('gamma_goal_met') is not None:
        lines.append(f"- Gamma > 95%: {'✓ MET' if ddpm_results['goal_assessment']['gamma_goal_met'] else '✗ NOT MET'}")
    
    if 'clinical_constraints' in ddpm_results:
        cc = ddpm_results['clinical_constraints']
        lines.extend([
            "",
            "## Clinical Constraints",
            "",
            f"- Cases passing all constraints: {cc['cases_all_passed']}/{ddpm_results['n_cases']}",
            f"- Total violations: {cc['total_violations']}",
        ])
        
        if cc['violation_counts']:
            lines.append("")
            lines.append("### Most Common Violations")
            lines.append("")
            for violation, count in sorted(cc['violation_counts'].items(), key=lambda x: -x[1])[:5]:
                lines.append(f"- {violation}: {count} case(s)")
    
    report = "\n".join(lines)
    
    with open(output_path, 'w') as f:
        f.write(report)
    
    print(f"Summary report saved to: {output_path}")
    return report

if ddpm_results:
    report = generate_summary_report(ddpm_results, baseline_results)
    print("\n" + report)

---

## Summary

This notebook provides:

1. **Model comparison** - Side-by-side metrics tables
2. **Per-case analysis** - Identify best/worst cases
3. **Dose visualization** - Multi-view, overlay, difference maps
4. **DVH analysis** - Curves and metrics tables
5. **Clinical constraints** - Pass/fail tracking
6. **Error analysis** - Histograms, statistics
7. **Publication figures** - Combined, high-DPI outputs
8. **Batch processing** - Generate figures for all cases
9. **Summary report** - Markdown export

All figures are saved to the `figures/` directory.