# ML Training with MSD Lung RadiObject

This notebook trains a binary classifier for lung tumor detection using the Medical Segmentation Decathlon data.

## Overview

1. **Load** RadiObject from URI (S3 or local)
2. **Explore** data and label distribution
3. **Split** into train/validation sets
4. **Train** a 3D CNN classifier
5. **Evaluate** model performance

## Task

Binary classification: Predict `has_tumor` (0 or 1) from CT volume patches.

**Prerequisites:** Run [05_ingest_msd.ipynb](./05_ingest_msd.ipynb) first to create the MSD Lung RadiObject.

## 1. Setup

In [None]:
import tempfile
import shutil
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from config import MSD_LUNG_URI, S3_REGION
from radiobject import RadiObject, configure
from radiobject.ctx import S3Config
from radiobject.ml import (
    create_training_dataloader,
    create_validation_dataloader,
    Compose,
    IntensityNormalize,
    RandomFlip3D,
)

print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"RadiObject URI: {MSD_LUNG_URI}")

In [None]:
# Determine compute device
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print(f"Training device: {DEVICE}")

In [None]:
# Configure S3 access if using S3 URI
if MSD_LUNG_URI.startswith("s3://"):
    configure(s3=S3Config(region=S3_REGION, max_parallel_ops=8))

TEMP_DIR = tempfile.mkdtemp(prefix="msd_ml_")
print(f"Temp directory: {TEMP_DIR}")

## 2. Load RadiObject from URI

In [None]:
# Load RadiObject
radi = RadiObject(MSD_LUNG_URI)

print(f"RadiObject: {radi}")
print(f"Subjects: {len(radi)}")
print(f"Collections: {radi.collection_names}")

In [None]:
# Display subject metadata
obs_meta = radi.obs_meta.read()
print(f"obs_meta columns: {list(obs_meta.columns)}")
obs_meta.head(10)

## 3. Explore Data

In [None]:
# Label distribution
label_counts = obs_meta['has_tumor'].value_counts().sort_index()

fig, ax = plt.subplots(figsize=(6, 4))
label_counts.plot(kind='bar', ax=ax, color=['steelblue', 'coral'])
ax.set_xlabel('has_tumor')
ax.set_ylabel('Count')
ax.set_title('Label Distribution')
ax.set_xticklabels(['No Tumor (0)', 'Has Tumor (1)'], rotation=0)

