## Colab Setup

In [None]:
import sys

IS_COLAB = 'google.colab' in sys.modules
print(f"Running in Google Colab: {IS_COLAB}")

In [None]:
import os
import sys

if IS_COLAB:
    print("Running in Google Colab environment.")
    if os.path.exists('/content/aai521_3proj'):
        print("Repository already exists. Pulling latest changes...")
        %cd /content/aai521_3proj
        !git pull
    else:
        print("Cloning repository...")
        !git clone https://github.com/swapnilprakashpatil/aai521_3proj.git
        %cd aai521_3proj    
    %pip install -r requirements.txt
    sys.path.append('/content/aai521_3proj/src')
    %ls
else:
    print("Running in local environment. Installing packages...")
    %pip install -r ../requirements.txt
    sys.path.append('../src')

In [None]:
import platform
import psutil
import subprocess
import os

if IS_COLAB:
    print("Google Colab Environment Specifications:")
    print("="*50)
    
    # Get system info
    
    print(f"Operating System: {platform.system()} {platform.release()}")
    print(f"Architecture: {platform.machine()}")
    print(f"Python Version: {platform.python_version()}")
    
    # Memory info
    memory = psutil.virtual_memory()
    print(f"Total RAM: {memory.total / (1024**3):.1f} GB")
    print(f"Available RAM: {memory.available / (1024**3):.1f} GB")
    
    # CPU info
    print(f"CPU Cores: {psutil.cpu_count(logical=False)} physical, {psutil.cpu_count(logical=True)} logical")
    
    # GPU info
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader,nounits'], 
                              capture_output=True, text=True)
        if result.returncode == 0:
            gpu_info = result.stdout.strip().split('\n')
            for i, gpu in enumerate(gpu_info):
                name, memory = gpu.split(', ')
                print(f"GPU {i}: {name}, {memory} MB VRAM")
        else:
            print("GPU: Not detected or nvidia-smi unavailable")
    except:
        print("GPU: Not detected")
    
    # Disk space
    disk = psutil.disk_usage('/')
    print(f"Disk Space: {disk.free / (1024**3):.1f} GB free / {disk.total / (1024**3):.1f} GB total")
    
    print("="*50)

    if not os.path.exists('/content/aai521_3proj'):
        print("WARNING: Cloning project repository required.")
        print("="*50)
else:
    print("Not running in Google Colab environment")

## 1. Setup & Imports

In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Add src to path
sys.path.append('../src')

# Reload modules to pick up latest changes
import importlib
if 'dataset' in sys.modules:
    importlib.reload(sys.modules['dataset'])
if 'models' in sys.modules:
    importlib.reload(sys.modules['models'])
if 'config' in sys.modules:
    importlib.reload(sys.modules['config'])

# Import custom modules
import config
from dataset import create_dataloaders, FloodDataset
from models import create_model, UNetPlusPlus, DeepLabV3Plus, SegFormer
from losses import create_loss_function
from metrics import MetricsTracker, SegmentationMetrics
from trainer import Trainer
from experiment_tracking import ExperimentLogger, ExperimentComparator
from gpu_manager import GPUManager

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# Initialize GPU manager
gpu_mgr = GPUManager()
gpu_mgr.setup()
gpu_mgr.print_nvidia_smi_info()

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {gpu_mgr.is_available()}")
if gpu_mgr.is_available():
    print(f"CUDA device: {gpu_mgr.gpu_name}")
    print(f"CUDA memory: {gpu_mgr.total_memory_gb:.2f} GB")

## 2. Data Loading & Exploration

In [None]:
# Print GPU information
gpu_mgr.print_info()
gpu_mgr.print_memory_stats()

### Visualize Class Distribution

In [None]:
# Clear GPU memory first
gpu_mgr.cleanup()

In [None]:
# Create dataloaders
print("Creating dataloaders...")
train_loader, val_loader, test_loader = create_dataloaders(
    train_dir=config.PROCESSED_TRAIN_DIR,
    val_dir=config.PROCESSED_VAL_DIR,
    test_dir=config.PROCESSED_TEST_DIR,
    batch_size=8,
    num_workers=4,
    pin_memory=True
)

# Calculate class weights from training data
print("\nCalculating class weights from training data...")
class_counts = torch.zeros(config.NUM_CLASSES)

for batch in tqdm(train_loader, desc="Computing class distribution"):
    masks = batch['mask']
    for cls in range(config.NUM_CLASSES):
        class_counts[cls] += (masks == cls).sum()

