# VMAT Diffusion: NPZ Verification Notebook (v2.1)

This notebook provides visual and quantitative verification of preprocessed `.npz` files for the VMAT diffusion model.

## Purpose
- Visual inspection of CT, dose, and mask alignment
- SDF (Signed Distance Field) visualization and validation
- Beam geometry review (Phase 2 prep)
- Quantitative validation of dose distributions
- Truncation analysis for boundary issues
- Detection of preprocessing errors before training

## Expected .npz Structure (v2.1)
| Key | Shape | Description |
|-----|-------|-------------|
| `ct` | (512, 512, 256) | CT volume, normalized [0, 1] |
| `dose` | (512, 512, 256) | Dose volume, normalized to Rx |
| `masks` | (8, 512, 512, 256) | Binary masks for structures |
| `masks_sdf` | (8, 512, 512, 256) | Signed distance fields [-1, 1] |
| `constraints` | (13,) | Prescription targets + OAR constraints |
| `metadata` | dict | Case info, beam geometry, validation |

## Mask Channel Mapping
| Channel | Structure |
|---------|----------|
| 0 | PTV70 |
| 1 | PTV56 |
| 2 | Prostate |
| 3 | Rectum |
| 4 | Bladder |
| 5 | Femur_L |
| 6 | Femur_R |
| 7 | Bowel |

In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from pathlib import Path
import warnings
import json

# For interactive slice viewing (optional)
try:
    from ipywidgets import interact, IntSlider
    HAS_WIDGETS = True
except ImportError:
    HAS_WIDGETS = False
    print("ipywidgets not available - interactive sliders disabled")

%matplotlib inline
plt.rcParams['figure.figsize'] = [16, 10]
plt.rcParams['figure.dpi'] = 100

In [None]:
# Configuration - EDIT THESE PATHS
NPZ_DIR = Path("~/vmat-diffusion-project/processed_npz").expanduser()
# Or specify a single file directly:
# NPZ_FILE = Path("~/vmat-diffusion-project/processed_npz/case_0001.npz").expanduser()

# Structure names for display
STRUCTURE_NAMES = {
    0: 'PTV70',
    1: 'PTV56',
    2: 'Prostate',
    3: 'Rectum',
    4: 'Bladder',
    5: 'Femur_L',
    6: 'Femur_R',
    7: 'Bowel'
}

# Expected prescription doses for SIB validation
EXPECTED_DOSES = {
    'PTV70': 1.0,      # 70/70 = 1.0
    'PTV56': 0.8,      # 56/70 = 0.8
    'PTV50.4': 0.72    # 50.4/70 = 0.72 (if you add this channel later)
}

# Colors for mask overlays (RGBA)
STRUCTURE_COLORS = {
    0: (1.0, 0.0, 0.0, 0.5),   # PTV70 - Red
    1: (1.0, 0.5, 0.0, 0.5),   # PTV56 - Orange
    2: (0.0, 1.0, 0.0, 0.4),   # Prostate - Green
    3: (0.6, 0.3, 0.0, 0.5),   # Rectum - Brown
    4: (1.0, 1.0, 0.0, 0.5),   # Bladder - Yellow
    5: (0.0, 0.0, 1.0, 0.4),   # Femur_L - Blue
    6: (0.0, 0.5, 1.0, 0.4),   # Femur_R - Light Blue
    7: (1.0, 0.0, 1.0, 0.4),   # Bowel - Magenta
}

In [None]:
# Helper functions

def load_npz(filepath):
    """Load .npz file and return as dict with validation for v2.1 structure."""
    data = np.load(filepath, allow_pickle=True)
    result = {key: data[key] for key in data.files}
    
    # Check for expected keys
    v2_keys = ['ct', 'dose', 'masks', 'constraints', 'metadata']
    v21_keys = ['masks_sdf']  # New in v2.1
    
    for key in v2_keys:
        if key not in result:
            warnings.warn(f"Missing key: {key}")
    
    # Detect version
    has_sdf = 'masks_sdf' in result
    version = 'v2.1' if has_sdf else 'v2.0 or earlier'
    
    print(f"Loaded: {filepath.name}")
    print(f"  Detected version: {version}")
    print(f"  CT shape: {result.get('ct', np.array([])).shape}")
    print(f"  Dose shape: {result.get('dose', np.array([])).shape}")
    print(f"  Masks shape: {result.get('masks', np.array([])).shape}")
    if has_sdf:
        print(f"  Masks SDF shape: {result.get('masks_sdf', np.array([])).shape}")
    print(f"  Constraints shape: {result.get('constraints', np.array([])).shape}")
    
    # Extract metadata
    if 'metadata' in result:
        metadata = result['metadata'].item() if isinstance(result['metadata'], np.ndarray) else result['metadata']
        print(f"  Case type: {metadata.get('case_type', {}).get('type', 'unknown')}")
        print(f"  Script version: {metadata.get('script_version', 'unknown')}")
    
    return result


