# Axon IA: Model Training

This notebook demonstrates how to train a segmentation model with Axon IA. We'll cover:

1. Loading the dataset
2. Creating a model
3. Setting up training parameters
4. Training the model
5. Evaluating performance

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from pathlib import Path

# Add parent directory to path for imports
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

## 1. Load the Dataset

First, let's load the dataset prepared in the previous notebook:

In [None]:
from axon_ia.data.dataset import AxonDataset
from axon_ia.data.transforms import get_train_transform, get_val_transform

# Set data paths
data_dir = Path("../data_example")

# Check if the data directory exists
if not data_dir.exists():
    print("Data directory not found. Please run the '01_data_preparation.ipynb' notebook first.")
else:
    # Create training dataset
    train_dataset = AxonDataset(
        root_dir=data_dir,
        split="train",
        modalities=["flair", "t1", "t2", "dwi"],
        target="mask",
        transform=get_train_transform()
    )
    
    # For this example, we'll use the same data for validation
    val_dataset = AxonDataset(
        root_dir=data_dir,
        split="train",  # Using train data for validation in this example
        modalities=["flair", "t1", "t2", "dwi"],
        target="mask",
        transform=get_val_transform()
    )
    
    print(f"Training dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")

In [None]:
# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=1,  # Small batch size for this example
    shuffle=True,
    num_workers=0,  # Set to 0 for easier debugging
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

## 2. Create the Model

Now, let's create a model for training:

In [None]:
from axon_ia.models import create_model

# Select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Get sample to determine input shape
sample = train_dataset[0]
input_shape = sample["image"].shape[1:]  # (D, H, W)
in_channels = sample["image"].shape[0]   # Number of modalities

print(f"Input shape: {input_shape}, Channels: {in_channels}")

# Create model
model = create_model(
    architecture="unetr",  # Using UNETR for this example
    in_channels=in_channels,
    out_channels=1,  # Binary segmentation
    img_size=input_shape,
    feature_size=8,  # Small feature size for this example
    hidden_size=128,  # Small hidden size for this example
    use_deep_supervision=True
)

# Move model to device
model = model.to(device)

## 3. Set Up Training Parameters

Let's set up the loss function, optimizer, and learning rate scheduler:

In [None]:
from axon_ia.losses import create_loss_function
from axon_ia.training import create_scheduler

# Create loss function
loss_fn = create_loss_function(
    loss_type="dice_ce",
    include_background=False,
    ce_weight=1.0,
    dice_weight=1.0
)

# Create optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.01
)

# Create learning rate scheduler
scheduler = create_scheduler(
    optimizer,
    scheduler_type="cosine_warmup",
    num_epochs=10,  # Small number of epochs for this example
    warmup_epochs=2
)

## 4. Train the Model

Now, let's train the model:

In [None]:
from axon_ia.training import Trainer
from axon_ia.training.callbacks import (
    EarlyStopping,
    ModelCheckpoint
)

# Create output directory
output_dir = Path("../outputs/example_training")
output_dir.mkdir(parents=True, exist_ok=True)

# Create callbacks
callbacks = [
    EarlyStopping(
        monitor="val_dice",
        patience=5,
        mode="max"
    ),
    ModelCheckpoint(
        filepath=output_dir / "checkpoints" / "model_{epoch:02d}.pth",
        monitor="val_dice",
        save_best_only=True,
        mode="max"
    )
]

# Create trainer
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    device=device,
    lr_scheduler=scheduler,
    callbacks=callbacks,
    grad_clip=1.0,
    use_amp=False,  # Disable mixed precision for this example
    deep_supervision=True
)

In [None]:
# Train the model
# Note: For this example, we're using a small number of epochs
history = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=10,
    val_interval=1,
    log_interval=1,
    checkpoint_dir=output_dir / "checkpoints",
    metrics=["dice", "iou"]
)

## 5. Analyze Training Results

Let's visualize the training history:

In [None]:
# Plot training and validation loss
plt.figure(figsize=(12, 5))

# Training loss
train_losses = [epoch["loss"] for epoch in history["train"]]
plt.subplot(1, 2, 1)
plt.plot(train_losses, 'b-', label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)