# Compute weights (inverse frequency)
total_pixels = class_counts.sum()
class_weights = total_pixels / (config.NUM_CLASSES * class_counts)
class_weights = class_weights / class_weights.sum() * config.NUM_CLASSES  # Normalize

print("\nClass distribution and weights:")
for cls, (name, count, weight) in enumerate(zip(config.CLASS_NAMES, class_counts, class_weights)):
    pct = (count / total_pixels * 100).item()
    print(f"  {name}: {count:.0f} pixels ({pct:.2f}%), weight: {weight:.4f}")

print(f"\nDataloaders created:")
print(f"  Train: {len(train_loader)} batches ({len(train_loader.dataset)} samples)")
print(f"  Val: {len(val_loader)} batches ({len(val_loader.dataset)} samples)")
print(f"  Test: {len(test_loader)} batches ({len(test_loader.dataset)} samples)")

### Visualize Sample Data

In [None]:
# Get a batch of training data
train_iter = iter(train_loader)
batch = next(train_iter)
images = batch['image']
masks = batch['mask']

print(f"Batch shape: {images.shape}")
print(f"Mask shape: {masks.shape}")
print(f"Image range: [{images.min():.3f}, {images.max():.3f}]")
print(f"Mask classes: {masks.unique().tolist()}")

# Visualize samples
def visualize_samples(images, masks, num_samples=3):
    """Visualize pre/post images and masks."""
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    
    # Color map for masks
    cmap = plt.cm.get_cmap('tab10', len(config.CLASS_NAMES))
    
    for i in range(num_samples):
        # Pre-event image (first 3 channels)
        pre_img = images[i, :3].permute(1, 2, 0).numpy()
        pre_img = (pre_img - pre_img.min()) / (pre_img.max() - pre_img.min() + 1e-8)
        
        # Post-event image (last 3 channels)
        post_img = images[i, 3:].permute(1, 2, 0).numpy()
        post_img = (post_img - post_img.min()) / (post_img.max() - post_img.min() + 1e-8)
        
        # Mask
        mask = masks[i].numpy()
        
        # Plot pre-event
        axes[i, 0].imshow(pre_img)
        axes[i, 0].set_title('Pre-Event Image', fontsize=12, fontweight='bold')
        axes[i, 0].axis('off')
        
        # Plot post-event
        axes[i, 1].imshow(post_img)
        axes[i, 1].set_title('Post-Event Image', fontsize=12, fontweight='bold')
        axes[i, 1].axis('off')
        
        # Plot mask
        mask_plot = axes[i, 2].imshow(mask, cmap=cmap, vmin=0, vmax=len(config.CLASS_NAMES)-1)
        axes[i, 2].set_title('Ground Truth Mask', fontsize=12, fontweight='bold')
        axes[i, 2].axis('off')
        
        # Add colorbar to last mask
        if i == num_samples - 1:
            cbar = plt.colorbar(mask_plot, ax=axes[i, 2], orientation='horizontal', 
                              pad=0.05, fraction=0.046)
            cbar.set_ticks(range(len(config.CLASS_NAMES)))
            cbar.set_ticklabels(config.CLASS_NAMES, rotation=45, ha='right', fontsize=8)
    
    plt.tight_layout()
    plt.show()

visualize_samples(images, masks, num_samples=3)

## 3. Model Architecture Overview

In [None]:
# Create models for architecture overview
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# List of all models to train
ALL_MODELS = ['unet++', 'deeplabv3+', 'segformer', 'fc_siam_diff', 'siamese_unet++', 'stanet']

models_info = []

for model_name in ALL_MODELS:
    model = create_model(
        model_name=model_name,
        in_channels=6 if 'siamese' not in model_name.lower() else 3,
        num_classes=config.NUM_CLASSES,
        **config.MODEL_CONFIGS.get(model_name, {})
    )
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    models_info.append({
        'Model': model_name.upper(),
        'Total Parameters': f"{total_params:,}",
        'Trainable Parameters': f"{trainable_params:,}",
        'Size (MB)': f"{total_params * 4 / 1e6:.2f}"
    })
    
    del model

# Display as table
models_df = pd.DataFrame(models_info)
print("\n" + "="*80)
print("MODEL ARCHITECTURE COMPARISON")
print("="*80)
print(models_df.to_string(index=False))
print("="*80)

## 4. Training Configuration