def get_metadata(data):
    """Safely extract metadata dict from data."""
    if 'metadata' not in data:
        return {}
    metadata = data['metadata']
    if isinstance(metadata, np.ndarray):
        return metadata.item()
    return metadata


def get_structure_stats(masks, dose):
    """Calculate dose statistics for each structure."""
    stats = {}
    for ch, name in STRUCTURE_NAMES.items():
        if ch < masks.shape[0]:
            mask = masks[ch] > 0
            voxel_count = mask.sum()
            if voxel_count > 0:
                dose_in_struct = dose[mask]
                stats[name] = {
                    'voxels': int(voxel_count),
                    'volume_cc': float(voxel_count * 1 * 1 * 2 / 1000),  # 1x1x2mm spacing
                    'dose_mean': float(dose_in_struct.mean()),
                    'dose_min': float(dose_in_struct.min()),
                    'dose_max': float(dose_in_struct.max()),
                    'dose_std': float(dose_in_struct.std()),
                    'D95': float(np.percentile(dose_in_struct, 5)),   # Dose to 95% volume
                    'D5': float(np.percentile(dose_in_struct, 95)),   # Dose to 5% volume
                }
            else:
                stats[name] = {'voxels': 0, 'volume_cc': 0.0, 'dose_mean': None}
    return stats


def check_boundary_truncation(masks):
    """Check if structures touch the grid boundary (potential truncation)."""
    truncation_warnings = []
    for ch, name in STRUCTURE_NAMES.items():
        if ch < masks.shape[0]:
            mask = masks[ch]
            if mask.sum() == 0:
                continue
            
            # Check each face of the 3D volume
            faces = {
                'Y-min (anterior)': mask[0, :, :].sum(),
                'Y-max (posterior)': mask[-1, :, :].sum(),
                'X-min (right)': mask[:, 0, :].sum(),
                'X-max (left)': mask[:, -1, :].sum(),
                'Z-min (inferior)': mask[:, :, 0].sum(),
                'Z-max (superior)': mask[:, :, -1].sum(),
            }
            
            touching = [face for face, count in faces.items() if count > 0]
            if touching:
                truncation_warnings.append(f"{name}: touches {', '.join(touching)}")
    
    return truncation_warnings


def validate_case(data):
    """Run validation checks on a single case."""
    ct = data['ct']
    dose = data['dose']
    masks = data['masks']
    metadata = get_metadata(data)
    
    checks = {}
    
    # Version detection
    checks['has_sdf'] = 'masks_sdf' in data
    checks['has_beam_geometry'] = metadata.get('beam_geometry') is not None
    checks['script_version'] = metadata.get('script_version', 'unknown')
    
    # CT checks
    checks['ct_range_valid'] = (ct.min() >= 0) and (ct.max() <= 1)
    checks['ct_min'] = float(ct.min())
    checks['ct_max'] = float(ct.max())
    
    # Dose checks
    checks['dose_nonneg'] = dose.min() >= -0.01  # Allow tiny negative from interpolation
    checks['dose_min'] = float(dose.min())
    checks['dose_max'] = float(dose.max())
    checks['dose_reasonable'] = dose.max() < 1.5  # Max < 150% prescription
    
    # PTV70 checks
    ptv70_mask = masks[0] > 0
    checks['ptv70_exists'] = ptv70_mask.sum() > 0
    if checks['ptv70_exists']:
        ptv70_dose_mean = dose[ptv70_mask].mean()
        checks['ptv70_dose_mean'] = float(ptv70_dose_mean)
        checks['ptv70_dose_adequate'] = 0.85 < ptv70_dose_mean < 1.15
    else:
        checks['ptv70_dose_mean'] = None
        checks['ptv70_dose_adequate'] = False
    
    # PTV56 checks (if exists)
    ptv56_mask = masks[1] > 0
    checks['ptv56_exists'] = ptv56_mask.sum() > 0
    if checks['ptv56_exists']:
        ptv56_only = ptv56_mask & ~ptv70_mask
        if ptv56_only.sum() > 0:
            ptv56_only_dose = dose[ptv56_only].mean()
            checks['ptv56_only_dose_mean'] = float(ptv56_only_dose)
            checks['ptv56_dose_appropriate'] = 0.70 < ptv56_only_dose < 0.95
        else:
            checks['ptv56_only_dose_mean'] = None
            checks['ptv56_dose_appropriate'] = None
    
    # Registration check
    any_ptv = (masks[0] > 0) | (masks[1] > 0)
    if any_ptv.sum() > 0 and (~any_ptv).sum() > 0:
        dose_in_ptv = dose[any_ptv].mean()
        dose_outside_ptv = dose[~any_ptv].mean()
        checks['dose_higher_in_ptv'] = dose_in_ptv > dose_outside_ptv
        checks['dose_in_ptv'] = float(dose_in_ptv)
        checks['dose_outside_ptv'] = float(dose_outside_ptv)
    
    # SDF checks (v2.1)
    if checks['has_sdf']:
        sdf = data['masks_sdf']
        checks['sdf_range_valid'] = (sdf.min() >= -1.0) and (sdf.max() <= 1.0)
        checks['sdf_min'] = float(sdf.min())
        checks['sdf_max'] = float(sdf.max())
    
    # Boundary truncation
    checks['boundary_warnings'] = check_boundary_truncation(masks)
    
    # Use stored truncation info if available (more detailed)
    if 'truncation_info' in metadata:
        checks['truncation_info'] = metadata['truncation_info']
    
    return checks

