# VMAT Diffusion: NPZ Verification Notebook (v2.2)

This notebook provides visual and quantitative verification of preprocessed `.npz` files for the VMAT diffusion model.

## Purpose
- Visual inspection of CT, dose, and mask alignment
- SDF (Signed Distance Field) visualization and validation
- Beam geometry and **MLC leaf position** review (Phase 2 data)
- Quantitative validation of dose distributions
- Truncation analysis for boundary issues
- Detection of preprocessing errors before training

## Expected .npz Structure (v2.2)
| Key | Shape | Description |
|-----|-------|-------------|
| `ct` | (512, 512, 256) | CT volume, normalized [0, 1] |
| `dose` | (512, 512, 256) | Dose volume, normalized to Rx |
| `masks` | (8, 512, 512, 256) | Binary masks for structures |
| `masks_sdf` | (8, 512, 512, 256) | Signed distance fields [-1, 1] |
| `constraints` | (13,) | Prescription targets + OAR constraints |
| `beam0_mlc_a` | (n_cp, n_leaves) | MLC bank A positions (v2.2) |
| `beam0_mlc_b` | (n_cp, n_leaves) | MLC bank B positions (v2.2) |
| `metadata` | dict | Case info, beam geometry, validation |

## Mask Channel Mapping
| Channel | Structure |
|---------|----------|
| 0 | PTV70 |
| 1 | PTV56 |
| 2 | Prostate |
| 3 | Rectum |
| 4 | Bladder |
| 5 | Femur_L |
| 6 | Femur_R |
| 7 | Bowel |

In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from pathlib import Path
import warnings
import json

# For interactive slice viewing (optional)
try:
    from ipywidgets import interact, IntSlider
    HAS_WIDGETS = True
except ImportError:
    HAS_WIDGETS = False
    print("ipywidgets not available - interactive sliders disabled")

%matplotlib inline
plt.rcParams['figure.figsize'] = [16, 10]
plt.rcParams['figure.dpi'] = 100

In [None]:
# Configuration - EDIT THESE PATHS
NPZ_DIR = Path("./processed_npz").expanduser()
# Or specify a single file directly:
# NPZ_FILE = Path("./processed_npz/case_0001.npz").expanduser()

# Structure names for display
STRUCTURE_NAMES = {
    0: 'PTV70',
    1: 'PTV56',
    2: 'Prostate',
    3: 'Rectum',
    4: 'Bladder',
    5: 'Femur_L',
    6: 'Femur_R',
    7: 'Bowel'
}

# Colors for mask overlays (RGBA)
STRUCTURE_COLORS = {
    0: (1.0, 0.0, 0.0, 0.5),   # PTV70 - Red
    1: (1.0, 0.5, 0.0, 0.5),   # PTV56 - Orange
    2: (0.0, 1.0, 0.0, 0.4),   # Prostate - Green
    3: (0.6, 0.3, 0.0, 0.5),   # Rectum - Brown
    4: (1.0, 1.0, 0.0, 0.5),   # Bladder - Yellow
    5: (0.0, 0.0, 1.0, 0.4),   # Femur_L - Blue
    6: (0.0, 0.5, 1.0, 0.4),   # Femur_R - Light Blue
    7: (1.0, 0.0, 1.0, 0.4),   # Bowel - Magenta
}

In [None]:
# Helper functions

def load_npz(filepath):
    """Load .npz file and return as dict with validation for v2.2 structure."""
    data = np.load(filepath, allow_pickle=True)
    result = {key: data[key] for key in data.files}
    
    # Detect version based on available keys
    has_sdf = 'masks_sdf' in result
    has_mlc = 'beam0_mlc_a' in result or 'beam0_mlc_b' in result
    
    if has_mlc:
        version = 'v2.2'
    elif has_sdf:
        version = 'v2.1'
    else:
        version = 'v2.0 or earlier'
    
    print(f"Loaded: {filepath.name}")
    print(f"  Detected version: {version}")
    print(f"  Keys: {list(result.keys())}")
    print(f"  CT shape: {result.get('ct', np.array([])).shape}")
    print(f"  Dose shape: {result.get('dose', np.array([])).shape}")
    print(f"  Masks shape: {result.get('masks', np.array([])).shape}")
    if has_sdf:
        print(f"  Masks SDF shape: {result.get('masks_sdf', np.array([])).shape}")
    if has_mlc:
        mlc_a = result.get('beam0_mlc_a', np.array([]))
        print(f"  MLC data shape: {mlc_a.shape} (beam0_mlc_a)")
    print(f"  Constraints shape: {result.get('constraints', np.array([])).shape}")
    
    # Extract metadata
    if 'metadata' in result:
        metadata = result['metadata'].item() if isinstance(result['metadata'], np.ndarray) else result['metadata']
        print(f"  Case type: {metadata.get('case_type', {}).get('type', 'unknown')}")
        print(f"  Script version: {metadata.get('script_version', 'unknown')}")
    
    return result


