# Experiment: Baseline U-Net for VMAT Dose Prediction

**Date:** 2026-01-19  
**Experiment ID:** `baseline_unet_run1`  
**Status:** Complete  

---

## 1. Overview

### 1.1 Objective
Train a baseline 3D U-Net model for direct dose prediction from CT and structure data. This serves as:
1. Validation that the preprocessing pipeline produces usable training data
2. A baseline for comparison against diffusion-based models (DDPM)
3. Proof of concept for the dose prediction task

### 1.2 Key Results

| Metric | Value |
|--------|-------|
| **Best Validation MAE** | **3.73 Gy** |
| Training Duration | 2.55 hours |
| Epochs (early stopped) | 62/200 |
| Best Epoch | 12 |

### 1.3 Conclusion
The baseline U-Net achieves reasonable dose prediction accuracy (3.73 Gy MAE), demonstrating that:
- The preprocessed data is suitable for training
- The model architecture can learn the dose distribution
- Early stopping triggered at epoch 62, suggesting convergence

---

## 2. Reproducibility Information

### 2.1 Environment

In [None]:
# Environment snapshot at time of experiment
REPRODUCIBILITY_INFO = {
    'git_commit': '0e2fedc74fd75899cf1ead5488b63a96e0bbf455',
    'git_message': 'Fix SDF computation bug: cast uint8 mask to bool before bitwise NOT',
    'python_version': '3.12.12',
    'pytorch_version': '2.4.1',
    'cuda_version': '12.4',
    'gpu': 'NVIDIA GeForce RTX 3090',
    'gpu_memory_gb': 24,
    'random_seed': 42,
    'experiment_date': '2026-01-19',
    'preprocessing_version': '2.2.0',
}

for k, v in REPRODUCIBILITY_INFO.items():
    print(f'{k}: {v}')

### 2.2 Command to Reproduce

```bash
# Ensure git is at the correct commit
git checkout 0e2fedc74fd75899cf1ead5488b63a96e0bbf455

# Activate environment
conda activate vmat-diffusion

# Run training
python scripts/train_baseline_unet.py \
    --data_dir /mnt/i/processed_npz \
    --epochs 200 \
    --exp_name baseline_unet_run1 \
    --seed 42
```

---

## 3. Dataset

### 3.1 Data Summary

In [None]:
import json
from pathlib import Path

DATASET_INFO = {
    'total_cases': 23,
    'skipped_cases': 1,  # case_0013 - missing PTV56 (non-SIB)
    'total_size_gb': 4.6,
    'preprocessing_script': 'preprocess_dicom_rt_v2.2.py',
    'volume_shape': (512, 512, 256),
    'voxel_spacing_mm': (1.0, 1.0, 2.0),
    'structures': ['PTV70', 'PTV56', 'Prostate', 'Rectum', 'Bladder', 'Femur_L', 'Femur_R', 'Bowel'],
    'input_channels': 9,  # 1 CT + 8 SDF
    'constraint_dim': 13,
}

print('Dataset Summary:')
for k, v in DATASET_INFO.items():
    print(f'  {k}: {v}')

In [None]:
# Load test case list (held out during training)
test_cases_path = Path('../runs/baseline_unet_run1/test_cases.json')
if test_cases_path.exists():
    with open(test_cases_path) as f:
        test_cases = json.load(f)
    print('Test cases (held out):')
    for case in test_cases:
        print(f'  - {case}')

### 3.2 Data Split

| Split | Cases | Percentage |
|-------|-------|------------|
| Train | 19 | 83% |
| Validation | 2 | 9% |
| Test | 2 | 9% |

**Note:** Random split with seed=42 for reproducibility.

---

## 4. Model Architecture

### 4.1 Model Summary

In [None]:
MODEL_CONFIG = {
    'architecture': 'BaselineUNet3D',
    'type': 'Direct Regression (not diffusion)',
    'parameters': 23_732_801,
    'parameters_human': '23.7M',
    'input_channels': 9,  # 1 CT + 8 SDF channels
    'output_channels': 1,  # Dose
    'base_channels': 48,
    'encoder_channels': [48, 96, 192, 384],
    'bottleneck_channels': 768,
    'constraint_conditioning': 'FiLM (Feature-wise Linear Modulation)',
    'constraint_dim': 13,
    'normalization': 'GroupNorm',
    'activation': 'SiLU',
    'upsampling': 'Trilinear + Conv',
}

print('Model Configuration:')
for k, v in MODEL_CONFIG.items():
    print(f'  {k}: {v}')

### 4.2 Architecture Diagram