---
## 1. Load and Inspect a Single Case

In [None]:
# List available .npz files
npz_files = sorted(NPZ_DIR.glob("*.npz"))
print(f"Found {len(npz_files)} .npz files in {NPZ_DIR}")
for i, f in enumerate(npz_files[:10]):  # Show first 10
    print(f"  [{i}] {f.name}")
if len(npz_files) > 10:
    print(f"  ... and {len(npz_files) - 10} more")

In [None]:
# Select a case to inspect (change index as needed)
CASE_INDEX = 0  # <-- CHANGE THIS to inspect different cases

if npz_files:
    selected_file = npz_files[CASE_INDEX]
    data = load_npz(selected_file)
else:
    raise FileNotFoundError(f"No .npz files found in {NPZ_DIR}")

---
## 2. Metadata and Beam Geometry (v2.1)

In [None]:
# Display metadata
metadata = get_metadata(data)

print("=" * 60)
print("CASE METADATA")
print("=" * 60)

print(f"\n[Basic Info]")
print(f"  Case ID: {metadata.get('case_id', 'N/A')}")
print(f"  Processed: {metadata.get('processed_date', 'N/A')}")
print(f"  Script version: {metadata.get('script_version', 'N/A')}")

print(f"\n[Case Type]")
case_type = metadata.get('case_type', {})
print(f"  Type: {case_type.get('type', 'N/A')}")
print(f"  PTV70: {'Yes' if case_type.get('ptv70_exists') else 'No'}")
print(f"  PTV56: {'Yes' if case_type.get('ptv56_exists') else 'No'}")
print(f"  PTV50.4: {'Yes' if case_type.get('ptv50_exists') else 'No'}")

print(f"\n[Prescription]")
rx_info = metadata.get('prescription_info', {})
print(f"  Primary dose: {rx_info.get('primary_dose', 'N/A')} Gy")
print(f"  Fractions: {rx_info.get('fractions', 'N/A')}")
print(f"  Dose/fraction: {rx_info.get('dose_per_fraction', 'N/A'):.2f} Gy" if rx_info.get('dose_per_fraction') else "  Dose/fraction: N/A")
print(f"  Source: {rx_info.get('source', 'N/A')}")

print(f"\n[Grid]")
print(f"  Shape: {metadata.get('target_shape', 'N/A')}")
print(f"  Spacing: {metadata.get('target_spacing_mm', 'N/A')} mm")
print(f"  SDF clip: {metadata.get('sdf_clip_mm', 'N/A')} mm")

In [None]:
# Display beam geometry (Phase 2 prep)
beam_geometry = metadata.get('beam_geometry')

print("=" * 60)
print("BEAM GEOMETRY (Phase 2 Data)")
print("=" * 60)

if beam_geometry:
    print(f"\n[Plan Summary]")
    print(f"  Plan label: {beam_geometry.get('plan_label', 'N/A')}")
    print(f"  Number of beams: {beam_geometry.get('num_beams', 'N/A')}")
    print(f"  Total MU: {beam_geometry.get('total_mu', 'N/A'):.1f}" if beam_geometry.get('total_mu') else "  Total MU: N/A")
    
    beams = beam_geometry.get('beams', [])
    if beams:
        print(f"\n[Beam Details]")
        print(f"  {'Beam':<15} {'Start°':>8} {'Stop°':>8} {'Direction':<12} {'Coll°':>8} {'MU':>10} {'CPs':>6}")
        print("  " + "-" * 75)
        for i, beam in enumerate(beams):
            name = beam.get('beam_name', f'Beam {i+1}') or f'Beam {i+1}'
            start = beam.get('arc_start_angle', 'N/A')
            stop = beam.get('arc_stop_angle', 'N/A')
            direction = beam.get('arc_direction', 'N/A') or 'N/A'
            coll = beam.get('collimator_angle', 'N/A')
            mu = beam.get('final_mu', 'N/A')
            cps = beam.get('num_control_points', 'N/A')
            
            start_str = f"{start:.1f}" if isinstance(start, (int, float)) else str(start)
            stop_str = f"{stop:.1f}" if isinstance(stop, (int, float)) else str(stop)
            coll_str = f"{coll:.1f}" if isinstance(coll, (int, float)) else str(coll)
            mu_str = f"{mu:.1f}" if isinstance(mu, (int, float)) else str(mu)
            
            print(f"  {name:<15} {start_str:>8} {stop_str:>8} {direction:<12} {coll_str:>8} {mu_str:>10} {cps:>6}")
