# VMAT Diffusion: NPZ Verification Notebook

This notebook provides visual and quantitative verification of preprocessed `.npz` files for the VMAT diffusion model.

## Purpose
- Visual inspection of CT, dose, and mask alignment
- Quantitative validation of dose distributions
- Detection of preprocessing errors before training

## Expected .npz Structure
| Key | Shape | Description |
|-----|-------|-------------|
| `ct` | (512, 512, 256) | CT volume, normalized [0, 1] |
| `dose` | (512, 512, 256) | Dose volume, normalized to 70 Gy |
| `masks` | (8, 512, 512, 256) | Binary masks for structures |
| `constraints` | (13,) | AAPM constraints + PTV type vector |

## Mask Channel Mapping
| Channel | Structure |
|---------|----------|
| 0 | PTV70 |
| 1 | PTV56 |
| 2 | Prostate |
| 3 | Rectum |
| 4 | Bladder |
| 5 | Femur_L |
| 6 | Femur_R |
| 7 | Bowel |

In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from pathlib import Path
import warnings

# For interactive slice viewing (optional)
try:
    from ipywidgets import interact, IntSlider
    HAS_WIDGETS = True
except ImportError:
    HAS_WIDGETS = False
    print("ipywidgets not available - interactive sliders disabled")

%matplotlib inline
plt.rcParams['figure.figsize'] = [16, 10]
plt.rcParams['figure.dpi'] = 100

In [None]:
# Configuration - EDIT THESE PATHS
NPZ_DIR = Path("~/vmat-diffusion-project/processed_npz").expanduser()
# Or specify a single file directly:
# NPZ_FILE = Path("~/vmat-diffusion-project/processed_npz/case_0001.npz").expanduser()

# Structure names for display
STRUCTURE_NAMES = {
    0: 'PTV70',
    1: 'PTV56',
    2: 'Prostate',
    3: 'Rectum',
    4: 'Bladder',
    5: 'Femur_L',
    6: 'Femur_R',
    7: 'Bowel'
}

# Expected prescription doses for SIB validation
EXPECTED_DOSES = {
    'PTV70': 1.0,      # 70/70 = 1.0
    'PTV56': 0.8,      # 56/70 = 0.8
    'PTV50.4': 0.72    # 50.4/70 = 0.72 (if you add this channel later)
}

# Colors for mask overlays (RGBA)
STRUCTURE_COLORS = {
    0: (1.0, 0.0, 0.0, 0.5),   # PTV70 - Red
    1: (1.0, 0.5, 0.0, 0.5),   # PTV56 - Orange
    2: (0.0, 1.0, 0.0, 0.4),   # Prostate - Green
    3: (0.6, 0.3, 0.0, 0.5),   # Rectum - Brown
    4: (1.0, 1.0, 0.0, 0.5),   # Bladder - Yellow
    5: (0.0, 0.0, 1.0, 0.4),   # Femur_L - Blue
    6: (0.0, 0.5, 1.0, 0.4),   # Femur_R - Light Blue
    7: (1.0, 0.0, 1.0, 0.4),   # Bowel - Magenta
}

In [None]:
# Helper functions

def load_npz(filepath):
    """Load .npz file and return as dict with basic validation."""
    data = np.load(filepath)
    result = {key: data[key] for key in data.files}
    
    # Basic shape validation
    expected_keys = ['ct', 'dose', 'masks', 'constraints']
    for key in expected_keys:
        if key not in result:
            warnings.warn(f"Missing key: {key}")
    
    print(f"Loaded: {filepath.name}")
    print(f"  CT shape: {result.get('ct', np.array([])).shape}")
    print(f"  Dose shape: {result.get('dose', np.array([])).shape}")
    print(f"  Masks shape: {result.get('masks', np.array([])).shape}")
    print(f"  Constraints shape: {result.get('constraints', np.array([])).shape}")
    
    return result