for i, v in enumerate(label_counts.values):
    ax.text(i, v + 1, str(v), ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

print(f"Label distribution: {label_counts.to_dict()}")

In [None]:
# Visualize samples with and without tumor
collection_name = radi.collection_names[0]
vc = radi.collection(collection_name)

# Get indices for each class
tumor_subjects = obs_meta[obs_meta['has_tumor'] == 1]['obs_subject_id'].tolist()
no_tumor_subjects = obs_meta[obs_meta['has_tumor'] == 0]['obs_subject_id'].tolist()

fig, axes = plt.subplots(2, 3, figsize=(12, 8))

# No tumor samples
for i, subject_id in enumerate(no_tumor_subjects[:3]):
    vol = radi.loc[subject_id].collection(collection_name).iloc[0]
    mid_z = vol.shape[2] // 2
    axes[0, i].imshow(vol.axial(z=mid_z).T, cmap='gray', origin='lower')
    axes[0, i].set_title(f'{subject_id} (no tumor)')
    axes[0, i].axis('off')

# Tumor samples
for i, subject_id in enumerate(tumor_subjects[:3]):
    vol = radi.loc[subject_id].collection(collection_name).iloc[0]
    mid_z = vol.shape[2] // 2
    axes[1, i].imshow(vol.axial(z=mid_z).T, cmap='gray', origin='lower')
    axes[1, i].set_title(f'{subject_id} (has tumor)')
    axes[1, i].axis('off')

plt.suptitle('Sample CT Scans by Label', fontsize=14)
plt.tight_layout()
plt.show()

## 4. Train/Validation Split

In [None]:
# 80/20 stratified split
all_ids = list(radi.obs_subject_ids)
np.random.seed(42)
np.random.shuffle(all_ids)

split_idx = int(0.8 * len(all_ids))
train_ids = all_ids[:split_idx]
val_ids = all_ids[split_idx:]

print(f"Training subjects: {len(train_ids)}")
print(f"Validation subjects: {len(val_ids)}")

# Check label distribution in splits
train_labels = obs_meta[obs_meta['obs_subject_id'].isin(train_ids)]['has_tumor']
val_labels = obs_meta[obs_meta['obs_subject_id'].isin(val_ids)]['has_tumor']

print(f"\nTrain label distribution: {train_labels.value_counts().to_dict()}")
print(f"Val label distribution: {val_labels.value_counts().to_dict()}")

In [None]:
# Create train/val RadiObjects
train_uri = f"{TEMP_DIR}/train_radi"
val_uri = f"{TEMP_DIR}/val_radi"

radi_train = radi.loc[train_ids].to_radi_object(train_uri)
radi_val = radi.loc[val_ids].to_radi_object(val_uri)

print(f"Train RadiObject: {radi_train}")
print(f"Val RadiObject: {radi_val}")

## 5. Create DataLoaders

In [None]:
# Training hyperparameters
BATCH_SIZE = 4
PATCH_SIZE = (64, 64, 64)

# Define transforms
train_transform = Compose([
    IntensityNormalize(),
    RandomFlip3D(axes=(0, 1, 2), prob=0.5),
])

val_transform = Compose([
    IntensityNormalize(),
])

# Create dataloaders using the new API
train_loader = create_training_dataloader(
    radi_train,
    modalities=[radi.collection_names[0]],
    label_column="has_tumor",
    batch_size=BATCH_SIZE,
    patch_size=PATCH_SIZE,
    num_workers=0,
    pin_memory=False,
    persistent_workers=False,
    transform=train_transform,
)

# Use create_validation_dataloader for val set (no shuffle, no drop_last)
val_loader = create_validation_dataloader(
    radi_val,
    modalities=[radi.collection_names[0]],
    label_column="has_tumor",
    batch_size=BATCH_SIZE,
    patch_size=PATCH_SIZE,
    num_workers=0,
    pin_memory=False,
    transform=val_transform,
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Train transform: {train_transform}")
print(f"Val transform: {val_transform}")

In [None]:
# Inspect a batch
batch = next(iter(train_loader))

print(f"Batch keys: {list(batch.keys())}")
print(f"Image shape: {batch['image'].shape}")  # (B, C, D, H, W)
print(f"Image dtype: {batch['image'].dtype}")
print(f"Labels: {batch['label'].tolist()}")
print(f"Memory per batch: {batch['image'].nbytes / 1024 / 1024:.1f} MB")

## 6. Define Model

A simple 3D CNN classifier with three convolutional layers.

In [None]:
class Simple3DCNN(nn.Module):
    """3D CNN for binary classification."""

    def __init__(self, in_channels: int = 1, num_classes: int = 2):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(16)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(32)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm3d(64)
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool3d(x, 2)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool3d(x, 2)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool(x).flatten(1)
        return self.fc(x)


model = Simple3DCNN(in_channels=1, num_classes=2).to(DEVICE)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 7. Training Loop

In [None]:
# Training configuration
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

# Training history
history = {
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": [],
}

In [None]:
print(f"Training on {DEVICE} for {NUM_EPOCHS} epochs...\n")

for epoch in range(NUM_EPOCHS):
    # Training phase
    model.train()
    train_loss, train_correct, train_total = 0.0, 0, 0

    for batch in train_loader:
        images = batch["image"].to(DEVICE)
        labels = batch["label"].long().to(DEVICE)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_correct += (outputs.argmax(1) == labels).sum().item()
        train_total += labels.size(0)

    # Validation phase
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0

    with torch.no_grad():
        for batch in val_loader:
            images = batch["image"].to(DEVICE)
            labels = batch["label"].long().to(DEVICE)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            val_correct += (outputs.argmax(1) == labels).sum().item()
            val_total += labels.size(0)

    # Record metrics
    history["train_loss"].append(train_loss / len(train_loader))
    history["train_acc"].append(100.0 * train_correct / train_total)
    history["val_loss"].append(val_loss / len(val_loader))
    history["val_acc"].append(100.0 * val_correct / val_total)

    print(
        f"Epoch {epoch + 1:2d}/{NUM_EPOCHS}: "
        f"Train Loss={history['train_loss'][-1]:.4f}, "
        f"Train Acc={history['train_acc'][-1]:.1f}%, "
        f"Val Loss={history['val_loss'][-1]:.4f}, "
        f"Val Acc={history['val_acc'][-1]:.1f}%"
    )

## 8. Visualize Results

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss
axes[0].plot(history["train_loss"], marker="o", label="Train")
axes[0].plot(history["val_loss"], marker="s", label="Validation")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training & Validation Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history["train_acc"], marker="o", label="Train")
axes[1].plot(history["val_acc"], marker="s", label="Validation")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy (%)")
axes[1].set_title("Training & Validation Accuracy")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Final metrics
print("=" * 40)
print("Final Results")
print("=" * 40)
print(f"Best Train Accuracy: {max(history['train_acc']):.1f}%")
print(f"Best Val Accuracy: {max(history['val_acc']):.1f}%")
print(f"Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"Final Val Loss: {history['val_loss'][-1]:.4f}")

## 9. Cleanup

In [None]:
# Cleanup temporary directory
shutil.rmtree(TEMP_DIR)
print(f"Cleaned up: {TEMP_DIR}")

## Summary

This notebook demonstrated:

1. **Loading** RadiObject from configured URI (S3 or local)
2. **Exploring** data and label distributions
3. **Splitting** data into train/validation sets
4. **Training** a 3D CNN classifier with PyTorch
5. **Evaluating** model performance

### ML Framework Features

| Feature | Example |
|---------|---------|
| **Training dataloader** | `create_training_dataloader(radi, ...)` - shuffled, drop_last |
| **Validation dataloader** | `create_validation_dataloader(radi, ...)` - no shuffle |
| **Transform composition** | `Compose([IntensityNormalize(), RandomFlip3D()])` |
| **Patch extraction** | Efficient 64³ patches from ISOTROPIC-tiled volumes |
| **Label integration** | Automatic obs_meta lookup via `label_column` |

### RadiObject Benefits for ML Training

| Feature | Benefit |
|---------|---------|
| **ISOTROPIC tiles** | 64³ chunks optimized for random patch access |
| **TileDB storage** | Memory-mapped I/O, only loads requested data |
| **S3 support** | Direct cloud storage access |
| **Compression** | 3-10x storage reduction with ZSTD |

### Next Steps

- Increase `NUM_EPOCHS` for better convergence
- Try more augmentations: `RandomNoise`, `WindowLevel` for CT
- Use `CacheStrategy.IN_MEMORY` for small datasets
- Experiment with different architectures (ResNet3D, UNet, etc.)