def get_metadata(data):
    """Safely extract metadata dict from data."""
    if 'metadata' not in data:
        return {}
    metadata = data['metadata']
    if isinstance(metadata, np.ndarray):
        return metadata.item()
    return metadata


def get_structure_stats(masks, dose):
    """Calculate dose statistics for each structure."""
    stats = {}
    for ch, name in STRUCTURE_NAMES.items():
        if ch < masks.shape[0]:
            mask = masks[ch] > 0
            voxel_count = mask.sum()
            if voxel_count > 0:
                dose_in_struct = dose[mask]
                stats[name] = {
                    'voxels': int(voxel_count),
                    'volume_cc': float(voxel_count * 1 * 1 * 2 / 1000),  # 1x1x2mm spacing
                    'dose_mean': float(dose_in_struct.mean()),
                    'dose_min': float(dose_in_struct.min()),
                    'dose_max': float(dose_in_struct.max()),
                    'dose_std': float(dose_in_struct.std()),
                    'D95': float(np.percentile(dose_in_struct, 5)),
                    'D5': float(np.percentile(dose_in_struct, 95)),
                }
            else:
                stats[name] = {'voxels': 0, 'volume_cc': 0.0, 'dose_mean': None}
    return stats

In [None]:
def validate_case(data):
    """Run validation checks on a loaded case."""
    checks = {}
    metadata = get_metadata(data)
    
    # Version detection
    checks['has_sdf'] = 'masks_sdf' in data
    checks['has_mlc'] = 'beam0_mlc_a' in data
    checks['has_beam_geometry'] = 'beam_geometry' in metadata
    checks['script_version'] = metadata.get('script_version', 'unknown')
    
    # CT validation
    ct = data.get('ct', np.array([]))
    if ct.size > 0:
        checks['ct_min'] = float(ct.min())
        checks['ct_max'] = float(ct.max())
        checks['ct_range_valid'] = (ct.min() >= -0.01) and (ct.max() <= 1.01)
    
    # Dose validation
    dose = data.get('dose', np.array([]))
    if dose.size > 0:
        checks['dose_min'] = float(dose.min())
        checks['dose_max'] = float(dose.max())
        checks['dose_nonneg'] = dose.min() >= -0.001
        checks['dose_reasonable'] = dose.max() < 1.5  # <150% of Rx
    
    # SDF validation
    if checks['has_sdf']:
        sdf = data['masks_sdf']
        checks['sdf_min'] = float(sdf.min())
        checks['sdf_max'] = float(sdf.max())
        checks['sdf_range_valid'] = (sdf.min() >= -1.01) and (sdf.max() <= 1.01)
    
    # MLC validation (v2.2)
    if checks['has_mlc']:
        mlc_a = data.get('beam0_mlc_a', np.array([]))
        mlc_b = data.get('beam0_mlc_b', np.array([]))
        checks['mlc_shape'] = mlc_a.shape
        checks['mlc_range'] = (float(mlc_a.min()), float(mlc_b.max()))
        checks['n_control_points'] = mlc_a.shape[0] if mlc_a.ndim > 0 else 0
        checks['n_leaves'] = mlc_a.shape[1] if mlc_a.ndim > 1 else 0
    
    # Structure validation
    masks = data.get('masks', np.array([]))
    if masks.size > 0:
        checks['ptv70_exists'] = masks[0].sum() > 0
        checks['ptv56_exists'] = masks[1].sum() > 0 if masks.shape[0] > 1 else False
        
        if checks['ptv70_exists'] and dose.size > 0:
            ptv70_mask = masks[0] > 0
            ptv70_dose = dose[ptv70_mask]
            checks['ptv70_dose_mean'] = float(ptv70_dose.mean())
            checks['ptv70_dose_adequate'] = 0.85 <= ptv70_dose.mean() <= 1.15
            
            # Registration check
            checks['dose_in_ptv'] = float(dose[ptv70_mask].mean())
            checks['dose_outside_ptv'] = float(dose[~ptv70_mask].mean())
            checks['dose_higher_in_ptv'] = checks['dose_in_ptv'] > checks['dose_outside_ptv']
    
    # Truncation info
    checks['truncation_info'] = metadata.get('truncation_info', {})
    checks['boundary_warnings'] = []
    for name, info in checks['truncation_info'].items():
        if isinstance(info, dict) and info.get('truncated') and info.get('truncation_percent', 0) > 5:
            checks['boundary_warnings'].append(f"{name}: {info['truncation_percent']:.1f}% truncated")
    
    return checks