def get_structure_stats(masks, dose):
    """Calculate dose statistics for each structure."""
    stats = {}
    for ch, name in STRUCTURE_NAMES.items():
        if ch < masks.shape[0]:
            mask = masks[ch] > 0
            voxel_count = mask.sum()
            if voxel_count > 0:
                dose_in_struct = dose[mask]
                stats[name] = {
                    'voxels': int(voxel_count),
                    'volume_cc': float(voxel_count * 1 * 1 * 2 / 1000),  # 1x1x2mm spacing
                    'dose_mean': float(dose_in_struct.mean()),
                    'dose_min': float(dose_in_struct.min()),
                    'dose_max': float(dose_in_struct.max()),
                    'dose_std': float(dose_in_struct.std()),
                    'D95': float(np.percentile(dose_in_struct, 5)),   # Dose to 95% volume
                    'D5': float(np.percentile(dose_in_struct, 95)),   # Dose to 5% volume
                }
            else:
                stats[name] = {'voxels': 0, 'volume_cc': 0.0, 'dose_mean': None}
    return stats


def check_boundary_truncation(masks):
    """Check if structures touch the grid boundary (potential truncation)."""
    truncation_warnings = []
    for ch, name in STRUCTURE_NAMES.items():
        if ch < masks.shape[0]:
            mask = masks[ch]
            if mask.sum() == 0:
                continue
            
            # Check each face of the 3D volume
            faces = {
                'Y-min (anterior)': mask[0, :, :].sum(),
                'Y-max (posterior)': mask[-1, :, :].sum(),
                'X-min (right)': mask[:, 0, :].sum(),
                'X-max (left)': mask[:, -1, :].sum(),
                'Z-min (inferior)': mask[:, :, 0].sum(),
                'Z-max (superior)': mask[:, :, -1].sum(),
            }
            
            touching = [face for face, count in faces.items() if count > 0]
            if touching:
                truncation_warnings.append(f"{name}: touches {', '.join(touching)}")
    
    return truncation_warnings


def validate_case(data):
    """Run validation checks on a single case."""
    ct = data['ct']
    dose = data['dose']
    masks = data['masks']
    
    checks = {}
    
    # CT checks
    checks['ct_range_valid'] = (ct.min() >= 0) and (ct.max() <= 1)
    checks['ct_min'] = float(ct.min())
    checks['ct_max'] = float(ct.max())
    
    # Dose checks
    checks['dose_nonneg'] = dose.min() >= -0.01  # Allow tiny negative from interpolation
    checks['dose_min'] = float(dose.min())
    checks['dose_max'] = float(dose.max())
    checks['dose_reasonable'] = dose.max() < 1.5  # Max < 150% prescription
    
    # PTV70 checks
    ptv70_mask = masks[0] > 0
    checks['ptv70_exists'] = ptv70_mask.sum() > 0
    if checks['ptv70_exists']:
        ptv70_dose_mean = dose[ptv70_mask].mean()
        checks['ptv70_dose_mean'] = float(ptv70_dose_mean)
        checks['ptv70_dose_adequate'] = 0.90 < ptv70_dose_mean < 1.10
    else:
        checks['ptv70_dose_mean'] = None
        checks['ptv70_dose_adequate'] = False
    
    # PTV56 checks (if exists)
    ptv56_mask = masks[1] > 0
    checks['ptv56_exists'] = ptv56_mask.sum() > 0
    if checks['ptv56_exists']:
        # For SIB: PTV56 should have mean dose ~0.8 (56/70)
        # But PTV56 typically encompasses PTV70, so dose will be higher
        ptv56_only = ptv56_mask & ~ptv70_mask  # PTV56 excluding PTV70
        if ptv56_only.sum() > 0:
            ptv56_only_dose = dose[ptv56_only].mean()
            checks['ptv56_only_dose_mean'] = float(ptv56_only_dose)
            checks['ptv56_dose_appropriate'] = 0.70 < ptv56_only_dose < 0.95
        else:
            checks['ptv56_only_dose_mean'] = None
            checks['ptv56_dose_appropriate'] = None  # Can't assess
    
    # Registration check: dose should be higher inside PTV than outside
    any_ptv = (masks[0] > 0) | (masks[1] > 0)
    if any_ptv.sum() > 0 and (~any_ptv).sum() > 0:
        dose_in_ptv = dose[any_ptv].mean()
        dose_outside_ptv = dose[~any_ptv].mean()
        checks['dose_higher_in_ptv'] = dose_in_ptv > dose_outside_ptv
        checks['dose_in_ptv'] = float(dose_in_ptv)
        checks['dose_outside_ptv'] = float(dose_outside_ptv)
    
    # Boundary truncation
    checks['boundary_warnings'] = check_boundary_truncation(masks)
    
    return checks