else:
    print("\n  No beam geometry available.")
    print("  (Extracted from RP file in v2.1+ preprocessing)")

---
## 3. Validation Checks

In [None]:
# Run validation
checks = validate_case(data)

print("=" * 60)
print("VALIDATION RESULTS")
print("=" * 60)

# Version info
print(f"\n[Version Detection]")
print(f"  Script version: {checks['script_version']}")
print(f"  Has SDFs: {'✓ Yes' if checks['has_sdf'] else '✗ No (v2.0 or earlier)'}")
print(f"  Has beam geometry: {'✓ Yes' if checks['has_beam_geometry'] else '✗ No'}")

# CT
print(f"\n[CT Volume]")
print(f"  Range valid [0,1]: {'✓ PASS' if checks['ct_range_valid'] else '✗ FAIL'}")
print(f"  Actual range: [{checks['ct_min']:.4f}, {checks['ct_max']:.4f}]")

# Dose
print(f"\n[Dose Volume]")
print(f"  Non-negative: {'✓ PASS' if checks['dose_nonneg'] else '✗ FAIL'}")
print(f"  Max < 150%: {'✓ PASS' if checks['dose_reasonable'] else '✗ FAIL'}")
print(f"  Actual range: [{checks['dose_min']:.4f}, {checks['dose_max']:.4f}]")
rx_dose = metadata.get('normalization_dose_gy', 70.0)
print(f"  (In Gy: [{checks['dose_min']*rx_dose:.1f}, {checks['dose_max']*rx_dose:.1f}])")

# SDFs (v2.1)
if checks['has_sdf']:
    print(f"\n[Signed Distance Fields]")
    print(f"  Range valid [-1,1]: {'✓ PASS' if checks['sdf_range_valid'] else '✗ FAIL'}")
    print(f"  Actual range: [{checks['sdf_min']:.4f}, {checks['sdf_max']:.4f}]")

# PTV70
print(f"\n[PTV70]")
print(f"  Exists: {'✓ PASS' if checks['ptv70_exists'] else '✗ FAIL'}")
if checks['ptv70_dose_mean'] is not None:
    print(f"  Mean dose: {checks['ptv70_dose_mean']:.3f} (target: ~1.0)")
    print(f"  Mean dose (Gy): {checks['ptv70_dose_mean']*rx_dose:.1f} Gy")
    print(f"  Dose adequate [0.85-1.15]: {'✓ PASS' if checks['ptv70_dose_adequate'] else '✗ FAIL'}")

# PTV56
print(f"\n[PTV56]")
print(f"  Exists: {'✓ YES' if checks['ptv56_exists'] else '- NO (single Rx case)'}")
if checks.get('ptv56_only_dose_mean') is not None:
    print(f"  Mean dose (PTV56 only, excl PTV70): {checks['ptv56_only_dose_mean']:.3f} (target: ~0.8)")
    print(f"  Mean dose (Gy): {checks['ptv56_only_dose_mean']*rx_dose:.1f} Gy")

# Registration
print(f"\n[Registration Check]")
if 'dose_higher_in_ptv' in checks:
    print(f"  Dose higher in PTV: {'✓ PASS' if checks['dose_higher_in_ptv'] else '✗ FAIL - CRITICAL'}")
    print(f"  Mean dose in PTV: {checks['dose_in_ptv']:.3f}")
    print(f"  Mean dose outside PTV: {checks['dose_outside_ptv']:.3f}")

# Boundary truncation
print(f"\n[Boundary Truncation]")
if checks['boundary_warnings']:
    print("  ⚠ WARNINGS:")
    for w in checks['boundary_warnings']:
        print(f"    - {w}")
else:
    print("  ✓ No structures touch boundaries")

In [None]:
# Detailed truncation info (v2.1)
truncation_info = checks.get('truncation_info', {})

if truncation_info:
    print("=" * 60)
    print("DETAILED TRUNCATION ANALYSIS (v2.1)")
    print("=" * 60)
    print(f"\n{'Structure':<12} {'Exists':>8} {'Truncated':>10} {'% at Boundary':>15} {'Boundaries Touched'}")
    print("-" * 75)
    
    for name, info in truncation_info.items():
        if isinstance(info, dict):
            exists = 'Yes' if info.get('exists') else 'No'
            truncated = 'Yes' if info.get('truncated') else 'No'
            pct = info.get('truncation_percent', 0)
            boundaries = ', '.join(info.get('touching_boundaries', [])) or 'None'
            
            # Highlight critical structures with >5% truncation
            warning = ' ⚠' if pct > 5 and name in ['PTV70', 'PTV56', 'Rectum', 'Bladder'] else ''
            print(f"{name:<12} {exists:>8} {truncated:>10} {pct:>14.1f}% {boundaries}{warning}")
