## 1. Setup and Configuration

In [None]:
import sys
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path.cwd().parent))

import logging
import numpy as np
import matplotlib.pyplot as plt
import torch

from config import get_default_config, get_fast_config
from preprocessing import MRIPreprocessor
from dataset import BrainMRIDataset, create_data_loaders
from utils import (
    setup_logging,
    discover_dataset,
    create_patient_level_split,
    visualize_sample,
    set_random_seeds,
)

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# For better plots
%matplotlib inline
plt.rcParams['figure.figsize'] = (15, 5)

print("✓ Imports successful")

In [None]:
# Load configuration
# Use 'fast' config for quick experimentation
config = get_fast_config()

# Or use default for full resolution
# config = get_default_config()

print(f"Configuration loaded:")
print(f"  Data root: {config.data.data_root}")
print(f"  Output root: {config.data.output_root}")
print(f"  Target spacing: {config.preprocessing.target_spacing}")
print(f"  Batch size: {config.training.batch_size}")
print(f"  Random seed: {config.split.random_seed}")

# Set random seeds for reproducibility
set_random_seeds(config.split.random_seed)

## 2. Dataset Discovery

Discover all MRI volumes in the BIDS-formatted dataset.

In [None]:
# Discover dataset
data_list = discover_dataset(config.data.data_root, config.data)

print(f"\nFound {len(data_list)} volumes")
print(f"\nFirst 5 entries:")
for i, item in enumerate(data_list[:5]):
    print(f"{i+1}. {item['subject']}/{item['session']}/{item['modality']} - {item['field_strength']}")

## 3. Data Exploration

In [None]:
# Analyze dataset composition
import pandas as pd

# Create dataframe for analysis
df = pd.DataFrame([
    {
        'subject': item['subject'],
        'session': item['session'],
        'modality': item['modality'],
        'field_strength': item['field_strength'],
    }
    for item in data_list
])

print("Dataset Summary:")
print("=" * 50)
print(f"Total volumes: {len(df)}")
print(f"\nSubjects: {df['subject'].nunique()}")
print(df['subject'].unique())
print(f"\nSessions: {df['session'].nunique()}")
print(df['session'].value_counts())
print(f"\nModalities: {df['modality'].nunique()}")
print(df['modality'].value_counts())
print(f"\nField Strengths:")
print(df['field_strength'].value_counts())

# Cross-tabulation
print("\nCross-tabulation (Session x Modality):")
print(pd.crosstab(df['session'], df['modality']))

print("\nCross-tabulation (Subject x Modality):")
print(pd.crosstab(df['subject'], df['modality']))

## 4. Patient-Level Splitting

Create train/val/test splits ensuring no data leakage.

In [None]:
# Create patient-level split
train_data, val_data, test_data = create_patient_level_split(
    data_list,
    train_ratio=config.split.train_ratio,
    val_ratio=config.split.val_ratio,
    test_ratio=config.split.test_ratio,
    random_seed=config.split.random_seed,
)

# Verify splits
train_subjects = set(item['subject'] for item in train_data)
val_subjects = set(item['subject'] for item in val_data)
test_subjects = set(item['subject'] for item in test_data)

print(f"\nSplit Verification:")
print(f"Train subjects: {sorted(train_subjects)}")
print(f"Val subjects: {sorted(val_subjects)}")
print(f"Test subjects: {sorted(test_subjects)}")

# Check for leakage
assert len(train_subjects & val_subjects) == 0, "Train-Val leakage!"
assert len(train_subjects & test_subjects) == 0, "Train-Test leakage!"
assert len(val_subjects & test_subjects) == 0, "Val-Test leakage!"
print("\n✓ No data leakage detected!")

## 5. Preprocessing Pipeline

Preprocess a single volume to demonstrate the pipeline.

In [None]:
# Select a sample to preprocess
sample_idx = 0
sample_path = data_list[sample_idx]['image']

print(f"Preprocessing sample: {sample_path.name}")
print(f"Subject: {data_list[sample_idx]['subject']}")
print(f"Session: {data_list[sample_idx]['session']}")
print(f"Modality: {data_list[sample_idx]['modality']}")
print(f"Field strength: {data_list[sample_idx]['field_strength']}")

In [None]:
# Initialize preprocessor
preprocessor = MRIPreprocessor(config.preprocessing)

# Preprocess single volume
output_path = config.data.output_root / f"test_{sample_path.name}"

result = preprocessor.preprocess_single(
    sample_path,
    output_path,
    save_intermediate=True,  # Save intermediate steps
)

print("\nPreprocessing Results:")
print(f"Original shape: {result['original_shape']}")
print(f"Processed shape: {result['processed_shape']}")
print(f"\nStatistics:")
for key, value in result['statistics'].items():
    print(f"  {key}: {value:.4f}")

## 6. DataLoader Creation

Create PyTorch DataLoaders for training.

In [None]:
# For this demo, use raw data (skip preprocessing for speed)
# In production, use preprocessed data

train_loader, val_loader, test_loader = create_data_loaders(
    config, train_data, val_data, test_data
)