---
## 1. Load and Inspect a Single Case

In [None]:
# List available .npz files
npz_files = sorted(NPZ_DIR.glob("*.npz"))
print(f"Found {len(npz_files)} .npz files in {NPZ_DIR}")
for i, f in enumerate(npz_files[:10]):  # Show first 10
    print(f"  [{i}] {f.name}")
if len(npz_files) > 10:
    print(f"  ... and {len(npz_files) - 10} more")

In [None]:
# Select a case to inspect (change index as needed)
CASE_INDEX = 0  # <-- CHANGE THIS to inspect different cases

if npz_files:
    selected_file = npz_files[CASE_INDEX]
    data = load_npz(selected_file)
else:
    raise FileNotFoundError(f"No .npz files found in {NPZ_DIR}")

---
## 2. Validation Checks

In [None]:
# Run validation
checks = validate_case(data)

print("=" * 60)
print("VALIDATION RESULTS")
print("=" * 60)

# CT
print(f"\n[CT Volume]")
print(f"  Range valid [0,1]: {'✓ PASS' if checks['ct_range_valid'] else '✗ FAIL'}")
print(f"  Actual range: [{checks['ct_min']:.4f}, {checks['ct_max']:.4f}]")

# Dose
print(f"\n[Dose Volume]")
print(f"  Non-negative: {'✓ PASS' if checks['dose_nonneg'] else '✗ FAIL'}")
print(f"  Max < 150%: {'✓ PASS' if checks['dose_reasonable'] else '✗ FAIL'}")
print(f"  Actual range: [{checks['dose_min']:.4f}, {checks['dose_max']:.4f}]")
print(f"  (In Gy: [{checks['dose_min']*70:.1f}, {checks['dose_max']*70:.1f}])")

# PTV70
print(f"\n[PTV70]")
print(f"  Exists: {'✓ PASS' if checks['ptv70_exists'] else '✗ FAIL'}")
if checks['ptv70_dose_mean'] is not None:
    print(f"  Mean dose: {checks['ptv70_dose_mean']:.3f} (target: ~1.0)")
    print(f"  Mean dose (Gy): {checks['ptv70_dose_mean']*70:.1f} Gy")
    print(f"  Dose adequate [0.9-1.1]: {'✓ PASS' if checks['ptv70_dose_adequate'] else '✗ FAIL'}")

# PTV56
print(f"\n[PTV56]")
print(f"  Exists: {'✓ YES' if checks['ptv56_exists'] else '- NO (single Rx case)'}")
if checks.get('ptv56_only_dose_mean') is not None:
    print(f"  Mean dose (PTV56 only, excl PTV70): {checks['ptv56_only_dose_mean']:.3f} (target: ~0.8)")
    print(f"  Mean dose (Gy): {checks['ptv56_only_dose_mean']*70:.1f} Gy")

# Registration
print(f"\n[Registration Check]")
if 'dose_higher_in_ptv' in checks:
    print(f"  Dose higher in PTV: {'✓ PASS' if checks['dose_higher_in_ptv'] else '✗ FAIL - CRITICAL'}")
    print(f"  Mean dose in PTV: {checks['dose_in_ptv']:.3f}")
    print(f"  Mean dose outside PTV: {checks['dose_outside_ptv']:.3f}")

# Boundary truncation
print(f"\n[Boundary Truncation]")
if checks['boundary_warnings']:
    print("  ⚠ WARNINGS:")
    for w in checks['boundary_warnings']:
        print(f"    - {w}")