---
## 1. Load Data

In [None]:
# Find available files
if 'NPZ_FILE' in dir() and NPZ_FILE.exists():
    npz_files = [NPZ_FILE]
else:
    npz_files = sorted(list(NPZ_DIR.glob("*.npz")))

print(f"Found {len(npz_files)} .npz files")
for i, f in enumerate(npz_files[:10]):
    print(f"  [{i}] {f.name}")
if len(npz_files) > 10:
    print(f"  ... and {len(npz_files) - 10} more")

In [None]:
# Select a file to analyze
FILE_INDEX = 0  # Change this to analyze different files

if npz_files:
    selected_file = npz_files[FILE_INDEX]
    data = load_npz(selected_file)
    metadata = get_metadata(data)
else:
    print("No files found! Check NPZ_DIR path.")

---
## 2. Beam Geometry & MLC Data (v2.2)

In [None]:
# Display beam geometry summary
beam_geometry = metadata.get('beam_geometry', {})

if beam_geometry:
    print("=" * 70)
    print("BEAM GEOMETRY")
    print("=" * 70)
    print(f"  Plan label: {beam_geometry.get('plan_label', 'N/A')}")
    print(f"  Number of beams: {beam_geometry.get('num_beams', 'N/A')}")
    print(f"  Total MU: {beam_geometry.get('total_mu', 0):.1f}")
    
    beams = beam_geometry.get('beams', [])
    if beams:
        print(f"\n  {'Beam':<15} {'Start':>8} {'Stop':>8} {'Direction':<12} {'Coll':>8} {'MU':>10} {'CPs':>6}")
        print("  " + "-" * 70)
        for i, beam in enumerate(beams):
            name = beam.get('beam_name', f'Beam {i+1}') or f'Beam {i+1}'
            start = beam.get('arc_start_angle', 'N/A')
            stop = beam.get('arc_stop_angle', 'N/A')
            direction = beam.get('arc_direction', 'N/A') or 'N/A'
            coll = beam.get('collimator_angle', 'N/A')
            mu = beam.get('final_mu', 'N/A')
            cps = beam.get('num_control_points', 'N/A')
            
            start_str = f"{start:.1f}" if isinstance(start, (int, float)) else str(start)
            stop_str = f"{stop:.1f}" if isinstance(stop, (int, float)) else str(stop)
            coll_str = f"{coll:.1f}" if isinstance(coll, (int, float)) else str(coll)
            mu_str = f"{mu:.1f}" if isinstance(mu, (int, float)) else str(mu)
            
            print(f"  {name:<15} {start_str:>8} {stop_str:>8} {direction:<12} {coll_str:>8} {mu_str:>10} {cps:>6}")
else:
    print("No beam geometry available (v2.1+ feature)")

In [None]:
# MLC Data Analysis (v2.2)
has_mlc = 'beam0_mlc_a' in data

if has_mlc:
    print("=" * 70)
    print("MLC DATA (v2.2)")
    print("=" * 70)
    
    # Find all MLC arrays
    mlc_keys = [k for k in data.keys() if k.startswith('beam') and 'mlc' in k]
    print(f"  Available MLC arrays: {mlc_keys}")
    
    for beam_idx in range(10):  # Check up to 10 beams
        key_a = f'beam{beam_idx}_mlc_a'
        key_b = f'beam{beam_idx}_mlc_b'
        
        if key_a in data and key_b in data:
            mlc_a = data[key_a]
            mlc_b = data[key_b]
            
            print(f"\n  Beam {beam_idx}:")
            print(f"    MLC A shape: {mlc_a.shape}")
            print(f"    MLC B shape: {mlc_b.shape}")
            print(f"    Control points: {mlc_a.shape[0]}")
            print(f"    Leaves per bank: {mlc_a.shape[1]}")
            print(f"    MLC A range: [{mlc_a.min():.1f}, {mlc_a.max():.1f}] mm")
            print(f"    MLC B range: [{mlc_b.min():.1f}, {mlc_b.max():.1f}] mm")
            
            # Aperture statistics
            aperture = mlc_b - mlc_a  # Gap between banks
            print(f"    Aperture range: [{aperture.min():.1f}, {aperture.max():.1f}] mm")
            print(f"    Mean aperture: {aperture.mean():.1f} mm")
