# MONAI Medical Imaging Notebook

This notebook provides a framework for working with MONAI for medical image analysis, particularly for the intracranial aneurysm detection project.

## 1. Environment Setup and Installation

In [None]:
# Install MONAI and dependencies
# Uncomment the following lines if packages are not installed
import sys
import subprocess

def pip_install(packages):
    subprocess.check_call([sys.executable, "-m", "pip", "install", *packages])

pip_install(["monai"])
pip_install(["nibabel"])
pip_install(["pydicom"])
pip_install(["torch", "torchvision"])
pip_install(["matplotlib"])
pip_install(["pandas"])
pip_install(["scikit-learn"])

: 

## 2. Import Required Libraries

In [None]:
import os
import sys
import glob
import zipfile
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional

# Medical imaging libraries
import nibabel as nib
import pydicom

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# MONAI imports
import monai
from monai.config import print_config
from monai.data import (
    Dataset as MonaiDataset,
    CacheDataset,
    DataLoader as MonaiDataLoader,
    decollate_batch
)
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd,
    Resized,
    RandFlipd,
    RandRotate90d,
    RandShiftIntensityd,
    ToTensord,
    AddChanneld,
    EnsureTyped
)
from monai.networks.nets import (
    DenseNet121,
    ResNet,
    EfficientNetBN,
    BasicUNet
)
from monai.metrics import ROCAUCMetric
from monai.utils import set_determinism

# Print MONAI configuration
print_config()

## 3. Configuration and Paths