# Validation metrics
val_losses = [epoch["val_loss"] for epoch in history["val"]]
val_dice = [epoch.get("val_dice", 0) for epoch in history["val"]]
val_iou = [epoch.get("val_iou", 0) for epoch in history["val"]]

plt.subplot(1, 2, 2)
plt.plot(val_losses, 'r-', label='Validation Loss')
plt.plot(val_dice, 'g-', label='Validation Dice')
plt.plot(val_iou, 'b-', label='Validation IoU')
plt.xlabel('Epoch')
plt.legend()
plt.title('Validation Metrics')
plt.grid(True)

plt.tight_layout()
plt.show()

## 6. Evaluate Model on Sample Data

In [None]:
# Get a sample from the validation dataset
sample = val_dataset[0]

# Get inputs
image = sample["image"].unsqueeze(0).to(device)  # Add batch dimension
mask = sample["mask"].unsqueeze(0).to(device)    # Add batch dimension

# Set model to evaluation mode
model.eval()

# Run inference
with torch.no_grad():
    pred = model(image)
    
    # If deep supervision, take first output
    if isinstance(pred, tuple):
        pred = pred[0]
    
    # Apply sigmoid for binary prediction
    pred_probs = torch.sigmoid(pred)
    pred_binary = (pred_probs > 0.5).float()

# Move to CPU for visualization
image_np = image.cpu().numpy()[0]  # Remove batch dimension
mask_np = mask.cpu().numpy()[0]
pred_np = pred_binary.cpu().numpy()[0]

# Visualize prediction vs ground truth
from axon_ia.utils.visualization import plot_multiple_slices

# Use FLAIR modality for visualization
fig = plot_multiple_slices(
    image=image_np[0],  # First modality (FLAIR)
    mask=mask_np[0],    # Ground truth
    prediction=pred_np[0],  # Prediction
    axis=2,  # Axial view
    num_slices=3,
    figsize=(15, 5),
    title="Prediction vs Ground Truth (Axial View)"
)
plt.show()

# Also show sagittal and coronal views
fig = plot_multiple_slices(
    image=image_np[0],
    mask=mask_np[0],
    prediction=pred_np[0],
    axis=1,  # Coronal view
    num_slices=3,
    figsize=(15, 5),
    title="Prediction vs Ground Truth (Coronal View)"
)
plt.show()

fig = plot_multiple_slices(
    image=image_np[0],
    mask=mask_np[0],
    prediction=pred_np[0],
    axis=0,  # Sagittal view
    num_slices=3,
    figsize=(15, 5),
    title="Prediction vs Ground Truth (Sagittal View)"
)
plt.show()

## 7. Calculate Evaluation Metrics

In [None]:
from axon_ia.evaluation.metrics import compute_metrics

# Calculate metrics
metrics = compute_metrics(
    y_pred=pred_binary.cpu().numpy(),
    y_true=mask.cpu().numpy(),
    metrics=["dice", "iou", "precision", "recall", "hausdorff"]
)

# Print metrics
print("Evaluation Metrics:")
for metric_name, value in metrics.items():
    print(f"{metric_name}: {value:.4f}")

## 8. Generate Evaluation Report

Finally, let's generate a comprehensive evaluation report:

In [None]:
from axon_ia.evaluation.report_generator import generate_evaluation_report

# Create a dictionary of patient metrics
patient_metrics = {
    sample["sample_id"]: metrics
}

# Create dictionaries for patient data
patient_images = {
    sample["sample_id"]: image_np[0]  # Using first modality (FLAIR)
}
patient_targets = {
    sample["sample_id"]: mask_np[0]
}
patient_predictions = {
    sample["sample_id"]: pred_np[0]
}

# Generate evaluation report
report_path = generate_evaluation_report(
    patient_metrics=patient_metrics,
    output_dir=output_dir / "evaluation",
    patient_images=patient_images,
    patient_targets=patient_targets,
    patient_predictions=patient_predictions,
    title="Example Evaluation Report",
    model_name="UNETR"
)

print(f"Evaluation report generated at {report_path}")