else:
    print("No MLC data available (v2.2 feature)")
    print("MLC positions are extracted in preprocessing v2.2+")

In [None]:
# Control Point Data (v2.2)
if beam_geometry and 'beams' in beam_geometry:
    for i, beam in enumerate(beam_geometry.get('beams', [])):
        cp_data = beam.get('control_point_data', {})
        
        if cp_data:
            print(f"\n{'='*70}")
            print(f"CONTROL POINT DATA - Beam {i}")
            print(f"{'='*70}")
            
            # Gantry angles
            gantry = cp_data.get('gantry_angles', [])
            if len(gantry) > 0:
                print(f"  Gantry angles: {len(gantry)} points")
                print(f"    Range: [{min(gantry):.1f}°, {max(gantry):.1f}°]")
            
            # Cumulative MU
            cum_mu = cp_data.get('cumulative_meterset_weight', [])
            if len(cum_mu) > 0:
                print(f"  Cumulative MU: {len(cum_mu)} points")
                print(f"    Range: [{min(cum_mu):.3f}, {max(cum_mu):.3f}]")
            
            # Dose rates
            dose_rates = cp_data.get('dose_rates', [])
            valid_rates = [r for r in dose_rates if r is not None]
            if valid_rates:
                print(f"  Dose rates: {len(dose_rates)} points ({len(valid_rates)} valid)")
                print(f"    Range: [{min(valid_rates):.0f}, {max(valid_rates):.0f}] MU/min")
            
            # Jaw positions
            for jaw in ['jaw_x1', 'jaw_x2', 'jaw_y1', 'jaw_y2']:
                positions = cp_data.get(jaw, [])
                if len(positions) > 0:
                    print(f"  {jaw}: [{min(positions):.1f}, {max(positions):.1f}] mm")

