# 🎯 Colab Training Template

Optimized training notebook for Google Colab.

**Before running:**
1. Complete `00_Colab_Setup.ipynb` first
2. Enable GPU runtime
3. Ensure dataset is in Drive

## 🚀 Quick Setup

In [None]:
# Mount Drive and setup paths
from google.colab import drive
import os
import sys

drive.mount('/content/drive')

# Set paths
PROJECT_ROOT = '/content/drive/MyDrive/endometriosis-uncertainty-seg'
CODE_DIR = '/content/endo-uncertainty-seg'

# Create symlinks
!mkdir -p {CODE_DIR}
%cd {CODE_DIR}

!ln -sf {PROJECT_ROOT}/data ./data
!ln -sf {PROJECT_ROOT}/experiments ./experiments

sys.path.append(CODE_DIR)

print("✓ Setup complete")

In [None]:
# Verify GPU
import torch

print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ GPU not available! Go to Runtime → Change runtime type → GPU")

## 📦 Install Dependencies

In [None]:
# Install missing packages (Colab has most pre-installed)
!pip install -q monai nibabel SimpleITK
!pip install -q pytorch-lightning
!pip install -q wandb  # Optional: for experiment tracking

print("✓ Dependencies installed")

## 🔧 Configuration

In [None]:
# Training configuration optimized for Colab T4
config = {
    # Paths
    'data_root': f'{PROJECT_ROOT}/data/raw/UT-EndoMRI',
    'splits_file': f'{PROJECT_ROOT}/data/splits/split_info.json',
    'checkpoint_dir': f'{PROJECT_ROOT}/experiments/checkpoints',
    'log_dir': f'{PROJECT_ROOT}/experiments/logs',
    
    # Data
    'dataset_name': 'D2_TCPW',
    'sequences': ['T2FS'],
    'structures': ['uterus', 'ovary', 'endometrioma'],
    
    # Preprocessing (smaller size for T4 GPU)
    'target_spacing': (5.0, 5.0, 5.0),
    'target_size': (96, 96, 24),  # Smaller for Colab
    
    # Training (optimized for T4)
    'batch_size': 2,  # Max for T4 with this size
    'num_workers': 2,
    'epochs': 100,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    
    # Save every N epochs (important for Colab!)
    'save_frequency': 10,
    
    # Model
    'model_name': 'unet',  # Start with simple U-Net
    'num_classes': 4,  # Background + 3 structures
    
    # Optimization
    'mixed_precision': True,  # Faster training
    'gradient_clip': 1.0,
}

print("Configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

## 📊 Load Data

In [None]:
# This assumes you have the data utils copied to the project
# If not, you'll need to define the necessary classes here or upload the files

import json

# Load splits
with open(config['splits_file'], 'r') as f:
    splits = json.load(f)

print(f"Data splits:")
print(f"  Train: {len(splits['train'])} subjects")
print(f"  Val: {len(splits['val'])} subjects")
print(f"  Test: {len(splits['test'])} subjects")

## 🏗️ Create DataLoaders

**Note:** This cell assumes you have your source files uploaded. If not, you'll need to copy the relevant code from your src/ directory.

In [None]:
# Import your custom modules
# Make sure src/ folder is uploaded or linked

try:
    from src.data.dataloader import get_dataloaders
    
    dataloaders = get_dataloaders(
        data_root=config['data_root'],
        splits=splits,
        config=config,
        dataset_name=config['dataset_name'],
        num_workers=config['num_workers']
    )
    
    print("✓ DataLoaders created")
    print(f"  Train batches: {len(dataloaders['train'])}")
    print(f"  Val batches: {len(dataloaders['val'])}")
    
except ImportError:
    print("⚠️ Could not import data modules")
    print("Please upload your src/ files or define the classes here")

## 🤖 Create Model

Starting with a simple 3D U-Net for baseline.

In [None]:
from monai.networks.nets import UNet
import torch.nn as nn

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNet(
    spatial_dims=3,
    in_channels=1,  # Single MRI sequence
    out_channels=config['num_classes'],
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
)

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✓ Model created")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Device: {device}")

## 📈 Training Setup

In [None]:
from monai.losses import DiceLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler

# Loss function
loss_fn = DiceLoss(
    include_background=False,
    to_onehot_y=True,
    softmax=True
)

# Optimizer
optimizer = AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay']
)

# Scheduler
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=config['epochs'],
    eta_min=1e-7
)

# Mixed precision
scaler = GradScaler() if config['mixed_precision'] else None

print("✓ Training setup complete")

## 💾 Checkpoint Management

In [None]:
import os
from pathlib import Path

# Create checkpoint directory
checkpoint_dir = Path(config['checkpoint_dir'])
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Check for existing checkpoints to resume
latest_checkpoint = checkpoint_dir / 'latest.pth'
start_epoch = 0
best_dice = 0.0