In [None]:
# Setup device using GPU manager
device = gpu_mgr.get_device()
print(f"Using device: {device}")

# Training configurations
LIGHT_CONFIG = {
    'num_epochs': 3,
    'batch_size': 8,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'device': device,
    'use_amp': True,
    'gradient_clip': 1.0,
    'max_batches_per_epoch': 50,
    'loss_type': 'combined',
    'early_stopping_patience': 5,
}

TRAINING_CONFIG = {
    'num_epochs': 20,
    'batch_size': 64,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'device': device,
    'use_amp': True,
    'gradient_clip': 1.0,
    'gradient_accumulation_steps': 2,
    'loss_type': 'combined',
    'early_stopping_patience': 5,
}

print("Configuration loaded:")
print(f"  Light validation: {LIGHT_CONFIG['num_epochs']} epochs, {LIGHT_CONFIG['max_batches_per_epoch']} batches/epoch")
print(f"  Full training: {TRAINING_CONFIG['num_epochs']} epochs, early stop patience={TRAINING_CONFIG['early_stopping_patience']}")

In [None]:
# Apply GPU optimizations
if gpu_mgr.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.set_num_threads(8)
    
    print(f"GPU: {gpu_mgr.gpu_name}")
    print(f"GPU Memory: {gpu_mgr.total_memory_gb:.1f} GB")
    print("GPU optimizations enabled")

In [None]:
# Check GPU memory and recommend optimal batch size
if gpu_mgr.is_available():
    recommended_batch = gpu_mgr.recommend_batch_size()
    
    print(f"GPU: {gpu_mgr.gpu_name}")
    print(f"Total Memory: {gpu_mgr.total_memory_gb:.2f} GB")
    print(f"Recommended batch size: {recommended_batch}")
    print(f"Current batch size: {TRAINING_CONFIG['batch_size']}")
    print(f"Effective batch size: {TRAINING_CONFIG['batch_size'] * TRAINING_CONFIG['gradient_accumulation_steps']}")
    
    # Monitor current GPU state
    gpu_mgr.cleanup()
    stats = gpu_mgr.get_memory_stats()
    print(f"\nCurrent GPU Usage: Allocated: {stats['allocated_gb']:.2f} GB, Reserved: {stats['reserved_gb']:.2f} GB, Free: {stats['free_gb']:.2f} GB")
else:
    print("No GPU available - using CPU")

## 5. Light Pipeline Validation

Quickly validate that all models can train without errors before committing to full training.

In [None]:
# Import light pipeline
from light_pipeline import LightPipeline

print("Light pipeline class loaded.")

In [None]:
# Run light validation pipeline for all models
light_pipeline = LightPipeline(LIGHT_CONFIG, class_weights)

# Validate all models
validation_results = light_pipeline.validate_all_models(
    ALL_MODELS,
    train_loader,
    val_loader
)

# Access results
passed_models = light_pipeline.get_passed_models()
failed_models = light_pipeline.get_failed_models()

print(f"\nReady to proceed with {len(passed_models)} validated models.")

### Validation Performance Comparison

In [None]:
# Import visualization utilities
from visualizations import ValidationVisualizer

# Create visualizer instance
viz = ValidationVisualizer()

# Visualize validation results with comprehensive analysis
if validation_results:
    # Figure 1: Training Speed & Success Rate Overview
    viz.plot_validation_overview(validation_results, ALL_MODELS)
    
    # Figure 2: Learning Progress & Convergence Analysis
    viz.plot_learning_analysis(validation_results, ALL_MODELS, len(train_loader))
    
    # Figure 3: Top Performers Podium
    viz.plot_top_performers(validation_results, ALL_MODELS)
    
    # Print detailed statistics
    viz.print_validation_statistics(validation_results, ALL_MODELS, len(train_loader))
else:
    print("No validation results available for visualization.")

## 6. Training Function