else:
    print("  ✓ No structures touch boundaries")

---
## 3. Structure Statistics

In [None]:
# Calculate and display structure statistics
stats = get_structure_stats(data['masks'], data['dose'])

print("=" * 80)
print("STRUCTURE STATISTICS")
print("=" * 80)
print(f"{'Structure':<12} {'Voxels':>10} {'Vol (cc)':>10} {'Mean':>8} {'Min':>8} {'Max':>8} {'D95':>8} {'D5':>8}")
print("-" * 80)

for name, s in stats.items():
    if s['dose_mean'] is not None:
        print(f"{name:<12} {s['voxels']:>10,} {s['volume_cc']:>10.1f} "
              f"{s['dose_mean']:>8.3f} {s['dose_min']:>8.3f} {s['dose_max']:>8.3f} "
              f"{s['D95']:>8.3f} {s['D5']:>8.3f}")
    else:
        print(f"{name:<12} {'(empty)':>10}")

print("-" * 80)
print("Note: Dose values normalized to 70 Gy. Multiply by 70 for absolute Gy.")

---
## 4. Visual Inspection: Axial Slices

In [None]:
def plot_axial_slice(ct, dose, masks, slice_idx, show_structures=[0, 1, 3, 4]):
    """
    Plot a single axial slice with CT, dose overlay, and structure contours.
    
    Args:
        ct: CT volume (Y, X, Z)
        dose: Dose volume (Y, X, Z)
        masks: Mask volume (C, Y, X, Z)
        slice_idx: Z index to display
        show_structures: List of channel indices to show
    """
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    ct_slice = ct[:, :, slice_idx]
    dose_slice = dose[:, :, slice_idx]
    
    # 1. CT only
    axes[0].imshow(ct_slice, cmap='gray', vmin=0, vmax=1)
    axes[0].set_title(f'CT (slice {slice_idx})')
    axes[0].axis('off')
    
    # 2. Dose only
    im_dose = axes[1].imshow(dose_slice * 70, cmap='jet', vmin=0, vmax=75)
    axes[1].set_title('Dose (Gy)')
    axes[1].axis('off')
    plt.colorbar(im_dose, ax=axes[1], fraction=0.046, pad=0.04)
    
    # 3. CT + Dose overlay
    axes[2].imshow(ct_slice, cmap='gray', vmin=0, vmax=1)
    dose_masked = np.ma.masked_where(dose_slice < 0.1, dose_slice)
    axes[2].imshow(dose_masked * 70, cmap='jet', alpha=0.5, vmin=0, vmax=75)
    axes[2].set_title('CT + Dose Overlay')
    axes[2].axis('off')
    
    # 4. CT + Structure contours
    axes[3].imshow(ct_slice, cmap='gray', vmin=0, vmax=1)
    
    for ch in show_structures:
        if ch < masks.shape[0]:
            mask_slice = masks[ch, :, :, slice_idx]
            if mask_slice.sum() > 0:
                color = STRUCTURE_COLORS.get(ch, (1, 1, 1, 0.5))[:3]
                axes[3].contour(mask_slice, levels=[0.5], colors=[color], linewidths=2)
    
    # Legend
    legend_elements = [plt.Line2D([0], [0], color=STRUCTURE_COLORS[ch][:3], linewidth=2, 
                                   label=STRUCTURE_NAMES[ch]) 
                       for ch in show_structures if ch < masks.shape[0]]
    axes[3].legend(handles=legend_elements, loc='upper right', fontsize=8)
    axes[3].set_title('CT + Structures')
    axes[3].axis('off')
    
    plt.tight_layout()
    return fig

In [None]:
# Find slice with maximum PTV70 area (good for visualization)
ptv70_per_slice = data['masks'][0].sum(axis=(0, 1))
best_slice = np.argmax(ptv70_per_slice)
print(f"Slice with max PTV70 area: {best_slice}")

# Plot that slice
fig = plot_axial_slice(data['ct'], data['dose'], data['masks'], best_slice, 
                       show_structures=[0, 1, 3, 4])  # PTV70, PTV56, Rectum, Bladder