In [None]:
# Visualize MLC Positions
if has_mlc:
    mlc_a = data['beam0_mlc_a']
    mlc_b = data['beam0_mlc_b']
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. MLC positions over control points (heatmap)
    ax1 = axes[0, 0]
    aperture = mlc_b - mlc_a
    im1 = ax1.imshow(aperture.T, aspect='auto', cmap='viridis', 
                     extent=[0, aperture.shape[0], 0, aperture.shape[1]])
    ax1.set_xlabel('Control Point')
    ax1.set_ylabel('Leaf Pair')
    ax1.set_title('MLC Aperture (Bank B - Bank A) [mm]')
    plt.colorbar(im1, ax=ax1, label='Aperture (mm)')
    
    # 2. Aperture profile at different control points
    ax2 = axes[0, 1]
    n_cp = aperture.shape[0]
    cp_indices = [0, n_cp//4, n_cp//2, 3*n_cp//4, n_cp-1]
    for cp in cp_indices:
        ax2.plot(aperture[cp], label=f'CP {cp}')
    ax2.set_xlabel('Leaf Pair')
    ax2.set_ylabel('Aperture (mm)')
    ax2.set_title('Aperture Profile at Selected Control Points')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # 3. MLC bank positions
    ax3 = axes[1, 0]
    central_leaf = mlc_a.shape[1] // 2
    ax3.plot(mlc_a[:, central_leaf], label='Bank A', color='blue')
    ax3.plot(mlc_b[:, central_leaf], label='Bank B', color='red')
    ax3.fill_between(range(n_cp), mlc_a[:, central_leaf], mlc_b[:, central_leaf], 
                     alpha=0.3, color='green', label='Aperture')
    ax3.set_xlabel('Control Point')
    ax3.set_ylabel('Position (mm)')
    ax3.set_title(f'Central Leaf Pair ({central_leaf}) Position vs Control Point')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # 4. Mean aperture over time
    ax4 = axes[1, 1]
    mean_aperture = aperture.mean(axis=1)
    ax4.plot(mean_aperture, color='green', linewidth=2)
    ax4.fill_between(range(n_cp), 0, mean_aperture, alpha=0.3, color='green')
    ax4.set_xlabel('Control Point')
    ax4.set_ylabel('Mean Aperture (mm)')
    ax4.set_title('Mean Aperture Across All Leaves')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("No MLC data to visualize")

---
## 3. Validation Checks

In [None]:
# Run validation
checks = validate_case(data)

print("=" * 60)
print("VALIDATION RESULTS")
print("=" * 60)

# Version info
print(f"\n[Version Detection]")
print(f"  Script version: {checks['script_version']}")
print(f"  Has SDFs: {'✓ Yes' if checks['has_sdf'] else '✗ No'}")
print(f"  Has MLC data: {'✓ Yes' if checks['has_mlc'] else '✗ No'}")
print(f"  Has beam geometry: {'✓ Yes' if checks['has_beam_geometry'] else '✗ No'}")

if checks.get('has_mlc'):
    print(f"  MLC shape: {checks.get('mlc_shape')}")
    print(f"  Control points: {checks.get('n_control_points')}")
    print(f"  Leaves/bank: {checks.get('n_leaves')}")

# CT
print(f"\n[CT Volume]")
print(f"  Range valid [0,1]: {'✓ PASS' if checks.get('ct_range_valid') else '✗ FAIL'}")
print(f"  Actual range: [{checks.get('ct_min', 0):.4f}, {checks.get('ct_max', 0):.4f}]")

# Dose
print(f"\n[Dose Volume]")
print(f"  Non-negative: {'✓ PASS' if checks.get('dose_nonneg') else '✗ FAIL'}")
print(f"  Max < 150%: {'✓ PASS' if checks.get('dose_reasonable') else '✗ FAIL'}")
print(f"  Actual range: [{checks.get('dose_min', 0):.4f}, {checks.get('dose_max', 0):.4f}]")

# SDFs
if checks['has_sdf']:
    print(f"\n[Signed Distance Fields]")
    print(f"  Range valid [-1,1]: {'✓ PASS' if checks.get('sdf_range_valid') else '✗ FAIL'}")
    print(f"  Actual range: [{checks.get('sdf_min', 0):.4f}, {checks.get('sdf_max', 0):.4f}]")

# PTV70
print(f"\n[PTV70]")
print(f"  Exists: {'✓ PASS' if checks.get('ptv70_exists') else '✗ FAIL'}")
if checks.get('ptv70_dose_mean') is not None:
    rx_dose = metadata.get('normalization_dose_gy', 70.0)
    print(f"  Mean dose: {checks['ptv70_dose_mean']:.3f} (target: ~1.0)")
    print(f"  Mean dose (Gy): {checks['ptv70_dose_mean']*rx_dose:.1f} Gy")
    print(f"  Dose adequate [0.85-1.15]: {'✓ PASS' if checks.get('ptv70_dose_adequate') else '✗ FAIL'}")

# Registration
print(f"\n[Registration Check]")
if 'dose_higher_in_ptv' in checks:
    print(f"  Dose higher in PTV: {'✓ PASS' if checks['dose_higher_in_ptv'] else '✗ FAIL - CRITICAL'}")

---
## 4. Visual Inspection

In [None]:
# Quick visualization of CT, dose, and structures
def plot_overview(ct, dose, masks, slice_idx=None):
    """Plot CT, dose overlay, and structure contours."""
    if slice_idx is None:
        slice_idx = ct.shape[2] // 2
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # CT
    axes[0].imshow(ct[:, :, slice_idx], cmap='gray', vmin=0, vmax=1)
    axes[0].set_title(f'CT (slice {slice_idx})')
    axes[0].axis('off')
    
    # Dose
    axes[1].imshow(ct[:, :, slice_idx], cmap='gray', vmin=0, vmax=1)
    dose_masked = np.ma.masked_where(dose[:, :, slice_idx] < 0.1, dose[:, :, slice_idx])
    axes[1].imshow(dose_masked, cmap='jet', alpha=0.6, vmin=0, vmax=1.1)
    axes[1].set_title('Dose Overlay')
    axes[1].axis('off')
    
    # Structures
    axes[2].imshow(ct[:, :, slice_idx], cmap='gray', vmin=0, vmax=1)
    for ch, name in STRUCTURE_NAMES.items():
        if ch < masks.shape[0] and masks[ch, :, :, slice_idx].sum() > 0:
            color = STRUCTURE_COLORS.get(ch, (1, 1, 1, 0.5))
            axes[2].contour(masks[ch, :, :, slice_idx], levels=[0.5], 
                          colors=[color[:3]], linewidths=2)
    axes[2].set_title('Structure Contours')
    axes[2].axis('off')
    
    plt.tight_layout()
    return fig

# Find slice with most PTV70 content
ptv70_per_slice = data['masks'][0].sum(axis=(0, 1))
best_slice = np.argmax(ptv70_per_slice)

fig = plot_overview(data['ct'], data['dose'], data['masks'], slice_idx=best_slice)
plt.show()

In [None]:
# SDF Visualization
if 'masks_sdf' in data:
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()
    
    slice_idx = best_slice
    
    for ch, name in STRUCTURE_NAMES.items():
        if ch < 8:
            ax = axes[ch]
            sdf_slice = data['masks_sdf'][ch, :, :, slice_idx]
            im = ax.imshow(sdf_slice, cmap='RdBu', vmin=-1, vmax=1)
            ax.contour(data['masks'][ch, :, :, slice_idx], levels=[0.5], colors='black', linewidths=1)
            ax.set_title(f'{name} SDF')
            ax.axis('off')
            plt.colorbar(im, ax=ax, shrink=0.8)
    
    plt.suptitle(f'Signed Distance Fields (slice {slice_idx})', fontsize=14)
    plt.tight_layout()
    plt.show()
else:
    print("No SDF data available")

---
## 5. Structure Statistics

In [None]:
# Calculate and display structure statistics
stats = get_structure_stats(data['masks'], data['dose'])
rx_dose = metadata.get('normalization_dose_gy', 70.0)

print("=" * 90)
print("STRUCTURE STATISTICS")
print("=" * 90)
print(f"{'Structure':<12} {'Voxels':>10} {'Vol (cc)':>10} {'Mean':>8} {'Min':>8} {'Max':>8} {'D95':>8} {'D5':>8}")
print("-" * 90)

for name, s in stats.items():
    if s['dose_mean'] is not None:
        print(f"{name:<12} {s['voxels']:>10,} {s['volume_cc']:>10.1f} "
              f"{s['dose_mean']:>8.3f} {s['dose_min']:>8.3f} {s['dose_max']:>8.3f} "
              f"{s['D95']:>8.3f} {s['D5']:>8.3f}")
    else:
        print(f"{name:<12} {'(empty)':>10}")

print("-" * 90)
print(f"Note: Dose values normalized to {rx_dose:.0f} Gy. Multiply by {rx_dose:.0f} for absolute Gy.")

---
## 6. DVH Plot

In [None]:
def plot_dvh(dose, masks, rx_dose=70.0, title='DVH'):
    """Plot cumulative DVH for all structures."""
    fig, ax = plt.subplots(figsize=(12, 8))
    
    dose_gy = dose * rx_dose
    dose_bins = np.linspace(0, 80, 200)
    
    colors = ['red', 'orange', 'green', 'brown', 'gold', 'blue', 'cyan', 'magenta']
    
    for ch, name in STRUCTURE_NAMES.items():
        if ch < masks.shape[0]:
            mask = masks[ch] > 0
            if mask.sum() > 0:
                struct_dose = dose_gy[mask]
                
                # Calculate cumulative DVH
                dvh = [(struct_dose >= d).sum() / len(struct_dose) * 100 for d in dose_bins]
                
                ax.plot(dose_bins, dvh, label=name, color=colors[ch], linewidth=2)
    
    ax.set_xlabel('Dose (Gy)', fontsize=12)
    ax.set_ylabel('Volume (%)', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.set_xlim(0, 80)
    ax.set_ylim(0, 105)
    ax.grid(True, alpha=0.3)
    ax.legend(loc='upper right', fontsize=10)
    
    ax.axvline(70, color='red', linestyle='--', alpha=0.5)
    ax.axvline(56, color='orange', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    return fig

fig = plot_dvh(data['dose'], data['masks'], rx_dose=rx_dose, title=f"DVH: {selected_file.stem}")
plt.show()

---
## 7. Constraints Vector

In [None]:
# Display constraints vector
constraints = data['constraints']

constraint_names = [
    'PTV70 target',
    'PTV56 target', 
    'PTV50.4 target',
    'Rectum V50',
    'Rectum V60',
    'Rectum V70',
    'Rectum Dmax',
    'Bladder V65',
    'Bladder V70',
    'Bladder V75',
    'Femur V50',
    'Bowel V45',
    'Cord Dmax'
]

print("=" * 50)
print("CONSTRAINTS VECTOR (13 values)")
print("=" * 50)
print(f"{'Index':<6} {'Name':<20} {'Value':>10}")
print("-" * 50)

for i, (name, val) in enumerate(zip(constraint_names, constraints)):
    print(f"{i:<6} {name:<20} {val:>10.4f}")

---
## 8. Summary Report

In [None]:
def generate_summary_report(filepath, data, checks, stats):
    """Generate a text summary report for a case."""
    metadata = get_metadata(data)
    report = []
    report.append("=" * 70)
    report.append(f"VERIFICATION REPORT: {filepath.name}")
    report.append("=" * 70)
    
    # Overall status
    critical_checks = ['ct_range_valid', 'dose_nonneg', 'ptv70_exists', 
                       'ptv70_dose_adequate', 'dose_higher_in_ptv']
    failures = [c for c in critical_checks if c in checks and checks[c] == False]
    
    if failures:
        report.append(f"\n⚠️  STATUS: NEEDS REVIEW")
        report.append(f"   Failed checks: {', '.join(failures)}")
    else:
        report.append(f"\n✓  STATUS: PASSED")
    
    # Version info
    report.append(f"\n   Script version: {checks.get('script_version', 'unknown')}")
    report.append(f"   Has SDFs: {'Yes' if checks.get('has_sdf') else 'No'}")
    report.append(f"   Has MLC data: {'Yes' if checks.get('has_mlc') else 'No'}")
    report.append(f"   Has beam geometry: {'Yes' if checks.get('has_beam_geometry') else 'No'}")
    
    # MLC summary
    if checks.get('has_mlc'):
        report.append(f"   MLC: {checks.get('n_control_points', 0)} CPs × {checks.get('n_leaves', 0)} leaves")
    
    report.append("\n" + "=" * 70)
    return "\n".join(report)

# Generate and print report
report = generate_summary_report(selected_file, data, checks, stats)
print(report)

---
## 9. Batch Validation (Optional)

In [None]:
# Set to True to run batch validation on all files
RUN_BATCH = False

if RUN_BATCH and npz_files:
    print(f"Running batch validation on {len(npz_files)} files...\n")
    
    results = []
    for npz_file in npz_files:
        try:
            d = np.load(npz_file, allow_pickle=True)
            data_tmp = {key: d[key] for key in d.files}
            checks_tmp = validate_case(data_tmp)
            
            critical_pass = all([
                checks_tmp.get('ct_range_valid', False),
                checks_tmp.get('dose_nonneg', False),
                checks_tmp.get('ptv70_exists', False),
            ])
            
            results.append({
                'file': npz_file.name,
                'status': '✓' if critical_pass else '✗',
                'version': checks_tmp.get('script_version', '?'),
                'sdf': '✓' if checks_tmp.get('has_sdf') else '-',
                'mlc': '✓' if checks_tmp.get('has_mlc') else '-',
                'ptv70': f"{checks_tmp.get('ptv70_dose_mean', 0):.3f}" if checks_tmp.get('ptv70_dose_mean') else 'N/A',
            })
        except Exception as e:
            results.append({
                'file': npz_file.name,
                'status': '✗ ERR',
                'version': '?',
                'sdf': '?',
                'mlc': '?',
                'ptv70': 'ERR',
            })
    
    # Print summary table
    print(f"{'File':<30} {'Status':>6} {'Ver':<8} {'SDF':>4} {'MLC':>4} {'PTV70':>8}")
    print("-" * 70)
    for r in results:
        print(f"{r['file']:<30} {r['status']:>6} {r['version']:<8} {r['sdf']:>4} {r['mlc']:>4} {r['ptv70']:>8}")
    
    # Summary
    passed = sum(1 for r in results if r['status'] == '✓')
    with_mlc = sum(1 for r in results if r['mlc'] == '✓')
    
    print(f"\n{'='*70}")
    print(f"BATCH SUMMARY: {passed}/{len(results)} passed, {with_mlc}/{len(results)} have MLC data")
else:
    print("Set RUN_BATCH = True to run batch validation")