```
Input: CT (1) + SDF (8) = 9 channels, 128³ patches
       ↓
┌─────────────────────────────────────────────────────────┐
│  Encoder                                                │
│  Conv3D(9→48) → Conv3D(48→96) → Conv3D(96→192) → (384) │
│  + MaxPool3D at each level                              │
└─────────────────────────────────────────────────────────┘
       ↓
┌─────────────────────────────────────────────────────────┐
│  Bottleneck (384→768→384)                               │
│  + FiLM conditioning from constraints (13-dim)          │
└─────────────────────────────────────────────────────────┘
       ↓
┌─────────────────────────────────────────────────────────┐
│  Decoder with Skip Connections                          │
│  Upsample + Concat + Conv3D at each level               │
└─────────────────────────────────────────────────────────┘
       ↓
Output: Dose (1 channel), 128³
```

---

## 5. Training Configuration

In [None]:
TRAINING_CONFIG = {
    'max_epochs': 200,
    'actual_epochs': 62,
    'early_stopping': True,
    'early_stopping_patience': 50,
    'batch_size': 2,
    'patch_size': 128,
    'patches_per_volume': 4,
    'samples_per_epoch': 76,  # 19 train cases × 4 patches
    'optimizer': 'AdamW',
    'learning_rate': 1e-4,
    'weight_decay': 0.01,
    'lr_scheduler': 'CosineAnnealingLR',
    'loss_function': 'MSE + Gradient Loss',
    'precision': '16-mixed (AMP)',
    'gradient_clip': 1.0,
    'num_workers': 4,
}

print('Training Configuration:')
for k, v in TRAINING_CONFIG.items():
    print(f'  {k}: {v}')

---

## 6. Training Results

### 6.1 Load Metrics

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Set publication-quality defaults
plt.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 11,
    'figure.figsize': (8, 6),
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'axes.grid': True,
    'grid.alpha': 0.3,
})

# Load metrics
metrics_path = Path('../runs/baseline_unet_run1/version_3/metrics.csv')
df = pd.read_csv(metrics_path)

# Extract epoch-level metrics
epoch_metrics = df[df['val/mae_gy'].notna()][['epoch', 'val/loss', 'val/mae_gy']].copy()
epoch_metrics['epoch'] = epoch_metrics['epoch'].astype(int)

# Also get training loss per epoch
train_loss = df[df['train/loss_epoch'].notna()][['epoch', 'train/loss_epoch']].copy()
train_loss['epoch'] = train_loss['epoch'].astype(int)

# Merge
metrics = epoch_metrics.merge(train_loss, on='epoch', how='left')
metrics.columns = ['epoch', 'val_loss', 'val_mae_gy', 'train_loss']

print(f'Loaded {len(metrics)} epochs of metrics')
metrics.head(10)

### 6.2 Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Loss curves
ax1 = axes[0]
ax1.plot(metrics['epoch'], metrics['train_loss'], 'b-', label='Train Loss', linewidth=2)
ax1.plot(metrics['epoch'], metrics['val_loss'], 'r-', label='Val Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss (MSE)')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.set_xlim(0, metrics['epoch'].max())

# Plot 2: Validation MAE
ax2 = axes[1]
ax2.plot(metrics['epoch'], metrics['val_mae_gy'], 'g-', linewidth=2)
best_epoch = metrics.loc[metrics['val_mae_gy'].idxmin(), 'epoch']
best_mae = metrics['val_mae_gy'].min()
ax2.axhline(y=best_mae, color='r', linestyle='--', alpha=0.7, label=f'Best: {best_mae:.2f} Gy')
ax2.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7)
ax2.scatter([best_epoch], [best_mae], color='r', s=100, zorder=5)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Validation MAE (Gy)')
ax2.set_title('Validation Mean Absolute Error')
ax2.legend()
ax2.set_xlim(0, metrics['epoch'].max())

plt.tight_layout()
plt.savefig('../runs/baseline_unet_run1/training_curves.png', dpi=300)
plt.show()

print(f'\nBest validation MAE: {best_mae:.3f} Gy at epoch {best_epoch}')

### 6.3 Results Summary

In [None]:
# Load training summary
with open('../runs/baseline_unet_run1/training_summary.json') as f:
    summary = json.load(f)

RESULTS = {
    'training_time_hours': round(summary['total_time_hours'], 2),
    'total_epochs': summary['final_metrics']['epoch'],
    'best_val_mae_gy': round(summary['best_val_mae_gy'], 3),
    'final_val_mae_gy': round(summary['final_metrics']['val_mae_gy'], 3),
    'final_train_loss': round(summary['final_metrics']['train_loss'], 6),
    'final_val_loss': round(summary['final_metrics']['val_loss'], 6),
    'early_stopped': summary['final_metrics']['epoch'] < 199,
    'best_checkpoint': 'best-epoch=012-val/mae_gy=3.735.ckpt',
}