print(f"DataLoaders created:")
print(f"  Train: {len(train_loader.dataset)} samples, {len(train_loader)} batches")
print(f"  Val: {len(val_loader.dataset)} samples, {len(val_loader)} batches")
print(f"  Test: {len(test_loader.dataset)} samples, {len(test_loader)} batches")

In [None]:
# Load a batch
batch = next(iter(train_loader))

print(f"Batch contents:")
print(f"  Images shape: {batch['image'].shape}")
print(f"  Images dtype: {batch['image'].dtype}")
print(f"  Images range: [{batch['image'].min():.4f}, {batch['image'].max():.4f}]")
print(f"  Subjects: {batch['subject']}")
print(f"  Sessions: {batch['session']}")
print(f"  Modalities: {batch['modality']}")

## 7. Visualization

Visualize samples from the dataset.

In [None]:
# Visualize a sample from the batch
sample_idx = 0
image = batch['image'][sample_idx, 0].cpu().numpy()  # Remove channel dim
subject = batch['subject'][sample_idx]

# Get middle slices
mid_sag = image.shape[0] // 2
mid_cor = image.shape[1] // 2
mid_ax = image.shape[2] // 2

# Compute intensity range for display
if image[image != 0].size > 0:
    vmin, vmax = np.percentile(image[image != 0], [1, 99])
else:
    vmin, vmax = image.min(), image.max()

# Plot
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(image[mid_sag, :, :].T, cmap='gray', vmin=vmin, vmax=vmax, origin='lower')
axes[0].set_title(f'Sagittal (x={mid_sag})')
axes[0].axis('off')

axes[1].imshow(image[:, mid_cor, :].T, cmap='gray', vmin=vmin, vmax=vmax, origin='lower')
axes[1].set_title(f'Coronal (y={mid_cor})')
axes[1].axis('off')

axes[2].imshow(image[:, :, mid_ax].T, cmap='gray', vmin=vmin, vmax=vmax, origin='lower')
axes[2].set_title(f'Axial (z={mid_ax})')
axes[2].axis('off')

fig.suptitle(f'Subject: {subject}', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Visualize intensity distribution
nonzero_values = image[image != 0]

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histogram
axes[0].hist(nonzero_values.flatten(), bins=100, alpha=0.7, color='blue')
axes[0].set_title('Intensity Distribution')
axes[0].set_xlabel('Intensity')
axes[0].set_ylabel('Frequency')
axes[0].grid(True, alpha=0.3)

# Box plot
axes[1].boxplot(nonzero_values.flatten())
axes[1].set_title('Intensity Box Plot')
axes[1].set_ylabel('Intensity')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Statistics:")
print(f"  Mean: {nonzero_values.mean():.4f}")
print(f"  Std: {nonzero_values.std():.4f}")
print(f"  Min: {nonzero_values.min():.4f}")
print(f"  Max: {nonzero_values.max():.4f}")
print(f"  Median: {np.median(nonzero_values):.4f}")

## 8. Statistics Analysis

Compute dataset-wide statistics.

In [None]:
from dataset import compute_dataset_statistics

# Compute statistics on training set
print("Computing training set statistics...")
train_stats = compute_dataset_statistics(train_loader, max_samples=10)

print("\nTraining Set Statistics:")
for key, value in train_stats.items():
    if isinstance(value, (int, float)):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

## 9. Demo Training Loop

Demonstrate how to use the DataLoader in a training loop.

In [None]:
# Simple demo training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

num_epochs = 1
num_batches_per_epoch = 3  # Limit for demo

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    print("-" * 50)
    
    for batch_idx, batch in enumerate(train_loader):
        if batch_idx >= num_batches_per_epoch:
            break
        
        # Get data
        images = batch['image'].to(device)
        subjects = batch['subject']
        
        # Simulate forward pass (just compute mean)
        output = images.mean(dim=[2, 3, 4])  # Global average pooling
        
        # Simulate loss computation
        dummy_loss = output.mean()
        
        print(f"  Batch {batch_idx + 1}:")
        print(f"    Shape: {images.shape}")
        print(f"    Subjects: {subjects}")
        print(f"    Loss: {dummy_loss.item():.4f}")
        
        # Check for NaN/Inf
        if torch.isnan(images).any():
            print("    WARNING: NaN detected!")
        if torch.isinf(images).any():
            print("    WARNING: Inf detected!")

print("\n✓ Training loop demo completed successfully!")

## Summary

This notebook demonstrated:
1. ✅ Dataset discovery and exploration
2. ✅ Patient-level splitting (no data leakage)
3. ✅ Preprocessing pipeline
4. ✅ DataLoader creation
5. ✅ Visualization and statistics
6. ✅ Training loop integration

### Next Steps:
- Run full preprocessing: `python example_pipeline.py --preprocess`
- Implement your model architecture
- Customize augmentation strategies
- Generate proper brain masks using HD-BET or SynthStrip
- Experiment with different normalization methods

### Key Takeaways:
- **Always use patient-level splits** to prevent data leakage
- **Set random seeds** for reproducibility
- **Monitor statistics** to ensure proper normalization
- **Visualize samples** to verify preprocessing
- **Save split information** for reproducibility