# Tumor Growth Prediction - Demo Notebook

This notebook demonstrates how to use the tumor growth prediction framework for training and inference on longitudinal medical imaging data.

## 1. Setup and Imports

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from src.models import Conv3DLSTM, Recurrent3DCNN, Baseline3DCNN
from src.data import TumorGrowthDataset, LongitudinalDataLoader
from src.training import Trainer, CombinedLoss, compute_metrics
from src.utils import Config, visualize_prediction

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Load Configuration

In [None]:
# Load configuration from file
config = Config.from_file('configs/default_config.yaml')

# Or create configuration programmatically
# config = Config()

# Print configuration
print("Model Configuration:")
for key, value in config.to_dict().items():
    print(f"  {key}: {value}")

## 3. Prepare Data

Create data loaders for training, validation, and testing.

In [None]:
# Create data loaders
train_loader, val_loader, test_loader = LongitudinalDataLoader.create_dataloaders(
    data_dir=config.data_dir,
    batch_size=config.batch_size,
    num_time_steps=config.num_time_steps,
    train_split=config.train_split,
    val_split=config.val_split,
    num_workers=config.num_workers
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 4. Visualize Sample Data

In [None]:
# Get a sample batch
inputs, targets = next(iter(train_loader))

print(f"Input shape: {inputs.shape}")  # (batch, time_steps, channels, D, H, W)
print(f"Target shape: {targets.shape}")  # (batch, channels, D, H, W)

# Visualize first sample
sample_input = inputs[0]  # (time_steps, channels, D, H, W)
sample_target = targets[0]  # (channels, D, H, W)

# Plot middle slices of each time step
fig, axes = plt.subplots(2, config.num_time_steps, figsize=(16, 8))
slice_idx = sample_input.shape[2] // 2

for t in range(config.num_time_steps):
    # Axial view
    axes[0, t].imshow(sample_input[t, 0, slice_idx, :, :], cmap='gray')
    axes[0, t].set_title(f'Time {t} - Axial')
    axes[0, t].axis('off')
    
    # Coronal view
    axes[1, t].imshow(sample_input[t, 0, :, slice_idx, :], cmap='gray')
    axes[1, t].set_title(f'Time {t} - Coronal')
    axes[1, t].axis('off')

plt.tight_layout()
plt.show()

## 5. Create Model

Choose and instantiate a model architecture.

In [None]:
# Create Conv3D-LSTM model
model = Conv3DLSTM(
    in_channels=config.in_channels,
    base_features=config.base_features,
    hidden_size=config.hidden_size,
    num_lstm_layers=config.num_lstm_layers,
    num_time_steps=config.num_time_steps,
    output_channels=config.output_channels
)

# Or create Recurrent 3D CNN
# model = Recurrent3DCNN(
#     in_channels=config.in_channels,
#     hidden_channels=[32, 64, 128],
#     num_layers=3,
#     output_channels=config.output_channels
# )

# Print model info
num_params = sum(p.numel() for p in model.parameters())
print(f"Model: {model.__class__.__name__}")
print(f"Total parameters: {num_params:,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 6. Test Forward Pass

In [None]:
# Test forward pass with sample data
model.eval()
with torch.no_grad():
    sample_batch = inputs[:2]  # Take 2 samples
    output = model(sample_batch)

print(f"Input shape: {sample_batch.shape}")
print(f"Output shape: {output.shape}")
print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")

## 7. Setup Training

Configure loss function, optimizer, and trainer.

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create loss function
criterion = CombinedLoss(
    mse_weight=config.mse_weight,
    dice_weight=config.dice_weight,
    smooth_weight=config.smooth_weight
)

# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    device=device,
    learning_rate=config.learning_rate,
    num_epochs=config.num_epochs,
    checkpoint_dir=config.checkpoint_dir,
    log_dir=config.log_dir
)

print("Trainer configured successfully!")

## 8. Train Model

Start training (note: this may take a while).

In [None]:
# Train the model
# Note: For demonstration, you might want to reduce num_epochs
# trainer.num_epochs = 5  # Quick demo

trainer.train()

print(f"\nBest validation loss: {trainer.best_val_loss:.6f}")

## 9. Evaluate on Test Set

In [None]:
# Load best model
best_checkpoint = torch.load('checkpoints/best_model.pth')
model.load_state_dict(best_checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

# Evaluate on test set
all_metrics = []

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        outputs = model(inputs)
        
        # Compute metrics for each sample in batch
        for i in range(outputs.shape[0]):
            metrics = compute_metrics(outputs[i:i+1], targets[i:i+1])
            all_metrics.append(metrics)

# Print average metrics
print("\nTest Set Metrics:")
print("=" * 50)
for key in all_metrics[0].keys():
    values = [m[key] for m in all_metrics if m[key] != float('inf')]
    if values:
        mean_val = np.mean(values)
        std_val = np.std(values)
        print(f"{key}: {mean_val:.4f} Â± {std_val:.4f}")

## 10. Visualize Predictions

In [None]:
# Get a test sample
test_inputs, test_targets = next(iter(test_loader))
test_inputs = test_inputs.to(device)
test_targets = test_targets.to(device)

# Make prediction
with torch.no_grad():
    predictions = model(test_inputs)

# Visualize first sample
visualize_prediction(
    test_inputs[0],
    predictions[0],
    test_targets[0],
    save_path='outputs/sample_prediction.png'
)

print("Visualization saved to outputs/sample_prediction.png")

## 11. Future Predictions

Generate predictions for multiple future time steps.

In [None]:
# For Conv3DLSTM, use predict_sequence
if hasattr(model, 'predict_sequence'):
    with torch.no_grad():
        future_predictions = model.predict_sequence(test_inputs[:1])
    
    print(f"Future predictions shape: {future_predictions.shape}")
    
    # Visualize predictions at different future time steps
    fig, axes = plt.subplots(1, future_predictions.shape[1], figsize=(20, 4))
    slice_idx = future_predictions.shape[3] // 2
    
    for t in range(future_predictions.shape[1]):
        axes[t].imshow(future_predictions[0, t, 0, slice_idx, :, :].cpu(), cmap='hot')
        axes[t].set_title(f'Future T+{t+1}')
        axes[t].axis('off')
    
    plt.tight_layout()
    plt.show()

# For Recurrent3DCNN, use predict_future
elif hasattr(model, 'predict_future'):
    with torch.no_grad():
        future_predictions = model.predict_future(test_inputs[:1], num_future_steps=3)
    
    print(f"Future predictions shape: {future_predictions.shape}")

## 12. Summary

This notebook demonstrated:
1. Loading and preparing longitudinal medical imaging data
2. Creating and configuring deep learning models for tumor growth prediction
3. Training models with appropriate loss functions
4. Evaluating model performance with multiple metrics
5. Visualizing predictions and future forecasts

For more information, see the [README](../README.md) file.