else:
    print("\nDetailed truncation info not available (v2.1+ feature)")

---
## 4. Structure Statistics

In [None]:
# Calculate and display structure statistics
stats = get_structure_stats(data['masks'], data['dose'])

print("=" * 80)
print("STRUCTURE STATISTICS")
print("=" * 80)
print(f"{'Structure':<12} {'Voxels':>10} {'Vol (cc)':>10} {'Mean':>8} {'Min':>8} {'Max':>8} {'D95':>8} {'D5':>8}")
print("-" * 80)

for name, s in stats.items():
    if s['dose_mean'] is not None:
        print(f"{name:<12} {s['voxels']:>10,} {s['volume_cc']:>10.1f} "
              f"{s['dose_mean']:>8.3f} {s['dose_min']:>8.3f} {s['dose_max']:>8.3f} "
              f"{s['D95']:>8.3f} {s['D5']:>8.3f}")
    else:
        print(f"{name:<12} {'(empty)':>10}")

print("-" * 80)
rx_dose = metadata.get('normalization_dose_gy', 70.0)
print(f"Note: Dose values normalized to {rx_dose:.0f} Gy. Multiply by {rx_dose:.0f} for absolute Gy.")

---
## 5. SDF Visualization (v2.1)

In [None]:
def plot_sdf_comparison(ct, masks_binary, masks_sdf, slice_idx, structures=[0, 3]):
    """
    Compare binary masks vs SDFs for selected structures.
    
    Args:
        ct: CT volume
        masks_binary: Binary masks (C, Y, X, Z)
        masks_sdf: SDF masks (C, Y, X, Z)
        slice_idx: Z slice to display
        structures: List of channel indices to show
    """
    n_structures = len(structures)
    fig, axes = plt.subplots(2, n_structures + 1, figsize=(5 * (n_structures + 1), 10))
    
    # CT reference (left column)
    axes[0, 0].imshow(ct[:, :, slice_idx], cmap='gray', vmin=0, vmax=1)
    axes[0, 0].set_title('CT')
    axes[0, 0].axis('off')
    
    axes[1, 0].imshow(ct[:, :, slice_idx], cmap='gray', vmin=0, vmax=1)
    axes[1, 0].set_title('CT')
    axes[1, 0].axis('off')
    
    # Each structure: binary (top) vs SDF (bottom)
    for i, ch in enumerate(structures):
        name = STRUCTURE_NAMES.get(ch, f'Ch {ch}')
        
        # Binary mask
        axes[0, i+1].imshow(ct[:, :, slice_idx], cmap='gray', vmin=0, vmax=1)
        if masks_binary[ch, :, :, slice_idx].sum() > 0:
            axes[0, i+1].contour(masks_binary[ch, :, :, slice_idx], levels=[0.5],
                                  colors=[STRUCTURE_COLORS[ch][:3]], linewidths=2)
            axes[0, i+1].contourf(masks_binary[ch, :, :, slice_idx], levels=[0.5, 1],
                                   colors=[STRUCTURE_COLORS[ch][:3]], alpha=0.3)
        axes[0, i+1].set_title(f'{name} - Binary')
        axes[0, i+1].axis('off')
        
        # SDF
        sdf_slice = masks_sdf[ch, :, :, slice_idx]
        im = axes[1, i+1].imshow(sdf_slice, cmap='RdBu', vmin=-1, vmax=1)
        axes[1, i+1].contour(masks_binary[ch, :, :, slice_idx], levels=[0.5],
                              colors='black', linewidths=1, linestyles='--')
        axes[1, i+1].set_title(f'{name} - SDF')
        axes[1, i+1].axis('off')
        plt.colorbar(im, ax=axes[1, i+1], fraction=0.046, pad=0.04)
    
    plt.suptitle(f'Binary Masks vs Signed Distance Fields (slice {slice_idx})', fontsize=14)
    plt.tight_layout()
    return fig


# Check if SDFs are available
if 'masks_sdf' in data:
    # Find slice with maximum PTV70 area
    ptv70_per_slice = data['masks'][0].sum(axis=(0, 1))
    best_slice = np.argmax(ptv70_per_slice)
    
    print(f"Comparing Binary Masks vs SDFs at slice {best_slice}")
    print("SDF values: negative (blue) = inside, zero (white) = boundary, positive (red) = outside")
    
    fig = plot_sdf_comparison(data['ct'], data['masks'], data['masks_sdf'], 
                               best_slice, structures=[0, 1, 3, 4])  # PTV70, PTV56, Rectum, Bladder
    plt.show()