In [None]:
# Project configuration
CONFIG = {
    'data_dir': '/lustre/work/sweeden/rsna-intracranial-aneurysm-detection.zip',
    'metadata_csv': '/mnt/data/merged_medical_data_train.csv',
    'cache_dir': './cache',
    'output_dir': './outputs',
    'model_dir': './models',
    
    # Data parameters
    'image_size': (128, 128, 64),  # Target 3D volume size
    'spacing': (1.0, 1.0, 1.0),    # Isotropic spacing in mm
    
    # Brain windowing parameters (Hounsfield Units)
    'windows': {
        'brain': {'center': 40, 'width': 80},
        'subdural': {'center': 75, 'width': 215},
        'stroke': {'center': 40, 'width': 40}
    },
    
    # Training parameters
    'batch_size': 4,
    'num_workers': 4,
    'epochs': 100,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'seed': 42,
    
    # Model parameters
    'in_channels': 1,
    'num_classes': 8,  # Multiple aneurysm location classes
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# Create directories
for dir_path in [CONFIG['cache_dir'], CONFIG['output_dir'], CONFIG['model_dir']]:
    Path(dir_path).mkdir(parents=True, exist_ok=True)

# Set deterministic training
set_determinism(seed=CONFIG['seed'])

print(f"Using device: {CONFIG['device']}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 4. Data Loading Functions

In [None]:
def apply_windowing(image: np.ndarray, center: float, width: float) -> np.ndarray:
    """
    Apply windowing to CT image in Hounsfield Units.
    """
    lower = center - width / 2
    upper = center + width / 2
    image = np.clip(image, lower, upper)
    image = (image - lower) / (upper - lower)
    return image

def load_dicom_series(dicom_dir: str) -> np.ndarray:
    """
    Load a DICOM series and convert to HU.
    """
    slices = []
    for filename in sorted(glob.glob(os.path.join(dicom_dir, '*.dcm'))):
        ds = pydicom.dcmread(filename)
        # Convert to HU
        pixel_array = ds.pixel_array.astype(np.float32)
        if hasattr(ds, 'RescaleSlope') and hasattr(ds, 'RescaleIntercept'):
            pixel_array = pixel_array * ds.RescaleSlope + ds.RescaleIntercept
        slices.append(pixel_array)
    
    if slices:
        volume = np.stack(slices, axis=-1)
        return volume
    return None

def load_nifti_volume(nifti_path: str) -> np.ndarray:
    """
    Load a NIfTI volume.
    """
    nifti = nib.load(nifti_path)
    volume = nifti.get_fdata()
    return volume

def prepare_data_list(metadata_csv: str, data_dir: str, limit: Optional[int] = None) -> List[Dict]:
    """
    Prepare data list for MONAI Dataset.
    """
    df = pd.read_csv(metadata_csv)
    if limit:
        df = df.head(limit)
    
    data_list = []
    for idx, row in df.iterrows():
        # This is a placeholder - adjust based on actual data structure
        data_dict = {
            'image': row.get('image_path', ''),  # Path to NIfTI or DICOM
            'label': row.get('label', 0),
            'series_uid': row.get('SeriesInstanceUID', ''),
            'patient_id': row.get('PatientID', '')
        }
        data_list.append(data_dict)
    
    return data_list

## 5. MONAI Transforms

In [None]:
# Define transforms for training
train_transforms = Compose([
    LoadImaged(keys=['image']),
    EnsureChannelFirstd(keys=['image']),
    Spacingd(keys=['image'], pixdim=CONFIG['spacing'], mode='bilinear'),
    Orientationd(keys=['image'], axcodes='RAS'),
    ScaleIntensityRanged(
        keys=['image'],
        a_min=-1000,
        a_max=1000,
        b_min=0,
        b_max=1,
        clip=True
    ),
    CropForegroundd(keys=['image'], source_key='image'),
    Resized(keys=['image'], spatial_size=CONFIG['image_size']),
    RandFlipd(keys=['image'], prob=0.5, spatial_axis=0),
    RandRotate90d(keys=['image'], prob=0.5, max_k=3),
    RandShiftIntensityd(keys=['image'], offsets=0.1, prob=0.5),
    ToTensord(keys=['image', 'label'])
])

# Define transforms for validation
val_transforms = Compose([
    LoadImaged(keys=['image']),
    EnsureChannelFirstd(keys=['image']),
    Spacingd(keys=['image'], pixdim=CONFIG['spacing'], mode='bilinear'),
    Orientationd(keys=['image'], axcodes='RAS'),
    ScaleIntensityRanged(
        keys=['image'],
        a_min=-1000,
        a_max=1000,
        b_min=0,
        b_max=1,
        clip=True
    ),
    CropForegroundd(keys=['image'], source_key='image'),
    Resized(keys=['image'], spatial_size=CONFIG['image_size']),
    ToTensord(keys=['image', 'label'])
])

## 6. Multi-Modal 3D CNN Model

In [None]:
class MultiModal3DCNN(nn.Module):
    """
    Multi-modal 3D CNN for medical image analysis.
    Supports both NIfTI and DICOM inputs with late fusion.
    """
    def __init__(self, in_channels=1, num_classes=8, dropout_prob=0.5):
        super().__init__()
        
        # Branch for NIfTI volumes
        self.nifti_branch = nn.Sequential(
            nn.Conv3d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),
            
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),
            
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),
            
            nn.Conv3d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm3d(256),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool3d(1)
        )
        
        # Branch for DICOM volumes (same architecture)
        self.dicom_branch = nn.Sequential(
            nn.Conv3d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),
            
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),
            
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(2),
            
            nn.Conv3d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm3d(256),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool3d(1)
        )
        
        # Fusion and classification layers
        self.fusion = nn.Sequential(
            nn.Linear(512, 256),  # 256 from each branch
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob)
        )
        
        # Multi-label classification heads
        self.location_classifier = nn.Linear(128, num_classes)  # Location classes
        self.aneurysm_classifier = nn.Linear(128, 1)  # Any aneurysm present
        
    def forward(self, nifti_input=None, dicom_input=None):
        features = []
        
        if nifti_input is not None:
            nifti_feat = self.nifti_branch(nifti_input)
            features.append(nifti_feat.view(nifti_feat.size(0), -1))
        
        if dicom_input is not None:
            dicom_feat = self.dicom_branch(dicom_input)
            features.append(dicom_feat.view(dicom_feat.size(0), -1))
        
        # Handle single modality
        if len(features) == 1:
            combined = torch.cat([features[0], torch.zeros_like(features[0])], dim=1)
        else:
            combined = torch.cat(features, dim=1)
        
        # Fusion
        fused = self.fusion(combined)
        
        # Multi-label outputs
        location_logits = self.location_classifier(fused)
        aneurysm_logits = self.aneurysm_classifier(fused)
        
        return {
            'location': torch.sigmoid(location_logits),
            'aneurysm': torch.sigmoid(aneurysm_logits)
        }

# Initialize model
model = MultiModal3DCNN(
    in_channels=CONFIG['in_channels'],
    num_classes=CONFIG['num_classes']
).to(CONFIG['device'])

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 7. Loss Functions and Metrics

In [None]:
class FocalLoss(nn.Module):
    """
    Focal Loss for addressing class imbalance in multi-label classification.
    """
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss

# Initialize loss functions
location_loss_fn = FocalLoss(alpha=1, gamma=2)
aneurysm_loss_fn = nn.BCELoss()

# Initialize metrics
auc_metric = ROCAUCMetric()

## 8. Training Loop

In [None]:
def train_epoch(model, dataloader, optimizer, device):
    """
    Train for one epoch.
    """
    model.train()
    epoch_loss = 0
    
    for batch in dataloader:
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(nifti_input=images)
        
        # Calculate losses
        location_loss = location_loss_fn(outputs['location'], labels)
        aneurysm_loss = aneurysm_loss_fn(
            outputs['aneurysm'], 
            (labels.sum(dim=1, keepdim=True) > 0).float()
        )
        
        total_loss = location_loss + aneurysm_loss
        
        # Backward pass
        total_loss.backward()
        optimizer.step()
        
        epoch_loss += total_loss.item()
    
    return epoch_loss / len(dataloader)

def validate_epoch(model, dataloader, device):
    """
    Validate for one epoch.
    """
    model.eval()
    epoch_loss = 0
    all_outputs = []
    all_labels = []
    
    with torch.no_grad():
        for batch in dataloader:
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(nifti_input=images)
            
            location_loss = location_loss_fn(outputs['location'], labels)
            aneurysm_loss = aneurysm_loss_fn(
                outputs['aneurysm'],
                (labels.sum(dim=1, keepdim=True) > 0).float()
            )
            
            total_loss = location_loss + aneurysm_loss
            epoch_loss += total_loss.item()
            
            all_outputs.append(outputs['aneurysm'].cpu())
            all_labels.append((labels.sum(dim=1, keepdim=True) > 0).float().cpu())
    
    # Calculate metrics
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    auc = auc_metric(all_outputs, all_labels)
    
    return epoch_loss / len(dataloader), auc

## 9. Main Training Pipeline

In [None]:
def main_training_loop():
    """
    Main training pipeline.
    """
    # Prepare dummy data for demonstration
    # In practice, load from actual dataset
    dummy_data = [
        {
            'image': torch.randn(1, *CONFIG['image_size']),
            'label': torch.zeros(CONFIG['num_classes'])
        }
        for _ in range(20)
    ]
    
    # Split data
    train_data = dummy_data[:15]
    val_data = dummy_data[15:]
    
    # Create datasets and dataloaders
    train_dataset = MonaiDataset(data=train_data, transform=None)
    val_dataset = MonaiDataset(data=val_data, transform=None)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        num_workers=0
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=0
    )
    
    # Initialize model, optimizer, scheduler
    model = MultiModal3DCNN(
        in_channels=CONFIG['in_channels'],
        num_classes=CONFIG['num_classes']
    ).to(CONFIG['device'])
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay']
    )
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=CONFIG['epochs']
    )
    
    # Training loop
    best_auc = 0
    train_losses = []
    val_losses = []
    val_aucs = []
    
    for epoch in range(CONFIG['epochs']):
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, CONFIG['device'])
        
        # Validate
        val_loss, val_auc = validate_epoch(model, val_loader, CONFIG['device'])
        
        # Update scheduler
        scheduler.step()
        
        # Save metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        val_aucs.append(val_auc)
        
        # Save best model
        if val_auc > best_auc:
            best_auc = val_auc
            torch.save(
                model.state_dict(),
                os.path.join(CONFIG['model_dir'], 'best_model.pth')
            )
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}/{CONFIG['epochs']}")
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val Loss: {val_loss:.4f}")
            print(f"Val AUC: {val_auc:.4f}")
            print("-" * 50)
    
    return train_losses, val_losses, val_aucs

# Uncomment to run training
# train_losses, val_losses, val_aucs = main_training_loop()

## 10. Visualization Functions