plt.show()

In [None]:
# Plot multiple slices through the volume
z_dim = data['ct'].shape[2]
slices_to_show = [z_dim//4, z_dim//2, 3*z_dim//4]  # 25%, 50%, 75% through volume

for sl in slices_to_show:
    fig = plot_axial_slice(data['ct'], data['dose'], data['masks'], sl,
                           show_structures=[0, 1, 3, 4])
    plt.show()

---
## 5. Interactive Slice Viewer (if ipywidgets available)

In [None]:
if HAS_WIDGETS:
    @interact(slice_idx=IntSlider(min=0, max=data['ct'].shape[2]-1, 
                                   value=data['ct'].shape[2]//2, 
                                   description='Z Slice:'))
    def interactive_viewer(slice_idx):
        fig = plot_axial_slice(data['ct'], data['dose'], data['masks'], slice_idx,
                               show_structures=[0, 1, 2, 3, 4])
        plt.show()
else:
    print("Interactive viewer requires ipywidgets. Install with: pip install ipywidgets")

---
## 6. Coronal and Sagittal Views

In [None]:
def plot_orthogonal_views(ct, dose, masks, title=""):
    """
    Plot axial, coronal, and sagittal views through the center of PTV70.
    """
    # Find center of PTV70
    ptv70 = masks[0]
    if ptv70.sum() > 0:
        coords = np.array(np.where(ptv70 > 0))
        center_y, center_x, center_z = coords.mean(axis=1).astype(int)
    else:
        center_y, center_x, center_z = [s//2 for s in ct.shape]
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Row 1: CT with dose overlay
    # Axial (Z slice)
    axes[0, 0].imshow(ct[:, :, center_z], cmap='gray', vmin=0, vmax=1)
    dose_ax = np.ma.masked_where(dose[:, :, center_z] < 0.1, dose[:, :, center_z])
    axes[0, 0].imshow(dose_ax * 70, cmap='jet', alpha=0.5, vmin=0, vmax=75)
    axes[0, 0].axhline(center_y, color='yellow', linestyle='--', alpha=0.5)
    axes[0, 0].axvline(center_x, color='yellow', linestyle='--', alpha=0.5)
    axes[0, 0].set_title(f'Axial (Z={center_z})')
    axes[0, 0].set_xlabel('X (Left-Right)')
    axes[0, 0].set_ylabel('Y (Ant-Post)')
    
    # Coronal (Y slice)
    coronal_ct = ct[center_y, :, :].T  # Transpose for proper orientation
    coronal_dose = dose[center_y, :, :].T
    axes[0, 1].imshow(coronal_ct, cmap='gray', vmin=0, vmax=1, aspect=2)  # aspect=2 for 1x1x2mm
    dose_cor = np.ma.masked_where(coronal_dose < 0.1, coronal_dose)
    axes[0, 1].imshow(dose_cor * 70, cmap='jet', alpha=0.5, vmin=0, vmax=75, aspect=2)
    axes[0, 1].axhline(center_z, color='yellow', linestyle='--', alpha=0.5)
    axes[0, 1].axvline(center_x, color='yellow', linestyle='--', alpha=0.5)
    axes[0, 1].set_title(f'Coronal (Y={center_y})')
    axes[0, 1].set_xlabel('X (Left-Right)')
    axes[0, 1].set_ylabel('Z (Inf-Sup)')
    
    # Sagittal (X slice)
    sagittal_ct = ct[:, center_x, :].T
    sagittal_dose = dose[:, center_x, :].T
    axes[0, 2].imshow(sagittal_ct, cmap='gray', vmin=0, vmax=1, aspect=2)
    dose_sag = np.ma.masked_where(sagittal_dose < 0.1, sagittal_dose)
    axes[0, 2].imshow(dose_sag * 70, cmap='jet', alpha=0.5, vmin=0, vmax=75, aspect=2)
    axes[0, 2].axhline(center_z, color='yellow', linestyle='--', alpha=0.5)
    axes[0, 2].axvline(center_y, color='yellow', linestyle='--', alpha=0.5)
    axes[0, 2].set_title(f'Sagittal (X={center_x})')
    axes[0, 2].set_xlabel('Y (Ant-Post)')
    axes[0, 2].set_ylabel('Z (Inf-Sup)')
    
    # Row 2: Structure contours
    structures_to_show = [0, 1, 3, 4]  # PTV70, PTV56, Rectum, Bladder
    
    # Axial contours
    axes[1, 0].imshow(ct[:, :, center_z], cmap='gray', vmin=0, vmax=1)
    for ch in structures_to_show:
        if masks[ch, :, :, center_z].sum() > 0:
            axes[1, 0].contour(masks[ch, :, :, center_z], levels=[0.5], 
                               colors=[STRUCTURE_COLORS[ch][:3]], linewidths=2)
    axes[1, 0].set_title('Axial Structures')
    
    # Coronal contours
    axes[1, 1].imshow(coronal_ct, cmap='gray', vmin=0, vmax=1, aspect=2)
    for ch in structures_to_show:
        cor_mask = masks[ch, center_y, :, :].T
        if cor_mask.sum() > 0:
            axes[1, 1].contour(cor_mask, levels=[0.5],
                               colors=[STRUCTURE_COLORS[ch][:3]], linewidths=2)
    axes[1, 1].set_title('Coronal Structures')
    
    # Sagittal contours
    axes[1, 2].imshow(sagittal_ct, cmap='gray', vmin=0, vmax=1, aspect=2)
    for ch in structures_to_show:
        sag_mask = masks[ch, :, center_x, :].T
        if sag_mask.sum() > 0:
            axes[1, 2].contour(sag_mask, levels=[0.5],
                               colors=[STRUCTURE_COLORS[ch][:3]], linewidths=2)
    axes[1, 2].set_title('Sagittal Structures')
    
    # Add legend
    legend_elements = [plt.Line2D([0], [0], color=STRUCTURE_COLORS[ch][:3], linewidth=2,
                                   label=STRUCTURE_NAMES[ch]) for ch in structures_to_show]
    fig.legend(handles=legend_elements, loc='lower center', ncol=4, fontsize=10)
    
    if title:
        fig.suptitle(title, fontsize=14)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.1)
    return fig

# Plot orthogonal views
fig = plot_orthogonal_views(data['ct'], data['dose'], data['masks'], 
                            title=f"Case: {selected_file.stem}")
plt.show()

---
## 7. Dose-Volume Histogram (DVH)

In [None]:
def compute_dvh(dose, mask, bins=100):
    """
    Compute cumulative DVH for a structure.
    
    Returns:
        dose_bins: Dose values (normalized)
        volume_percent: Percentage of volume receiving at least that dose
    """
    if mask.sum() == 0:
        return None, None
    
    dose_in_struct = dose[mask > 0].flatten()
    
    # Create dose bins from 0 to max dose
    max_dose = max(dose_in_struct.max(), 1.2)  # At least up to 120%
    dose_bins = np.linspace(0, max_dose, bins)
    
    # Cumulative histogram (percentage receiving >= dose)
    volume_percent = np.array([(dose_in_struct >= d).sum() / len(dose_in_struct) * 100 
                                for d in dose_bins])
    
    return dose_bins, volume_percent


def plot_dvh(dose, masks, title="Dose-Volume Histogram"):
    """
    Plot DVH for all structures.
    """
    fig, ax = plt.subplots(figsize=(12, 8))
    
    for ch, name in STRUCTURE_NAMES.items():
        if ch < masks.shape[0]:
            dose_bins, volume_percent = compute_dvh(dose, masks[ch])
            if dose_bins is not None:
                color = STRUCTURE_COLORS[ch][:3]
                # Convert normalized dose to Gy
                ax.plot(dose_bins * 70, volume_percent, label=name, 
                        color=color, linewidth=2)
    
    ax.set_xlabel('Dose (Gy)', fontsize=12)
    ax.set_ylabel('Volume (%)', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.set_xlim(0, 80)
    ax.set_ylim(0, 105)
    ax.grid(True, alpha=0.3)
    ax.legend(loc='upper right', fontsize=10)
    
    # Add prescription lines
    ax.axvline(70, color='red', linestyle='--', alpha=0.5, label='Rx 70 Gy')
    ax.axvline(56, color='orange', linestyle='--', alpha=0.5, label='Rx 56 Gy')
    
    plt.tight_layout()
    return fig

# Plot DVH
fig = plot_dvh(data['dose'], data['masks'], title=f"DVH: {selected_file.stem}")
plt.show()

---
## 8. Dose Profile Through PTV

In [None]:
def plot_dose_profiles(ct, dose, masks):
    """
    Plot dose profiles through the center of PTV70 in all three directions.
    """
    # Find center of PTV70
    ptv70 = masks[0]
    if ptv70.sum() > 0:
        coords = np.array(np.where(ptv70 > 0))
        center_y, center_x, center_z = coords.mean(axis=1).astype(int)
    else:
        center_y, center_x, center_z = [s//2 for s in ct.shape]
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # X profile (Left-Right)
    x_profile = dose[center_y, :, center_z] * 70
    axes[0].plot(x_profile, 'b-', linewidth=2)
    axes[0].axhline(70, color='red', linestyle='--', alpha=0.7, label='70 Gy')
    axes[0].axhline(56, color='orange', linestyle='--', alpha=0.7, label='56 Gy')
    axes[0].fill_between(range(len(x_profile)), 0, x_profile, 
                          where=masks[0, center_y, :, center_z] > 0, 
                          alpha=0.3, color='red', label='PTV70')
    axes[0].set_xlabel('X (voxels)')
    axes[0].set_ylabel('Dose (Gy)')
    axes[0].set_title(f'Left-Right Profile (Y={center_y}, Z={center_z})')
    axes[0].legend()
    axes[0].set_ylim(0, 80)
    axes[0].grid(True, alpha=0.3)
    
    # Y profile (Anterior-Posterior)
    y_profile = dose[:, center_x, center_z] * 70
    axes[1].plot(y_profile, 'b-', linewidth=2)
    axes[1].axhline(70, color='red', linestyle='--', alpha=0.7)
    axes[1].axhline(56, color='orange', linestyle='--', alpha=0.7)
    axes[1].fill_between(range(len(y_profile)), 0, y_profile,
                          where=masks[0, :, center_x, center_z] > 0,
                          alpha=0.3, color='red')
    axes[1].set_xlabel('Y (voxels)')
    axes[1].set_ylabel('Dose (Gy)')
    axes[1].set_title(f'Ant-Post Profile (X={center_x}, Z={center_z})')
    axes[1].set_ylim(0, 80)
    axes[1].grid(True, alpha=0.3)
    
    # Z profile (Superior-Inferior)
    z_profile = dose[center_y, center_x, :] * 70
    axes[2].plot(z_profile, 'b-', linewidth=2)
    axes[2].axhline(70, color='red', linestyle='--', alpha=0.7)
    axes[2].axhline(56, color='orange', linestyle='--', alpha=0.7)
    axes[2].fill_between(range(len(z_profile)), 0, z_profile,
                          where=masks[0, center_y, center_x, :] > 0,
                          alpha=0.3, color='red')
    axes[2].set_xlabel('Z (voxels)')
    axes[2].set_ylabel('Dose (Gy)')
    axes[2].set_title(f'Sup-Inf Profile (X={center_x}, Y={center_y})')
    axes[2].set_ylim(0, 80)
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

# Plot dose profiles
fig = plot_dose_profiles(data['ct'], data['dose'], data['masks'])
plt.show()

---
## 9. Summary Report

In [None]:
def generate_summary_report(filepath, data, checks, stats):
    """
    Generate a text summary report for a case.
    """
    report = []
    report.append("=" * 70)
    report.append(f"VERIFICATION REPORT: {filepath.name}")
    report.append("=" * 70)
    
    # Overall status
    critical_checks = ['ct_range_valid', 'dose_nonneg', 'ptv70_exists', 
                       'ptv70_dose_adequate', 'dose_higher_in_ptv']
    failures = [c for c in critical_checks if c in checks and checks[c] == False]
    
    if failures:
        report.append(f"\n⚠️  STATUS: NEEDS REVIEW")
        report.append(f"   Failed checks: {', '.join(failures)}")
    else:
        report.append(f"\n✓  STATUS: PASSED")
    
    # Case type
    ptv56_exists = stats.get('PTV56', {}).get('voxels', 0) > 0
    case_type = "SIB (PTV70 + PTV56)" if ptv56_exists else "Single Rx (PTV70 only)"
    report.append(f"\n   Case type: {case_type}")
    
    # Key metrics
    report.append(f"\n--- Key Metrics ---")
    if stats.get('PTV70', {}).get('dose_mean'):
        report.append(f"   PTV70 mean dose: {stats['PTV70']['dose_mean']*70:.1f} Gy "
                      f"(D95={stats['PTV70']['D95']*70:.1f}, D5={stats['PTV70']['D5']*70:.1f})")
    if ptv56_exists and stats.get('PTV56', {}).get('dose_mean'):
        report.append(f"   PTV56 mean dose: {stats['PTV56']['dose_mean']*70:.1f} Gy")
    if stats.get('Rectum', {}).get('dose_mean'):
        report.append(f"   Rectum mean dose: {stats['Rectum']['dose_mean']*70:.1f} Gy "
                      f"(max={stats['Rectum']['dose_max']*70:.1f})")
    if stats.get('Bladder', {}).get('dose_mean'):
        report.append(f"   Bladder mean dose: {stats['Bladder']['dose_mean']*70:.1f} Gy "
                      f"(max={stats['Bladder']['dose_max']*70:.1f})")
    
    # Warnings
    if checks.get('boundary_warnings'):
        report.append(f"\n--- Warnings ---")
        for w in checks['boundary_warnings']:
            report.append(f"   ⚠ {w}")
    
    report.append("\n" + "=" * 70)
    
    return "\n".join(report)

# Generate and print report
report = generate_summary_report(selected_file, data, checks, stats)
print(report)

---
## 10. Quick Batch Check (Optional)

Run basic validation on all files to identify cases needing manual review.

In [None]:
# Set to True to run batch validation on all files
RUN_BATCH = False

if RUN_BATCH and npz_files:
    print(f"Running batch validation on {len(npz_files)} files...\n")
    
    results = []
    for npz_file in npz_files:
        try:
            d = np.load(npz_file)
            data_tmp = {key: d[key] for key in d.files}
            checks_tmp = validate_case(data_tmp)
            
            # Determine pass/fail
            critical_pass = all([
                checks_tmp.get('ct_range_valid', False),
                checks_tmp.get('dose_nonneg', False),
                checks_tmp.get('ptv70_exists', False),
                checks_tmp.get('dose_higher_in_ptv', True),  # Default True if not computed
            ])
            
            results.append({
                'file': npz_file.name,
                'status': '✓ PASS' if critical_pass else '✗ FAIL',
                'ptv70_dose': checks_tmp.get('ptv70_dose_mean'),
                'warnings': len(checks_tmp.get('boundary_warnings', [])),
            })
        except Exception as e:
            results.append({
                'file': npz_file.name,
                'status': f'✗ ERROR: {str(e)[:30]}',
                'ptv70_dose': None,
                'warnings': 0,
            })
    
    # Print summary table
    print(f"{'File':<30} {'Status':<15} {'PTV70 Dose':>12} {'Warnings':>10}")
    print("-" * 70)
    for r in results:
        dose_str = f"{r['ptv70_dose']:.3f}" if r['ptv70_dose'] else "N/A"
        print(f"{r['file']:<30} {r['status']:<15} {dose_str:>12} {r['warnings']:>10}")
    
    # Summary
    passed = sum(1 for r in results if '✓' in r['status'])
    print(f"\n{'='*70}")
    print(f"BATCH SUMMARY: {passed}/{len(results)} cases passed")
else:
    print("Set RUN_BATCH = True to run batch validation")