In [None]:
def train_model(model_name, config_dict, train_loader, val_loader, class_weights, resume_from_checkpoint=None):
    """Train a single model and return training history.
    
    Args:
        model_name: Name of the model to train
        config_dict: Training configuration dictionary
        train_loader: Training data loader
        val_loader: Validation data loader
        class_weights: Class weights for loss function
        resume_from_checkpoint: Path to checkpoint file to resume training from
    """
    print(f"\n{'='*80}")
    print(f"Training {model_name.upper()}")
    print(f"{'='*80}\n")
    
    # CUDA optimizations
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    
    # Create output directory or use existing for resume
    if resume_from_checkpoint:
        # Extract output dir from checkpoint path
        checkpoint_path = Path(resume_from_checkpoint)
        output_dir = checkpoint_path.parent.parent
        checkpoint_dir = checkpoint_path.parent
        print(f"Resuming from checkpoint: {checkpoint_path}")
    else:
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        output_dir = Path('../outputs/training') / f'{model_name}_{timestamp}'
        output_dir.mkdir(parents=True, exist_ok=True)
        checkpoint_dir = output_dir / 'checkpoints'
        checkpoint_dir.mkdir(exist_ok=True)
    
    # Clear GPU memory before creating model
    if config_dict['device'] == 'cuda':
        try:
            torch.cuda.empty_cache()
            import gc
            gc.collect()
        except:
            pass
    
    # Create model
    model = create_model(
        model_name=model_name,
        in_channels=6 if 'siamese' not in model_name.lower() else 3,
        num_classes=config.NUM_CLASSES,
        **config.MODEL_CONFIGS.get(model_name, {})
    )
    
    # Move model to device with proper handling for meta tensors
    if config_dict['device'] == 'cuda':
        # First move to CPU if needed, then to CUDA to avoid meta tensor issues
        model = model.cpu()
        
        # Clear GPU cache again before moving to CUDA
        try:
            torch.cuda.empty_cache()
            gc.collect()
        except:
            pass
        
        model = model.to(config_dict['device'])
    else:
        model = model.to(config_dict['device'])
    
    # Use torch.compile if available (disabled for now to avoid issues)
    # if hasattr(torch, 'compile') and config_dict['device'] == 'cuda':
    #     try:
    #         model = torch.compile(model, mode='default')
    #     except Exception as e:
    #         pass
    
    # Create loss function
    loss_fn = create_loss_function(
        loss_type=config_dict['loss_type'],
        num_classes=config.NUM_CLASSES,
        class_weights=class_weights.to(config_dict['device']),
        device=config_dict['device'],
        ce_weight=config_dict.get('ce_weight', 0.1),
        dice_weight=config_dict.get('dice_weight', 2.0),
        focal_weight=config_dict.get('focal_weight', 3.0),
        focal_gamma=config_dict.get('focal_gamma', 3.0)
    )
    
    # Create optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config_dict['learning_rate'],
        weight_decay=config_dict['weight_decay']
    )
    
    # Create learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=3
    )
    
    # Create experiment logger
    from experiment_tracking import ExperimentLogger
    logger = ExperimentLogger(experiment_name=f'{model_name}_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
    
    # Load checkpoint if resuming
    start_epoch = 0
    best_val_iou = 0.0
    if resume_from_checkpoint:
        checkpoint = torch.load(resume_from_checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint.get('epoch', 0) + 1
        best_val_iou = checkpoint.get('best_val_iou', 0.0)
        print(f"Resuming from epoch {start_epoch}, best IoU: {best_val_iou:.4f}")
    
    # Create trainer
    experiment_name = f"{model_name}_{output_dir.name.split('_', 1)[1]}" if resume_from_checkpoint else f'{model_name}_{timestamp}'
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        loss_fn=loss_fn,
        num_classes=config.NUM_CLASSES,
        device=config_dict['device'],
        checkpoint_dir=checkpoint_dir,
        experiment_name=experiment_name,
        use_amp=config_dict['use_amp'],
        gradient_clip_val=config_dict['gradient_clip'],
        early_stopping_patience=config_dict['early_stopping_patience'],
        gradient_accumulation_steps=config_dict.get('gradient_accumulation_steps', 1),
        class_names=config.CLASS_NAMES
    )
    
    # Train (adjust epochs if resuming)
    remaining_epochs = config_dict['num_epochs'] - start_epoch
    if remaining_epochs > 0:
        history = trainer.train(num_epochs=remaining_epochs)
    else:
        print(f"Training already completed ({start_epoch} epochs). Skipping.")
        return None, output_dir
    
    # Print final summary
    best_epoch = max(range(len(history['val_iou'])), key=lambda i: history['val_iou'][i])
    print(f"\n{'='*80}")
    print(f"FINAL RESULTS - {model_name.upper()}")
    print(f"{'='*80}")
    print(f"Best epoch: {best_epoch + 1}/{len(history['val_iou'])}")
    print(f"\nBest validation metrics:")
    print(f"  IoU:  {history['val_iou'][best_epoch]:.4f}")
    print(f"  Dice: {history['val_dice'][best_epoch]:.4f}")
    print(f"  F1:   {history['val_f1'][best_epoch]:.4f}")
    print(f"\n{'='*80}")
    # Per-class metrics if available
    if 'val_iou_per_class' in history:
        print(f"\nPer-class IoU (Best Epoch):")
        for i, (class_name, iou) in enumerate(zip(config.CLASS_NAMES, history['val_iou_per_class'][best_epoch])):
            print(f"  {class_name}: {iou:.4f}")
    
    print(f"{'='*80}\n")
    
    # Log metrics to TensorBoard
    for epoch in range(len(history['train_loss'])):
        logger.log_scalar('Loss/train', history['train_loss'][epoch], epoch)
        logger.log_scalar('Loss/val', history['val_loss'][epoch], epoch)
        logger.log_scalar('IoU/train', history['train_iou'][epoch], epoch)
        logger.log_scalar('IoU/val', history['val_iou'][epoch], epoch)
    
    logger.close()
    
    # Save history
    history_json = {}
    for key, values in history.items():
        if isinstance(values, list):
            history_json[key] = [float(v) if hasattr(v, 'item') else v for v in values]
        else:
            history_json[key] = values
    
    with open(output_dir / 'training_history.json', 'w') as f:
        json.dump(history_json, f, indent=2)
    
    print(f"[SAVED] Checkpoints: {checkpoint_dir}")
    print(f"[SAVED] Training history: {output_dir / 'training_history.json'}\n")
    
    return history, output_dir

## 6. Train All Models

In [None]:
# Clear GPU cache before training
gpu_mgr.cleanup()

import gc
gc.collect()

if gpu_mgr.is_available():
    stats = gpu_mgr.get_memory_stats()
    total_memory_gb = stats['total_gb']
    allocated_gb = stats['allocated_gb']
    reserved_gb = stats['reserved_gb']
    free_gb = stats['free_gb']
    
    print(f"GPU Memory Status:")
    print(f"  Total: {total_memory_gb:.2f} GB")
    print(f"  Allocated: {allocated_gb:.2f} GB")
    print(f"  Reserved: {reserved_gb:.2f} GB")
    print(f"  Available: {free_gb:.2f} GB")
    
    if free_gb < 2.0:
        print("\nWARNING: Less than 2GB free GPU memory!")
        print("Recommendation: Restart runtime to clear GPU memory completely.")
    else:
        print(f"\nGPU memory check passed. Ready for training.")
else:
    print("No GPU available - training will be slow on CPU")

In [None]:
# Reload trainer module 
import importlib
if 'trainer' in sys.modules:
    importlib.reload(sys.modules['trainer'])
    print("Trainer module reloaded")
else:
    from trainer import Trainer
    print("Trainer module loaded")

In [None]:
import threading
from queue import Queue
import copy
import gc

# Global variable to track currently training models
current_training_models = []
training_lock = threading.Lock()

def train_model_parallel(model_name, config_dict, train_loader, val_loader, class_weights, results_queue):
    """Train a single model in parallel mode."""
    try:
        # Register this model as currently training
        with training_lock:
            current_training_models.append(model_name)
        
        # Create separate config for this model
        model_config = copy.deepcopy(config_dict)
        model_config['batch_size'] = 4  # Reduced to 4 for parallel training
        
        # Clear GPU cache before training
        gpu_mgr.cleanup()
        
        # Create fresh dataloaders for parallel training
        parallel_train_loader, parallel_val_loader, _ = create_dataloaders(
            train_dir=config.PROCESSED_TRAIN_DIR,
            val_dir=config.PROCESSED_VAL_DIR,
            test_dir=config.PROCESSED_TEST_DIR,
            batch_size=model_config['batch_size'],
            num_workers=1,  # Reduced to 1 for parallel training
            pin_memory=False  # Disabled to save GPU memory
        )
        
        # Train model with fresh dataloaders
        history, output_dir = train_model(model_name, model_config, parallel_train_loader, parallel_val_loader, class_weights)
        
        # Store results
        results_queue.put({
            'model_name': model_name,
            'history': history,
            'output_dir': output_dir,
            'success': True,
            'error': None
        })
    except Exception as e:
        import traceback
        error_msg = f"{str(e)}\n{traceback.format_exc()}"
        results_queue.put({
            'model_name': model_name,
            'history': None,
            'output_dir': None,
            'success': False,
            'error': error_msg
        })
    finally:
        # Unregister this model
        with training_lock:
            if model_name in current_training_models:
                current_training_models.remove(model_name)
        
        # Clear GPU cache after training
        gpu_mgr.cleanup()

def train_models_parallel_pairs(model_pairs, config_dict, train_loader, val_loader, class_weights):
    """Train models in parallel pairs."""
    all_results = {}
    
    for pair_idx, pair in enumerate(model_pairs):
        print(f"\n{'='*80}")
        print(f"PARALLEL TRAINING - PAIR {pair_idx + 1}: {' + '.join([m.upper() for m in pair])}")
        print(f"{'='*80}\n")
        
        # Clear GPU cache before starting pair
        gpu_mgr.cleanup()
        
        # Create results queue
        results_queue = Queue()
        
        # Create threads for parallel training
        threads = []
        for model_name in pair:
            thread = threading.Thread(
                target=train_model_parallel,
                args=(model_name, config_dict, train_loader, val_loader, class_weights, results_queue)
            )
            threads.append(thread)
            thread.start()
        
        # Wait for both to complete
        for thread in threads:
            thread.join()
        
        # Collect results
        while not results_queue.empty():
            result = results_queue.get()
            model_name = result['model_name']
            
            if result['success']:
                all_results[model_name] = {
                    'history': result['history'],
                    'output_dir': result['output_dir']
                }
                print(f"\n{model_name.upper()} completed successfully!")
            else:
                print(f"\n{model_name.upper()} failed:")
                print(f"Error: {result['error']}\n")
        
        # Clear GPU cache between pairs
        gpu_mgr.cleanup()
        
        print(f"\n{'='*80}")
        print(f"PAIR {pair_idx + 1} COMPLETED")
        print(f"{'='*80}\n")
    
    return all_results

print("Parallel training functions loaded!")

In [None]:
# GPU monitoring using GPUManager
print("GPU monitoring available via GPUManager.")
print("Run gpu_mgr.monitor_memory(interval=30, duration=3600) to start monitoring.")

# Uncomment to run GPU monitor in background
import threading
monitor_thread = gpu_mgr.monitor_memory(interval=30, duration=3600)
print("GPU monitor started in background (30s refresh, 1 hour duration).")

## 7. Full Training Execution

**IMPORTANT:** Run the Light Pipeline Validation (Section 5) first before executing this section!

This section trains all models with the optimized configuration:
- 20 epochs (reduced from 30)
- Early stopping patience: 5 (reduced from 10)
- Sequential or parallel mode

In [None]:
# Training mode: Set to True for parallel, False for sequential
USE_PARALLEL_TRAINING = False  # Changed to False for more stable training

# Resume from checkpoint (optional)
# Set to checkpoint path to resume training, or None to start fresh
# Example: RESUME_CHECKPOINTS = {'deeplabv3+': '../outputs/training/deeplabv3+_20251202_123456/checkpoints/best_model.pth'}
RESUME_CHECKPOINTS = {}  # Empty dict = start fresh training for all models

# Define all models to train
# CRITICAL: Other GPU processes are using 18GB of your 42GB GPU!
# Latest validation: Only DeepLabV3+ and SegFormer passed
# 
# TO FIX: Kill other GPU processes to free memory
# Run in terminal: nvidia-smi
# Then kill processes: taskkill /PID 756229 /F (repeat for PIDs 1075108, 1251332)
# Or restart your machine to clear all GPU memory
#
# Current strategy: Train the 2 models that consistently pass validation
# ALL_MODELS = ['deeplabv3+', 'segformer']  # Models that passed latest validation
# ALL_MODELS = ['unet++', 'deeplabv3+', 'segformer', 'fc_siam_diff']  # Uncomment after freeing GPU memory
ALL_MODELS = ['unet++', 'deeplabv3+', 'segformer', 'fc_siam_diff', 'siamese_unet++', 'stanet']  # All 6 models

# Model pairs for parallel training (only used if USE_PARALLEL_TRAINING=True)
MODEL_PAIRS = [
    ['unet++', 'deeplabv3+'],           # Pair 1: Smaller models
    ['segformer', 'fc_siam_diff'],      # Pair 2: Medium models
    ['siamese_unet++', 'stanet']        # Pair 3: Larger models
]

# Check validation status
if 'validation_results' in globals():
    failed_models = [m for m, r in validation_results.items() if r['status'] == 'failed']
    if failed_models:
        print(f"\nWARNING: {len(failed_models)} model(s) failed validation!")
        print("Failed models:", ', '.join([m.upper() for m in failed_models]))
        print("Recommend fixing validation errors before full training.\n")
else:
    print("\nWARNING: Light validation not run yet!")
    print("Recommend running Section 5 (Light Pipeline Validation) first.\n")

# Execute training
if USE_PARALLEL_TRAINING:
    print("\nPARALLEL TRAINING MODE")
    print("Training 6 models in 3 pairs with batch_size=4 per model")
    print("Total effective batch size: 8 (2 models Ã— 4)\n")
    
    results = train_models_parallel_pairs(
        MODEL_PAIRS,
        TRAINING_CONFIG,
        train_loader,
        val_loader,
        class_weights
    )
else:
    print("\nSEQUENTIAL TRAINING MODE")
    print(f"Training {len(ALL_MODELS)} models one by one with batch_size={TRAINING_CONFIG['batch_size']}\n")
    
    results = {}
    
    for model_name in ALL_MODELS:
        # Check if resume checkpoint exists for this model
        resume_checkpoint = RESUME_CHECKPOINTS.get(model_name, None)
        if resume_checkpoint:
            print(f"\nResuming {model_name} from checkpoint...")
        
        history, output_dir = train_model(
            model_name,
            TRAINING_CONFIG,
            train_loader,
            val_loader,
            class_weights,
            resume_from_checkpoint=resume_checkpoint
        )
        results[model_name] = {
            'history': history,
            'output_dir': output_dir
        }
        
        # Clear GPU cache between models
        gpu_mgr.cleanup()

# Extract results (handle missing models gracefully)
unet_history = results.get('unet++', {}).get('history')
unet_output_dir = results.get('unet++', {}).get('output_dir')
deeplab_history = results.get('deeplabv3+', {}).get('history')
deeplab_output_dir = results.get('deeplabv3+', {}).get('output_dir')
segformer_history = results.get('segformer', {}).get('history')
segformer_output_dir = results.get('segformer', {}).get('output_dir')
fcsiamdiff_history = results.get('fc_siam_diff', {}).get('history')
fcsiamdiff_output_dir = results.get('fc_siam_diff', {}).get('output_dir')
siamese_unet_history = results.get('siamese_unet++', {}).get('history')
siamese_unet_output_dir = results.get('siamese_unet++', {}).get('output_dir')
stanet_history = results.get('stanet', {}).get('history')
stanet_output_dir = results.get('stanet', {}).get('output_dir')

# Training complete summary
print("\n" + "="*80)
print("TRAINING COMPLETE")
print("="*80)
print(f"Mode: {'PARALLEL' if USE_PARALLEL_TRAINING else 'SEQUENTIAL'}")

successful_models = []
failed_models = []

for model_name in ALL_MODELS:
    if model_name in results and results[model_name].get('history') is not None:
        successful_models.append(model_name)
    else:
        failed_models.append(model_name)

if successful_models:
    print(f"\nSuccessfully trained models ({len(successful_models)}/{len(ALL_MODELS)}):")
    for model_name in successful_models:
        print(f"  {model_name.upper()}: {results[model_name]['output_dir']}")

if failed_models:
    print(f"\nFailed models ({len(failed_models)}/{len(ALL_MODELS)}):")
    for model_name in failed_models:
        print(f"  {model_name.upper()}")

print("="*80)

## 8. Training Metrics Visualization

In [None]:
# Import training visualizer
from visualizations import TrainingVisualizer

# Visualize individual model training histories
model_histories = [
    (unet_history, 'U-Net++'),
    (deeplab_history, 'DeepLabV3+'),
    (segformer_history, 'SegFormer'),
    (fcsiamdiff_history, 'FC-Siam-Diff'),
    (siamese_unet_history, 'Siamese U-Net++'),
    (stanet_history, 'STANet')
]

for history, model_name in model_histories:
    if history is not None:
        print(f"\n{'='*80}")
        print(f"Visualizing {model_name} Training History")
        print(f"{'='*80}\n")
        
        # Plot training history
        save_path = Path('../outputs/training') / f'{model_name.lower().replace(" ", "_").replace("+", "plus")}_history.png'
        save_path.parent.mkdir(parents=True, exist_ok=True)
        
        TrainingVisualizer.plot_training_history(history, model_name, save_path)
    else:
        print(f"\nNo training history available for {model_name}")

## 9. Model Comparison

In [None]:
# Compare all models using TrainingVisualizer
from visualizations import TrainingVisualizer

all_histories = [
    unet_history, 
    deeplab_history, 
    segformer_history, 
    fcsiamdiff_history, 
    siamese_unet_history, 
    stanet_history
]

all_names = [
    'U-Net++', 
    'DeepLabV3+', 
    'SegFormer', 
    'FC-Siam-Diff', 
    'Siamese U-Net++', 
    'STANet'
]

if any(h is not None for h in all_histories):
    print(f"\n{'='*100}")
    print("COMPARING ALL MODELS")
    print(f"{'='*100}\n")
    
    # Generate comparison visualization
    save_path = Path('../outputs/training/model_comparison.png')
    save_path.parent.mkdir(parents=True, exist_ok=True)
    
    TrainingVisualizer.compare_models(all_histories, all_names, save_path)
else:
    print("No models completed training successfully. Unable to generate comparison.")

## 10. TensorBoard Visualization

In [None]:
# Launch TensorBoard to view training metrics
print("To view TensorBoard:")
print("1. Run in terminal: tensorboard --logdir=../outputs/tensorboard --port=6006")
print("2. Open browser: http://localhost:6006")
print("\nTensorBoard shows:")
print("  - Training/validation loss curves")
print("  - IoU, Dice, F1 metrics over time")
print("  - Per-class performance")
print("  - Learning rate schedules")
print("  - Model graphs")

## 11. Commit and Push Checkpoints to Git

Save training checkpoints and results to version control.

In [None]:
import subprocess
from pathlib import Path

def commit_and_push_checkpoints(commit_message=None):
    """Commit training outputs and push to git repository."""
    
    # Check if we're in a git repository
    try:
        result = subprocess.run(['git', 'rev-parse', '--git-dir'], 
                              capture_output=True, text=True, check=True)
        print("Git repository detected.\n")
    except subprocess.CalledProcessError:
        print("ERROR: Not in a git repository. Cannot commit.")
        return False
    
    # Add outputs directory
    outputs_dir = Path('../outputs')
    if not outputs_dir.exists():
        print("No outputs directory found. Nothing to commit.")
        return False
    
    print("Adding training outputs to git...")
    
    # Add specific files (exclude large model weights if needed)
    files_to_add = [
        '../outputs/training/*/training_history.json',
        '../outputs/training/*/checkpoints/*.pth',
        '../outputs/model_comparison.png',
        '../outputs/tensorboard'
    ]
    
    for pattern in files_to_add:
        try:
            subprocess.run(['git', 'add', pattern], check=False)
        except Exception as e:
            print(f"Warning: Could not add {pattern}: {e}")
    
    # Check if there are changes to commit
    result = subprocess.run(['git', 'status', '--porcelain'], 
                          capture_output=True, text=True)
    
    if not result.stdout.strip():
        print("\nNo changes to commit.")
        return False
    
    # Create commit message
    if not commit_message:
        timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        trained_models = ', '.join([m.upper() for m in ALL_MODELS]) if 'ALL_MODELS' in globals() else 'models'
        commit_message = f"Training checkpoint: {trained_models} - {timestamp}"
    
    print(f"\nCommit message: {commit_message}")
    
    # Commit
    try:
        result = subprocess.run(['git', 'commit', '-m', commit_message],
                              capture_output=True, text=True, check=True)
        print("\nCommit successful!")
        print(result.stdout)
    except subprocess.CalledProcessError as e:
        print(f"\nCommit failed: {e.stderr}")
        return False
    
    # Push to remote
    print("\nPushing to remote repository...")
    try:
        result = subprocess.run(['git', 'push'],
                              capture_output=True, text=True, check=True)
        print("Push successful!")
        print(result.stdout)
        return True
    except subprocess.CalledProcessError as e:
        print(f"\nPush failed: {e.stderr}")
        print("You may need to pull changes first or check your remote configuration.")
        return False

# Run the commit and push
print("="*80)
print("COMMITTING TRAINING CHECKPOINTS TO GIT")
print("="*80)
print("\nThis will commit:")
print("  - Training history JSON files")
print("  - Model checkpoints (.pth files)")
print("  - Comparison plots")
print("  - TensorBoard logs")
print("\nNote: Large checkpoint files may take time to upload.\n")

# commit_and_push_checkpoints with default message
commit_and_push_checkpoints()

print("COMMITTING TRAINING CHECKPOINTS TO GIT - COMPLETE")