else:
    print("SDFs not available in this file (v2.1+ feature).")
    print("Re-run preprocessing with preprocess_dicom_rt_v2.1.py to generate SDFs.")

In [None]:
# SDF profile through PTV70
if 'masks_sdf' in data:
    # Find center of PTV70
    ptv70 = data['masks'][0]
    if ptv70.sum() > 0:
        coords = np.array(np.where(ptv70 > 0))
        center_y, center_x, center_z = coords.mean(axis=1).astype(int)
    else:
        center_y, center_x, center_z = [s//2 for s in data['ct'].shape]
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # X profile
    sdf_profile = data['masks_sdf'][0, center_y, :, center_z]
    binary_profile = data['masks'][0, center_y, :, center_z].astype(float)
    
    axes[0].plot(sdf_profile, 'b-', linewidth=2, label='SDF')
    axes[0].plot(binary_profile - 0.5, 'r--', linewidth=1, label='Binary (shifted)')
    axes[0].axhline(0, color='gray', linestyle=':', alpha=0.5)
    axes[0].fill_between(range(len(sdf_profile)), -1, 1, 
                          where=binary_profile > 0, alpha=0.2, color='red')
    axes[0].set_xlabel('X (voxels)')
    axes[0].set_ylabel('SDF value')
    axes[0].set_title(f'PTV70 SDF Profile (L-R, Y={center_y}, Z={center_z})')
    axes[0].set_ylim(-1.1, 1.1)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Y profile
    sdf_profile = data['masks_sdf'][0, :, center_x, center_z]
    binary_profile = data['masks'][0, :, center_x, center_z].astype(float)
    
    axes[1].plot(sdf_profile, 'b-', linewidth=2, label='SDF')
    axes[1].plot(binary_profile - 0.5, 'r--', linewidth=1, label='Binary (shifted)')
    axes[1].axhline(0, color='gray', linestyle=':', alpha=0.5)
    axes[1].fill_between(range(len(sdf_profile)), -1, 1,
                          where=binary_profile > 0, alpha=0.2, color='red')
    axes[1].set_xlabel('Y (voxels)')
    axes[1].set_ylabel('SDF value')
    axes[1].set_title(f'PTV70 SDF Profile (A-P, X={center_x}, Z={center_z})')
    axes[1].set_ylim(-1.1, 1.1)
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Z profile
    sdf_profile = data['masks_sdf'][0, center_y, center_x, :]
    binary_profile = data['masks'][0, center_y, center_x, :].astype(float)
    
    axes[2].plot(sdf_profile, 'b-', linewidth=2, label='SDF')
    axes[2].plot(binary_profile - 0.5, 'r--', linewidth=1, label='Binary (shifted)')
    axes[2].axhline(0, color='gray', linestyle=':', alpha=0.5)
    axes[2].fill_between(range(len(sdf_profile)), -1, 1,
                          where=binary_profile > 0, alpha=0.2, color='red')
    axes[2].set_xlabel('Z (voxels)')
    axes[2].set_ylabel('SDF value')
    axes[2].set_title(f'PTV70 SDF Profile (S-I, X={center_x}, Y={center_y})')
    axes[2].set_ylim(-1.1, 1.1)
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.suptitle('SDF Profiles Through PTV70 Center', fontsize=14)
    plt.tight_layout()
    plt.show()
    
    print("\nSDF Interpretation:")
    print("  - SDF = 0 at structure boundary")
    print("  - SDF < 0 inside structure (blue)")
    print("  - SDF > 0 outside structure (red)")
    print("  - Smooth transitions enable better gradient flow during training")

---
## 6. Visual Inspection: Axial Slices

In [None]:
def plot_axial_slice(ct, dose, masks, slice_idx, show_structures=[0, 1, 3, 4]):
    """
    Plot a single axial slice with CT, dose overlay, and structure contours.
    """
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    ct_slice = ct[:, :, slice_idx]
    dose_slice = dose[:, :, slice_idx]
    rx_dose = metadata.get('normalization_dose_gy', 70.0)
    
    # 1. CT only
    axes[0].imshow(ct_slice, cmap='gray', vmin=0, vmax=1)
    axes[0].set_title(f'CT (slice {slice_idx})')
    axes[0].axis('off')
    
    # 2. Dose only
    im_dose = axes[1].imshow(dose_slice * rx_dose, cmap='jet', vmin=0, vmax=75)
    axes[1].set_title('Dose (Gy)')
    axes[1].axis('off')
    plt.colorbar(im_dose, ax=axes[1], fraction=0.046, pad=0.04)
    
    # 3. CT + Dose overlay
    axes[2].imshow(ct_slice, cmap='gray', vmin=0, vmax=1)
    dose_masked = np.ma.masked_where(dose_slice < 0.1, dose_slice)
    axes[2].imshow(dose_masked * rx_dose, cmap='jet', alpha=0.5, vmin=0, vmax=75)
    axes[2].set_title('CT + Dose Overlay')
    axes[2].axis('off')
    
    # 4. CT + Structure contours
    axes[3].imshow(ct_slice, cmap='gray', vmin=0, vmax=1)
    
    for ch in show_structures:
        if ch < masks.shape[0]:
            mask_slice = masks[ch, :, :, slice_idx]
            if mask_slice.sum() > 0:
                color = STRUCTURE_COLORS.get(ch, (1, 1, 1, 0.5))[:3]
                axes[3].contour(mask_slice, levels=[0.5], colors=[color], linewidths=2)
    
    legend_elements = [plt.Line2D([0], [0], color=STRUCTURE_COLORS[ch][:3], linewidth=2, 
                                   label=STRUCTURE_NAMES[ch]) 
                       for ch in show_structures if ch < masks.shape[0]]
    axes[3].legend(handles=legend_elements, loc='upper right', fontsize=8)
    axes[3].set_title('CT + Structures')
    axes[3].axis('off')
    
    plt.tight_layout()
    return fig

In [None]:
# Find slice with maximum PTV70 area
ptv70_per_slice = data['masks'][0].sum(axis=(0, 1))
best_slice = np.argmax(ptv70_per_slice)
print(f"Slice with max PTV70 area: {best_slice}")

fig = plot_axial_slice(data['ct'], data['dose'], data['masks'], best_slice, 
                       show_structures=[0, 1, 3, 4])
plt.show()

In [None]:
# Plot multiple slices through the volume
z_dim = data['ct'].shape[2]
slices_to_show = [z_dim//4, z_dim//2, 3*z_dim//4]

for sl in slices_to_show:
    fig = plot_axial_slice(data['ct'], data['dose'], data['masks'], sl,
                           show_structures=[0, 1, 3, 4])
    plt.show()

---
## 7. Dose-Volume Histogram (DVH)

In [None]:
def compute_dvh(dose, mask, bins=100):
    """Compute cumulative DVH for a structure."""
    if mask.sum() == 0:
        return None, None
    
    dose_in_struct = dose[mask > 0].flatten()
    max_dose = max(dose_in_struct.max(), 1.2)
    dose_bins = np.linspace(0, max_dose, bins)
    volume_percent = np.array([(dose_in_struct >= d).sum() / len(dose_in_struct) * 100 
                                for d in dose_bins])
    return dose_bins, volume_percent


def plot_dvh(dose, masks, title="Dose-Volume Histogram"):
    """Plot DVH for all structures."""
    fig, ax = plt.subplots(figsize=(12, 8))
    rx_dose = metadata.get('normalization_dose_gy', 70.0)
    
    for ch, name in STRUCTURE_NAMES.items():
        if ch < masks.shape[0]:
            dose_bins, volume_percent = compute_dvh(dose, masks[ch])
            if dose_bins is not None:
                color = STRUCTURE_COLORS[ch][:3]
                ax.plot(dose_bins * rx_dose, volume_percent, label=name, 
                        color=color, linewidth=2)
    
    ax.set_xlabel('Dose (Gy)', fontsize=12)
    ax.set_ylabel('Volume (%)', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.set_xlim(0, 80)
    ax.set_ylim(0, 105)
    ax.grid(True, alpha=0.3)
    ax.legend(loc='upper right', fontsize=10)
    
    ax.axvline(70, color='red', linestyle='--', alpha=0.5, label='Rx 70 Gy')
    ax.axvline(56, color='orange', linestyle='--', alpha=0.5, label='Rx 56 Gy')
    
    plt.tight_layout()
    return fig

fig = plot_dvh(data['dose'], data['masks'], title=f"DVH: {selected_file.stem}")
plt.show()

---
## 8. Summary Report

In [None]:
def generate_summary_report(filepath, data, checks, stats):
    """Generate a text summary report for a case."""
    metadata = get_metadata(data)
    report = []
    report.append("=" * 70)
    report.append(f"VERIFICATION REPORT: {filepath.name}")
    report.append("=" * 70)
    
    # Overall status
    critical_checks = ['ct_range_valid', 'dose_nonneg', 'ptv70_exists', 
                       'ptv70_dose_adequate', 'dose_higher_in_ptv']
    failures = [c for c in critical_checks if c in checks and checks[c] == False]
    
    if failures:
        report.append(f"\n⚠️  STATUS: NEEDS REVIEW")
        report.append(f"   Failed checks: {', '.join(failures)}")
    else:
        report.append(f"\n✓  STATUS: PASSED")
    
    # Version info
    report.append(f"\n   Script version: {checks.get('script_version', 'unknown')}")
    report.append(f"   Has SDFs: {'Yes' if checks.get('has_sdf') else 'No'}")
    report.append(f"   Has beam geometry: {'Yes' if checks.get('has_beam_geometry') else 'No'}")
    
    # Case type
    case_type = metadata.get('case_type', {})
    report.append(f"\n   Case type: {case_type.get('type', 'unknown')}")
    
    # Key metrics
    rx_dose = metadata.get('normalization_dose_gy', 70.0)
    report.append(f"\n--- Key Metrics ---")
    if stats.get('PTV70', {}).get('dose_mean'):
        report.append(f"   PTV70 mean dose: {stats['PTV70']['dose_mean']*rx_dose:.1f} Gy "
                      f"(D95={stats['PTV70']['D95']*rx_dose:.1f}, D5={stats['PTV70']['D5']*rx_dose:.1f})")
    if case_type.get('ptv56_exists') and stats.get('PTV56', {}).get('dose_mean'):
        report.append(f"   PTV56 mean dose: {stats['PTV56']['dose_mean']*rx_dose:.1f} Gy")
    if stats.get('Rectum', {}).get('dose_mean'):
        report.append(f"   Rectum mean dose: {stats['Rectum']['dose_mean']*rx_dose:.1f} Gy "
                      f"(max={stats['Rectum']['dose_max']*rx_dose:.1f})")
    if stats.get('Bladder', {}).get('dose_mean'):
        report.append(f"   Bladder mean dose: {stats['Bladder']['dose_mean']*rx_dose:.1f} Gy "
                      f"(max={stats['Bladder']['dose_max']*rx_dose:.1f})")
    
    # Beam geometry summary
    beam_geometry = metadata.get('beam_geometry')
    if beam_geometry:
        report.append(f"\n--- Beam Geometry ---")
        report.append(f"   Beams: {beam_geometry.get('num_beams', 'N/A')}")
        report.append(f"   Total MU: {beam_geometry.get('total_mu', 0):.1f}")
    
    # Warnings
    if checks.get('boundary_warnings'):
        report.append(f"\n--- Warnings ---")
        for w in checks['boundary_warnings']:
            report.append(f"   ⚠ {w}")
    
    report.append("\n" + "=" * 70)
    
    return "\n".join(report)

# Generate and print report
report = generate_summary_report(selected_file, data, checks, stats)
print(report)

---
## 9. Batch Validation (Optional)

Run basic validation on all files to identify cases needing manual review.

In [None]:
# Set to True to run batch validation on all files
RUN_BATCH = False

if RUN_BATCH and npz_files:
    print(f"Running batch validation on {len(npz_files)} files...\n")
    
    results = []
    for npz_file in npz_files:
        try:
            d = np.load(npz_file, allow_pickle=True)
            data_tmp = {key: d[key] for key in d.files}
            checks_tmp = validate_case(data_tmp)
            metadata_tmp = get_metadata(data_tmp)
            
            critical_pass = all([
                checks_tmp.get('ct_range_valid', False),
                checks_tmp.get('dose_nonneg', False),
                checks_tmp.get('ptv70_exists', False),
                checks_tmp.get('dose_higher_in_ptv', True),
            ])
            
            results.append({
                'file': npz_file.name,
                'status': '✓ PASS' if critical_pass else '✗ FAIL',
                'version': checks_tmp.get('script_version', '?'),
                'has_sdf': '✓' if checks_tmp.get('has_sdf') else '✗',
                'has_beams': '✓' if checks_tmp.get('has_beam_geometry') else '✗',
                'ptv70_dose': checks_tmp.get('ptv70_dose_mean'),
                'case_type': metadata_tmp.get('case_type', {}).get('type', '?'),
            })
        except Exception as e:
            results.append({
                'file': npz_file.name,
                'status': f'✗ ERROR',
                'version': '?',
                'has_sdf': '?',
                'has_beams': '?',
                'ptv70_dose': None,
                'case_type': '?',
            })
    
    # Print summary table
    print(f"{'File':<25} {'Status':<10} {'Ver':<8} {'SDF':>4} {'Beam':>5} {'PTV70':>8} {'Type':<12}")
    print("-" * 80)
    for r in results:
        dose_str = f"{r['ptv70_dose']:.3f}" if r['ptv70_dose'] else "N/A"
        print(f"{r['file']:<25} {r['status']:<10} {r['version']:<8} {r['has_sdf']:>4} "
              f"{r['has_beams']:>5} {dose_str:>8} {r['case_type']:<12}")
    
    # Summary
    passed = sum(1 for r in results if '✓' in r['status'])
    with_sdf = sum(1 for r in results if r['has_sdf'] == '✓')
    with_beams = sum(1 for r in results if r['has_beams'] == '✓')
    
    print(f"\n{'='*80}")
    print(f"BATCH SUMMARY")
    print(f"  Passed: {passed}/{len(results)}")
    print(f"  With SDFs: {with_sdf}/{len(results)}")
    print(f"  With beam geometry: {with_beams}/{len(results)}")
else:
    print("Set RUN_BATCH = True to run batch validation")