if latest_checkpoint.exists():
    print(f"Found checkpoint: {latest_checkpoint}")
    response = input("Resume from checkpoint? (y/n): ")
    
    if response.lower() == 'y':
        checkpoint = torch.load(latest_checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_dice = checkpoint.get('best_dice', 0.0)
        
        print(f"✓ Resumed from epoch {start_epoch}")
        print(f"  Best Dice: {best_dice:.4f}")
else:
    print("No checkpoint found. Starting from scratch.")

print(f"Starting training from epoch {start_epoch}")

## 🏋️ Training Loop

**Important for Colab:**
- Saves checkpoints every 10 epochs
- Can be interrupted and resumed
- Progress saved to Google Drive

In [None]:
from tqdm.notebook import tqdm
import time

def train_epoch(model, dataloader, optimizer, loss_fn, device, scaler=None):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc='Training'):
        images = batch['image'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        if scaler:
            with autocast():
                outputs = model(images)
                loss = loss_fn(outputs, labels)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def validate(model, dataloader, loss_fn, device):
    """Validate model"""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validation'):
            images = batch['image'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

print("Training functions defined")

In [None]:
# Main training loop
from datetime import datetime

print(f"Starting training at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Training for {config['epochs'] - start_epoch} epochs")
print("="*60)

try:
    for epoch in range(start_epoch, config['epochs']):
        print(f"\nEpoch {epoch+1}/{config['epochs']}")
        
        # Train
        train_loss = train_epoch(
            model, dataloaders['train'], optimizer, 
            loss_fn, device, scaler
        )
        
        # Validate
        val_loss = validate(
            model, dataloaders['val'], loss_fn, device
        )
        
        # Scheduler step
        scheduler.step()
        
        # Calculate Dice (1 - loss for DiceLoss)
        val_dice = 1 - val_loss
        
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f}")
        print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Save checkpoint
        if (epoch + 1) % config['save_frequency'] == 0 or val_dice > best_dice:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'val_dice': val_dice,
                'best_dice': max(best_dice, val_dice)
            }
            
            # Save latest
            torch.save(checkpoint, checkpoint_dir / 'latest.pth')
            
            # Save best
            if val_dice > best_dice:
                best_dice = val_dice
                torch.save(checkpoint, checkpoint_dir / 'best.pth')
                print(f"✓ New best model! Dice: {best_dice:.4f}")
            
            # Save periodic
            if (epoch + 1) % config['save_frequency'] == 0:
                torch.save(checkpoint, checkpoint_dir / f'epoch_{epoch+1}.pth')
                print(f"✓ Checkpoint saved (epoch {epoch+1})")

except KeyboardInterrupt:
    print("\n⚠️ Training interrupted!")
    print("Checkpoint saved. You can resume later.")

except Exception as e:
    print(f"\n❌ Error during training: {e}")
    import traceback
    traceback.print_exc()

finally:
    print(f"\nTraining session ended at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Best Dice score: {best_dice:.4f}")
    print(f"Checkpoints saved to: {checkpoint_dir}")

## 📊 Visualize Results

In [None]:
# Load best model and visualize predictions
import matplotlib.pyplot as plt
import numpy as np

# Load best checkpoint
best_checkpoint = torch.load(checkpoint_dir / 'best.pth')
model.load_state_dict(best_checkpoint['model_state_dict'])
model.eval()

# Get a validation batch
val_batch = next(iter(dataloaders['val']))
images = val_batch['image'].to(device)
labels = val_batch['label'].to(device)

# Predict
with torch.no_grad():
    outputs = model(images)
    preds = torch.argmax(outputs, dim=1)

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

slice_idx = images.shape[-1] // 2

axes[0].imshow(images[0, 0, :, :, slice_idx].cpu(), cmap='gray')
axes[0].set_title('Input Image')
axes[0].axis('off')

axes[1].imshow(labels[0, :, :, slice_idx].cpu(), cmap='jet')
axes[1].set_title('Ground Truth')
axes[1].axis('off')

axes[2].imshow(preds[0, :, :, slice_idx].cpu(), cmap='jet')
axes[2].set_title('Prediction')
axes[2].axis('off')

plt.tight_layout()
plt.savefig(f'{PROJECT_ROOT}/experiments/results/prediction_sample.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Visualization saved to {PROJECT_ROOT}/experiments/results/")

## 💾 Download Checkpoints (Optional)

Download checkpoints to your local machine for backup.

In [None]:
from google.colab import files

# Download best model
# files.download(str(checkpoint_dir / 'best.pth'))

print("Uncomment to download checkpoints")

## 🎉 Training Complete!

### Next Steps:
1. **Evaluate on test set** - Use test_phase1.py or create evaluation notebook
2. **Improve model** - Try Transformer architecture (Phase 3)
3. **Add uncertainty** - Implement MC Dropout or Ensembles (Phase 4)
4. **Analyze results** - Compare with paper baseline

### Your checkpoints are saved in:
```
Google Drive/endometriosis-uncertainty-seg/experiments/checkpoints/
├── best.pth      # Best model
├── latest.pth    # Latest checkpoint (for resuming)
└── epoch_*.pth   # Periodic checkpoints
```