In [None]:
def plot_training_curves(train_losses, val_losses, val_aucs):
    """
    Plot training curves.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss curves
    ax1.plot(train_losses, label='Train Loss')
    ax1.plot(val_losses, label='Val Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True)
    
    # AUC curve
    ax2.plot(val_aucs, label='Val AUC')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('AUC')
    ax2.set_title('Validation AUC')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

def visualize_3d_volume(volume, slice_idx=None):
    """
    Visualize 3D medical volume.
    """
    if slice_idx is None:
        slice_idx = volume.shape[2] // 2
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    # Axial view
    axes[0].imshow(volume[:, :, slice_idx], cmap='gray')
    axes[0].set_title(f'Axial Slice {slice_idx}')
    axes[0].axis('off')
    
    # Coronal view
    axes[1].imshow(volume[:, volume.shape[1]//2, :], cmap='gray')
    axes[1].set_title('Coronal View')
    axes[1].axis('off')
    
    # Sagittal view
    axes[2].imshow(volume[volume.shape[0]//2, :, :], cmap='gray')
    axes[2].set_title('Sagittal View')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()

## 11. Inference and Prediction

In [None]:
def predict_single_volume(model, volume, device='cuda'):
    """
    Make prediction for a single volume.
    """
    model.eval()
    
    # Prepare input
    if isinstance(volume, np.ndarray):
        volume = torch.from_numpy(volume).float()
    
    if len(volume.shape) == 3:
        volume = volume.unsqueeze(0).unsqueeze(0)  # Add batch and channel dims
    elif len(volume.shape) == 4:
        volume = volume.unsqueeze(0)  # Add batch dim
    
    volume = volume.to(device)
    
    # Predict
    with torch.no_grad():
        outputs = model(nifti_input=volume)
    
    return {
        'location_probs': outputs['location'].cpu().numpy(),
        'aneurysm_prob': outputs['aneurysm'].cpu().numpy()
    }

# Example usage
# dummy_volume = torch.randn(1, *CONFIG['image_size'])
# predictions = predict_single_volume(model, dummy_volume, CONFIG['device'])
# print(f"Aneurysm probability: {predictions['aneurysm_prob'][0, 0]:.3f}")

## 12. Model Explainability (Grad-CAM)

In [None]:
from monai.visualize import GradCAM

def generate_gradcam(model, volume, target_layer, device='cuda'):
    """
    Generate Grad-CAM visualization for model predictions.
    """
    # Initialize Grad-CAM
    cam = GradCAM(nn_module=model, target_layers=target_layer)
    
    # Prepare input
    if isinstance(volume, np.ndarray):
        volume = torch.from_numpy(volume).float()
    
    if len(volume.shape) == 3:
        volume = volume.unsqueeze(0).unsqueeze(0)
    
    volume = volume.to(device)
    
    # Generate CAM
    result = cam(x=volume)
    
    return result[0].cpu().numpy()

# Example usage (requires target layer specification)
# target_layer = 'nifti_branch.12'  # Last conv layer
# cam_result = generate_gradcam(model, dummy_volume, target_layer)

## 13. Export and Save Results

In [None]:
def save_predictions_to_csv(predictions, output_path):
    """
    Save predictions to CSV file.
    """
    df = pd.DataFrame(predictions)
    df.to_csv(output_path, index=False)
    print(f"Predictions saved to {output_path}")

def export_model_to_onnx(model, output_path, input_shape):
    """
    Export model to ONNX format for deployment.
    """
    model.eval()
    dummy_input = torch.randn(1, 1, *input_shape).to(CONFIG['device'])
    
    torch.onnx.export(
        model,
        (dummy_input, None),  # Multi-modal inputs
        output_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['nifti_input', 'dicom_input'],
        output_names=['location', 'aneurysm'],
        dynamic_axes={
            'nifti_input': {0: 'batch_size'},
            'dicom_input': {0: 'batch_size'},
            'location': {0: 'batch_size'},
            'aneurysm': {0: 'batch_size'}
        }
    )
    print(f"Model exported to {output_path}")

# Example usage
# export_model_to_onnx(model, 'model.onnx', CONFIG['image_size'])

## 14. Summary and Next Steps

This notebook provides a comprehensive framework for medical image analysis using MONAI, specifically designed for the intracranial aneurysm detection project. 

### Key Features:
- Multi-modal support (NIfTI and DICOM)
- 3D CNN architecture with late fusion
- Multi-label classification
- Brain-specific windowing
- MONAI transforms and data pipeline
- Training/validation loops
- Model explainability (Grad-CAM)
- Export capabilities (ONNX)

### Next Steps:
1. Connect to actual dataset (unzip and process the RSNA dataset)
2. Implement proper data manifest creation
3. Set up K-fold cross-validation
4. Fine-tune hyperparameters
5. Implement ensemble methods
6. Add more sophisticated augmentations
7. Deploy model for inference

### Note:
This implementation is for research purposes only and should not be used for clinical diagnosis without proper validation and regulatory approval.