print('=' * 60)
print('FINAL RESULTS')
print('=' * 60)
for k, v in RESULTS.items():
    print(f'  {k}: {v}')

---

## 7. Analysis

### 7.1 Convergence Analysis

In [None]:
# Analyze convergence
print('Convergence Analysis:')
print(f'  - Training stopped at epoch {RESULTS["total_epochs"]} (max: 200)')
print(f'  - Best MAE achieved at epoch {best_epoch}')
print(f'  - Epochs after best: {RESULTS["total_epochs"] - best_epoch}')
print(f'  - Early stopping patience: 50 epochs')
print(f'  - Final MAE ({RESULTS["final_val_mae_gy"]:.2f} Gy) vs Best ({RESULTS["best_val_mae_gy"]:.2f} Gy)')
print(f'  - Degradation: {RESULTS["final_val_mae_gy"] - RESULTS["best_val_mae_gy"]:.2f} Gy')

# Check for overfitting
train_val_gap = metrics['val_loss'].iloc[-1] - metrics['train_loss'].iloc[-1]
print(f'\nOverfitting check:')
print(f'  - Final train loss: {metrics["train_loss"].iloc[-1]:.6f}')
print(f'  - Final val loss: {metrics["val_loss"].iloc[-1]:.6f}')
print(f'  - Gap: {train_val_gap:.6f}')
print(f'  - Assessment: {"Minimal overfitting" if train_val_gap < 0.01 else "Some overfitting observed"}')

### 7.2 Clinical Interpretation

**MAE of 3.73 Gy in context:**
- Prescription dose: 70 Gy to PTV70
- 3.73 Gy ≈ 5.3% of prescription dose
- This is a reasonable baseline, but clinical acceptability typically requires:
  - DVH-based metrics
  - Gamma analysis (3%/3mm)
  - Structure-specific dose accuracy

**Next steps for clinical validation:**
1. Run inference on held-out test cases
2. Compute DVH comparisons
3. Gamma analysis
4. Visualize dose distributions

---

## 8. Artifacts and Outputs

### 8.1 Saved Files

In [None]:
from pathlib import Path

run_dir = Path('../runs/baseline_unet_run1')

print('Saved artifacts:')
print(f'\nDirectory: {run_dir.absolute()}')
print('\nFiles:')
for f in sorted(run_dir.rglob('*')):
    if f.is_file():
        size = f.stat().st_size
        if size > 1e6:
            size_str = f'{size/1e6:.1f} MB'
        elif size > 1e3:
            size_str = f'{size/1e3:.1f} KB'
        else:
            size_str = f'{size} B'
        rel_path = f.relative_to(run_dir)
        print(f'  {rel_path}: {size_str}')

### 8.2 How to Load Best Checkpoint

In [None]:
# Example code to load the trained model
LOAD_MODEL_CODE = '''
import torch
from pathlib import Path

# Path to best checkpoint
ckpt_path = Path('runs/baseline_unet_run1/checkpoints/best-epoch=012-val/mae_gy=3.735.ckpt')

# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location='cuda')

# If using the LightningModule directly:
# model = BaselineUNetModule.load_from_checkpoint(ckpt_path)
# model.eval()

# Or extract just the state dict:
# state_dict = checkpoint['state_dict']
'''
print(LOAD_MODEL_CODE)

---

## 9. Next Steps

### Immediate:
1. [ ] Run inference on test cases (case_0009, case_0022)
2. [ ] Generate DVH comparisons
3. [ ] Compute gamma pass rates
4. [ ] Visualize predicted vs ground truth doses

### Future experiments:
1. [ ] Train DDPM model for comparison
2. [ ] Ablation study: with/without SDF features
3. [ ] Ablation study: with/without constraint conditioning
4. [ ] Increase dataset size (process case_0013 with --relax_filter)
5. [ ] Hyperparameter tuning (learning rate, batch size)

---

## 10. Notes and Observations

### Issues encountered:
1. **Gamma computation failed** - pymedphys gamma function had module import issue (non-critical)
2. **Deterministic warnings** - trilinear upsampling backward pass not deterministic on CUDA

### Observations:
1. Best MAE achieved early (epoch 12), suggesting model learns quickly
2. Validation MAE fluctuates significantly (3.7-9.4 Gy range)
3. Small validation set (2 cases) may contribute to variance
4. Training loss continues to decrease after validation plateaus

---

*Notebook generated: 2026-01-19*  
*Last updated: 2026-01-19*