## Colab Setup

In [None]:
import sys

IS_COLAB = 'google.colab' in sys.modules
print(f"Running in Google Colab: {IS_COLAB}")

In [None]:
import platform
import psutil
import subprocess
import os

if IS_COLAB:
    print("Google Colab Environment Specifications:")
    print("="*50)
    
    # Get system info
    
    print(f"Operating System: {platform.system()} {platform.release()}")
    print(f"Architecture: {platform.machine()}")
    print(f"Python Version: {platform.python_version()}")
    
    # Memory info
    memory = psutil.virtual_memory()
    print(f"Total RAM: {memory.total / (1024**3):.1f} GB")
    print(f"Available RAM: {memory.available / (1024**3):.1f} GB")
    
    # CPU info
    print(f"CPU Cores: {psutil.cpu_count(logical=False)} physical, {psutil.cpu_count(logical=True)} logical")
    
    # GPU info
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader,nounits'], 
                              capture_output=True, text=True)
        if result.returncode == 0:
            gpu_info = result.stdout.strip().split('\n')
            for i, gpu in enumerate(gpu_info):
                name, memory = gpu.split(', ')
                print(f"GPU {i}: {name}, {memory} MB VRAM")
        else:
            print("GPU: Not detected or nvidia-smi unavailable")
    except:
        print("GPU: Not detected")
    
    # Disk space
    disk = psutil.disk_usage('/')
    print(f"Disk Space: {disk.free / (1024**3):.1f} GB free / {disk.total / (1024**3):.1f} GB total")
    
    print("="*50)
else:
    print("Not running in Google Colab environment")

In [None]:
import os
import sys

if IS_COLAB:
    print("Running in Google Colab environment.")
    if os.path.exists('/content/aai521_3proj'):
        print("Repository already exists. Pulling latest changes...")
        %cd /content/aai521_3proj
        !git pull
    else:
        print("Cloning repository...")
        !git clone https://github.com/swapnilprakashpatil/aai521_3proj.git
        %cd aai521_3proj    
    %pip install -r requirements.txt
    sys.path.append('/content/aai521_3proj/src')
    %ls
else:
    print("Running in local environment. Installing packages...")
    %pip install -r ../requirements.txt
    sys.path.append('../src')

## 1. Setup & Imports

In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Add src to path
sys.path.append('../src')

# Reload modules to pick up latest changes
import importlib
if 'dataset' in sys.modules:
    importlib.reload(sys.modules['dataset'])
if 'models' in sys.modules:
    importlib.reload(sys.modules['models'])
if 'config' in sys.modules:
    importlib.reload(sys.modules['config'])

# Import custom modules
import config
from dataset import create_dataloaders, FloodDataset
from models import create_model, UNetPlusPlus, DeepLabV3Plus, SegFormer
from losses import create_loss_function
from metrics import MetricsTracker, SegmentationMetrics
from trainer import Trainer
from experiment_tracking import ExperimentLogger, ExperimentComparator

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Data Loading & Exploration

In [None]:
# Create dataloaders
print("Creating dataloaders...")
train_loader, val_loader, test_loader = create_dataloaders(
    train_dir=config.PROCESSED_TRAIN_DIR,
    val_dir=config.PROCESSED_VAL_DIR,
    test_dir=config.PROCESSED_TEST_DIR,
    batch_size=64,  # A100 can handle 64 easily with 40GB VRAM
    num_workers=4,  # Linux supports more workers without issues
    pin_memory=True  # Faster GPU transfer
)

print(f"\nDataset sizes:")
print(f"  Training: {len(train_loader.dataset)} patches ({len(train_loader)} batches)")
print(f"  Validation: {len(val_loader.dataset)} patches ({len(val_loader)} batches)")
print(f"  Test: {len(test_loader.dataset)} patches ({len(test_loader)} batches)")

# Get class weights
class_weights = train_loader.dataset.get_class_weights()
print(f"\nClass weights: {class_weights}")

# Get class distribution
class_dist = train_loader.dataset.get_class_distribution()
print(f"\nClass distribution:")
for i, (class_name, count) in enumerate(zip(config.CLASS_NAMES, class_dist)):
    percentage = (count / class_dist.sum()) * 100
    print(f"  {class_name}: {count:,} ({percentage:.2f}%)")

### Visualize Class Distribution

In [None]:
# Plot class distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))

# Bar plot
colors = plt.cm.tab10(np.linspace(0, 1, len(config.CLASS_NAMES)))
bars = ax1.bar(range(len(config.CLASS_NAMES)), class_dist, color=colors, alpha=0.7)
ax1.set_xlabel('Class', fontsize=12)
ax1.set_ylabel('Pixel Count', fontsize=12)
ax1.set_title('Class Distribution in Training Set', fontsize=14, fontweight='bold')
ax1.set_xticks(range(len(config.CLASS_NAMES)))
ax1.set_xticklabels(config.CLASS_NAMES, rotation=45, ha='right')
ax1.grid(axis='y', alpha=0.3)

# Add percentage labels
for i, (bar, count) in enumerate(zip(bars, class_dist)):
    height = bar.get_height()
    percentage = (count / class_dist.sum()) * 100
    ax1.text(bar.get_x() + bar.get_width()/2., height,
            f'{percentage:.1f}%',
            ha='center', va='bottom', fontsize=9)

# Pie chart
ax2.pie(class_dist, labels=config.CLASS_NAMES, autopct='%1.1f%%',
        colors=colors, startangle=90)
ax2.set_title('Class Distribution (Percentage)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

### Visualize Sample Data

In [None]:
# Get a batch of training data
train_iter = iter(train_loader)
batch = next(train_iter)
images = batch['image']
masks = batch['mask']

print(f"Batch shape: {images.shape}")
print(f"Mask shape: {masks.shape}")
print(f"Image range: [{images.min():.3f}, {images.max():.3f}]")
print(f"Mask classes: {masks.unique().tolist()}")

# Visualize samples
def visualize_samples(images, masks, num_samples=3):
    """Visualize pre/post images and masks."""
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    
    # Color map for masks
    cmap = plt.cm.get_cmap('tab10', len(config.CLASS_NAMES))
    
    for i in range(num_samples):
        # Pre-event image (first 3 channels)
        pre_img = images[i, :3].permute(1, 2, 0).numpy()
        pre_img = (pre_img - pre_img.min()) / (pre_img.max() - pre_img.min() + 1e-8)
        
        # Post-event image (last 3 channels)
        post_img = images[i, 3:].permute(1, 2, 0).numpy()
        post_img = (post_img - post_img.min()) / (post_img.max() - post_img.min() + 1e-8)
        
        # Mask
        mask = masks[i].numpy()
        
        # Plot pre-event
        axes[i, 0].imshow(pre_img)
        axes[i, 0].set_title('Pre-Event Image', fontsize=12, fontweight='bold')
        axes[i, 0].axis('off')
        
        # Plot post-event
        axes[i, 1].imshow(post_img)
        axes[i, 1].set_title('Post-Event Image', fontsize=12, fontweight='bold')
        axes[i, 1].axis('off')
        
        # Plot mask
        mask_plot = axes[i, 2].imshow(mask, cmap=cmap, vmin=0, vmax=len(config.CLASS_NAMES)-1)
        axes[i, 2].set_title('Ground Truth Mask', fontsize=12, fontweight='bold')
        axes[i, 2].axis('off')
        
        # Add colorbar to last mask
        if i == num_samples - 1:
            cbar = plt.colorbar(mask_plot, ax=axes[i, 2], orientation='horizontal', 
                              pad=0.05, fraction=0.046)
            cbar.set_ticks(range(len(config.CLASS_NAMES)))
            cbar.set_ticklabels(config.CLASS_NAMES, rotation=45, ha='right', fontsize=8)
    
    plt.tight_layout()
    plt.show()

visualize_samples(images, masks, num_samples=3)

## 3. Model Architecture Overview

In [None]:
# Create models for architecture overview
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# List of all models to train
ALL_MODELS = ['unet++', 'deeplabv3+', 'segformer', 'fc_siam_diff', 'siamese_unet++', 'stanet']

models_info = []

for model_name in ALL_MODELS:
    model = create_model(
        model_name=model_name,
        in_channels=6 if 'siamese' not in model_name.lower() else 3,
        num_classes=config.NUM_CLASSES,
        **config.MODEL_CONFIGS.get(model_name, {})
    )
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    models_info.append({
        'Model': model_name.upper(),
        'Total Parameters': f"{total_params:,}",
        'Trainable Parameters': f"{trainable_params:,}",
        'Size (MB)': f"{total_params * 4 / 1e6:.2f}"
    })
    
    del model

# Display as table
models_df = pd.DataFrame(models_info)
print("\n" + "="*80)
print("MODEL ARCHITECTURE COMPARISON")
print("="*80)
print(models_df.to_string(index=False))
print("="*80)

## 4. Training Configuration

In [None]:
# Training configuration - OPTIMIZED FOR A100 GPU (40GB VRAM)
TRAINING_CONFIG = {
    'batch_size': 64,  # A100 can handle large batches with 40GB VRAM
    'num_epochs': 30,  # Balanced for quality and speed
    'learning_rate': 1e-4,  # Slightly higher LR for larger batch
    'weight_decay': 1e-4,
    'use_amp': True,  # Critical for A100 - uses Tensor Cores
    'gradient_clip': 1.0,
    'early_stopping_patience': 10,
    'loss_type': 'combined',
    'scheduler_type': 'plateau',
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'gradient_accumulation_steps': 1,  # No accumulation needed with batch=64
    'print_every_n_epochs': 5,
    # Loss weights for class imbalance
    'ce_weight': 0.1,
    'dice_weight': 2.0,
    'focal_weight': 3.0,
    'focal_gamma': 3.0
}

print("\n" + "="*80)
print("TRAINING CONFIGURATION - A100 GPU OPTIMIZATION")
print("="*80)
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")
print(f"\n  ðŸš€ A100 GPU OPTIMIZATIONS:")
print(f"  âœ“ Batch size: 64 (8x original - maximum GPU utilization)")
print(f"  âœ“ Data workers: 4 (Linux + persistent_workers)")
print(f"  âœ“ Pin memory: enabled")
print(f"  âœ“ Prefetch factor: 2")
print(f"  âœ“ Mixed precision (AMP): Tensor Cores enabled")
print(f"  âœ“ TF32: ON for faster matrix operations")
print(f"\n  âš¡ EXPECTED PERFORMANCE ON A100:")
print(f"  â€¢ Speed: ~30-40 sec/epoch (10x faster than Windows!)")
print(f"  â€¢ Time per model: ~20-25 minutes")
print(f"  â€¢ Total for 6 models: ~2-2.5 hours")
print(f"  â€¢ GPU utilization: 30-40% (optimal for this model size)")
print("="*80)

In [None]:
# Apply A100 GPU-specific optimizations
if torch.cuda.is_available():
    # Enable cuDNN auto-tuner to find best algorithms
    torch.backends.cudnn.benchmark = True
    
    # Enable TF32 for A100 (faster computation on Ampere GPUs)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    # Set optimal number of threads for Colab
    torch.set_num_threads(8)
    
    # Check if A100 is detected
    gpu_name = torch.cuda.get_device_name(0)
    
    print("\n" + "="*80)
    print("ðŸš€ A100 GPU OPTIMIZATIONS ENABLED")
    print("="*80)
    print(f"  GPU: {gpu_name}")
    print(f"  âœ“ cuDNN benchmark mode: ON (auto-tune algorithms)")
    print(f"  âœ“ TF32 tensor cores: ON (A100 specific - 8x faster!)")
    print(f"  âœ“ Mixed Precision (AMP): Enabled in config")
    print(f"  âœ“ CPU threads: 8 (Colab optimized)")
    print(f"  âœ“ Persistent workers: ON (prefetching enabled)")
    
    # Display memory info
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"\n  ðŸ’¾ GPU Memory: {total_memory:.1f} GB")
    print(f"  ðŸ“Š Expected usage with batch=64: ~15-20 GB (50% utilization)")
    print("="*80 + "\n")

### ðŸš€ A100 GPU Performance Optimizations

**Optimized for Google Colab A100 (40GB VRAM) - 10x faster than local!**

#### 1. **Hardware Advantages**:
   - **A100 GPU**: 40GB VRAM (vs 4-8GB consumer GPUs)
   - **Tensor Cores**: Specialized hardware for mixed precision
   - **High bandwidth**: 1.5TB/s memory bandwidth
   - **Linux environment**: Better multiprocessing support

#### 2. **Configuration Optimizations**:
   - **Batch size: 64** (8x larger than Windows - 50% GPU utilization)
   - **Workers: 4** with persistent_workers (no Windows limitations)
   - **Prefetch factor: 2** (continuous data streaming)
   - **Mixed Precision (AMP)**: Leverages A100 Tensor Cores for 2-3x speedup

#### 3. **PyTorch Optimizations**:
   - **TF32 mode**: A100-specific acceleration (8x faster matmul)
   - **cuDNN benchmark**: Auto-tunes convolution algorithms
   - **No gradient accumulation**: Direct large batches

#### 4. **Expected Performance** âš¡:
   - **Speed**: ~30-40 sec/epoch (was ~377 sec on Windows)
   - **Per model**: ~20-25 minutes (was ~5 hours)
   - **Total (6 models)**: ~2-2.5 hours (was ~30 hours)
   - **Overall speedup**: **10-12x faster than original setup!**

#### 5. **Memory Usage**:
   - Batch=64: ~15-20 GB / 40 GB (~50% utilization)
   - Leaves plenty of headroom for model optimization
   - Can potentially increase to batch=96 if needed

In [None]:
# Check GPU memory and recommend optimal batch size
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    gpu_props = torch.cuda.get_device_properties(0)
    total_memory_gb = gpu_props.total_memory / 1e9
    
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Total Memory: {total_memory_gb:.2f} GB")
    
    # Rough estimation for batch size recommendation
    if total_memory_gb >= 24:
        recommended_batch = "24-32 (you have plenty of VRAM)"
    elif total_memory_gb >= 16:
        recommended_batch = "16-24 (current: 16 is good)"
    elif total_memory_gb >= 12:
        recommended_batch = "12-16 (current: 16 might be tight)"
    elif total_memory_gb >= 8:
        recommended_batch = "8-12 (reduce to 12 if OOM)"
    else:
        recommended_batch = "4-8 (reduce to 8 and increase grad accumulation)"
    
    print(f"Recommended batch size: {recommended_batch}")
    print(f"Current batch size: {TRAINING_CONFIG['batch_size']}")
    print(f"Effective batch size: {TRAINING_CONFIG['batch_size'] * TRAINING_CONFIG['gradient_accumulation_steps']}")
    
    # Monitor current GPU state
    torch.cuda.empty_cache()
    allocated = torch.cuda.memory_allocated() / 1e9
    reserved = torch.cuda.memory_reserved() / 1e9
    print(f"\nCurrent GPU Usage:")
    print(f"  Allocated: {allocated:.2f} GB")
    print(f"  Reserved: {reserved:.2f} GB")
    print(f"  Free: {total_memory_gb - reserved:.2f} GB")
else:
    print("No GPU available - using CPU")

## 5. Training Function

In [None]:
def train_model(model_name, config_dict, train_loader, val_loader, class_weights):
    """Train a single model and return training history."""
    print(f"\n{'='*80}")
    print(f"Training {model_name.upper()}")
    print(f"{'='*80}\n")
    
    # CUDA optimizations
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    
    # Create output directory
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    output_dir = Path('../outputs/training') / f'{model_name}_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)
    checkpoint_dir = output_dir / 'checkpoints'
    checkpoint_dir.mkdir(exist_ok=True)
    
    # Create model
    model = create_model(
        model_name=model_name,
        in_channels=6 if 'siamese' not in model_name.lower() else 3,
        num_classes=config.NUM_CLASSES,
        **config.MODEL_CONFIGS.get(model_name, {})
    )
    model = model.to(config_dict['device'])
    
    # Use torch.compile if available
    if hasattr(torch, 'compile') and config_dict['device'] == 'cuda':
        try:
            model = torch.compile(model, mode='default')
        except Exception as e:
            pass
    
    # Create loss function
    loss_fn = create_loss_function(
        loss_type=config_dict['loss_type'],
        num_classes=config.NUM_CLASSES,
        class_weights=class_weights.to(config_dict['device']),
        device=config_dict['device'],
        ce_weight=config_dict.get('ce_weight', 0.1),
        dice_weight=config_dict.get('dice_weight', 2.0),
        focal_weight=config_dict.get('focal_weight', 3.0),
        focal_gamma=config_dict.get('focal_gamma', 3.0)
    )
    
    # Create optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config_dict['learning_rate'],
        weight_decay=config_dict['weight_decay'],
        betas=(0.9, 0.999),
        eps=1e-8
    )
    
    # Create scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=5,
        min_lr=1e-6
    )
    
    # Create experiment logger
    logger = ExperimentLogger(
        log_dir=Path('../outputs/tensorboard'),
        experiment_name=f'{model_name}_{timestamp}'
    )
    logger.log_hyperparameters(config_dict)
    
    # Create trainer
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        loss_fn=loss_fn,
        num_classes=config.NUM_CLASSES,
        device=config_dict['device'],
        checkpoint_dir=checkpoint_dir,
        experiment_name=f'{model_name}_{timestamp}',
        use_amp=config_dict['use_amp'],
        gradient_clip_val=config_dict['gradient_clip'],
        early_stopping_patience=config_dict['early_stopping_patience'],
        gradient_accumulation_steps=config_dict.get('gradient_accumulation_steps', 1),
        class_names=config.CLASS_NAMES
    )
    
    # Train
    history = trainer.train(num_epochs=config_dict['num_epochs'])
    
    # Print final summary
    best_epoch = max(range(len(history['val_iou'])), key=lambda i: history['val_iou'][i])
    print(f"\n{'='*80}")
    print(f"FINAL RESULTS - {model_name.upper()}")
    print(f"{'='*80}")
    print(f"Best epoch: {best_epoch + 1}/{len(history['val_iou'])}")
    print(f"\nBest validation metrics:")
    print(f"  IoU:  {history['val_iou'][best_epoch]:.4f}")
    print(f"  Dice: {history['val_dice'][best_epoch]:.4f}")
    print(f"  F1:   {history['val_f1'][best_epoch]:.4f}")
    
    # Per-class metrics if available
    if 'val_iou_per_class' in history:
        print(f"\nPer-class IoU (Best Epoch):")
        for i, (class_name, iou) in enumerate(zip(config.CLASS_NAMES, history['val_iou_per_class'][best_epoch])):
            print(f"  {class_name}: {iou:.4f}")
    
    print(f"{'='*80}\n")
    
    # Log metrics to TensorBoard
    for epoch in range(len(history['train_loss'])):
        logger.log_scalar('Loss/train', history['train_loss'][epoch], epoch)
        logger.log_scalar('Loss/val', history['val_loss'][epoch], epoch)
        logger.log_scalar('IoU/train', history['train_iou'][epoch], epoch)
        logger.log_scalar('IoU/val', history['val_iou'][epoch], epoch)
    
    logger.close()
    
    # Save history
    history_json = {}
    for key, values in history.items():
        if isinstance(values, list):
            history_json[key] = [float(v) if hasattr(v, 'item') else v for v in values]
        else:
            history_json[key] = values
    
    with open(output_dir / 'training_history.json', 'w') as f:
        json.dump(history_json, f, indent=2)
    
    print(f"[SAVED] Checkpoints: {checkpoint_dir}")
    print(f"[SAVED] Training history: {output_dir / 'training_history.json'}\n")
    
    return history, output_dir

---

## 6. Train All Models

**Choose ONE of the following training modes:**

In [None]:
# Clear GPU cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    import gc
    gc.collect()
    
    # Check available memory
    free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
    free_gb = free_memory / 1e9
    print(f"GPU Memory Available: {free_gb:.2f} GB")
    
    if free_gb < 2.0:
        print("WARNING: Less than 2GB free GPU memory!")
        print("Please restart kernel and re-run from the beginning.")
    else:
        print("GPU memory check passed. Starting training...")
else:
    print("No GPU available - training will be slow on CPU")

In [None]:
# Reload trainer module 
import importlib
if 'trainer' in sys.modules:
    importlib.reload(sys.modules['trainer'])
    print("Trainer module reloaded")
else:
    from trainer import Trainer
    print("Trainer module loaded")

In [None]:
import threading
from queue import Queue
import copy

# Global variable to track currently training models
current_training_models = []
training_lock = threading.Lock()

def train_model_parallel(model_name, config_dict, train_loader, val_loader, class_weights, results_queue):
    """Train a single model in parallel mode."""
    try:
        # Register this model as currently training
        with training_lock:
            current_training_models.append(model_name)
        
        # Create separate config for this model
        model_config = copy.deepcopy(config_dict)
        model_config['batch_size'] = 32  # Reduced for parallel training
        
        # Train model
        history, output_dir = train_model(model_name, model_config, train_loader, val_loader, class_weights)
        
        # Store results
        results_queue.put({
            'model_name': model_name,
            'history': history,
            'output_dir': output_dir,
            'success': True,
            'error': None
        })
    except Exception as e:
        results_queue.put({
            'model_name': model_name,
            'history': None,
            'output_dir': None,
            'success': False,
            'error': str(e)
        })
    finally:
        # Unregister this model
        with training_lock:
            if model_name in current_training_models:
                current_training_models.remove(model_name)

def train_models_parallel_pairs(model_pairs, config_dict, train_loader, val_loader, class_weights):
    """Train models in parallel pairs."""
    all_results = {}
    
    for pair_idx, pair in enumerate(model_pairs):
        print(f"\n{'='*80}")
        print(f"PARALLEL TRAINING - PAIR {pair_idx + 1}: {' + '.join([m.upper() for m in pair])}")
        print(f"{'='*80}\n")
        
        # Create results queue
        results_queue = Queue()
        
        # Create threads for parallel training
        threads = []
        for model_name in pair:
            thread = threading.Thread(
                target=train_model_parallel,
                args=(model_name, config_dict, train_loader, val_loader, class_weights, results_queue)
            )
            threads.append(thread)
            thread.start()
        
        # Wait for both to complete
        for thread in threads:
            thread.join()
        
        # Collect results
        while not results_queue.empty():
            result = results_queue.get()
            model_name = result['model_name']
            
            if result['success']:
                all_results[model_name] = {
                    'history': result['history'],
                    'output_dir': result['output_dir']
                }
                print(f"\n{model_name.upper()} completed successfully!")
            else:
                print(f"\n{model_name.upper()} failed: {result['error']}")
        
        # Clear GPU cache between pairs
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        
        print(f"\n{'='*80}")
        print(f"PAIR {pair_idx + 1} COMPLETED")
        print(f"{'='*80}\n")
    
    return all_results

print("Parallel training functions loaded!")


In [None]:
# Monitor GPU usage during parallel training
def monitor_gpu_usage(interval=10, duration=None):
    """Monitor GPU memory usage in real-time with model names."""
    import time
    from IPython.display import clear_output
    
    start_time = time.time()
    
    try:
        while True:
            if duration and (time.time() - start_time) > duration:
                break
                
            clear_output(wait=True)
            
            if torch.cuda.is_available():
                allocated = torch.cuda.memory_allocated() / 1e9
                reserved = torch.cuda.memory_reserved() / 1e9
                total = torch.cuda.get_device_properties(0).total_memory / 1e9
                
                # Get currently training models
                with training_lock:
                    training_models = current_training_models.copy()
                
                print(f"{'='*70}")
                print(f"GPU Memory Monitor (refreshes every {interval}s)")
                print(f"{'='*70}")
                
                # Show currently training models
                if training_models:
                    models_str = ' + '.join([m.upper() for m in training_models])
                    print(f"  Currently Training: {models_str}")
                    print(f"{'-'*70}")
                else:
                    print(f"  Currently Training: (idle)")
                    print(f"{'-'*70}")
                
                print(f"  Allocated: {allocated:.2f} GB / {total:.1f} GB ({allocated/total*100:.1f}%)")
                print(f"  Reserved:  {reserved:.2f} GB / {total:.1f} GB ({reserved/total*100:.1f}%)")
                print(f"  Free:      {total - reserved:.2f} GB ({(total-reserved)/total*100:.1f}%)")
                print(f"\n  Press Interrupt (â– ) to stop monitoring")
                print(f"{'='*70}")
            else:
                print("No GPU available")
                break
            
            time.sleep(interval)
            
    except KeyboardInterrupt:
        print("\nMonitoring stopped.")

print("GPU monitoring function ready.")

# run GPU monitor in background (useful during parallel training)
import threading
monitor_thread = threading.Thread(target=monitor_gpu_usage, args=(10, 3600))
monitor_thread.daemon = True
monitor_thread.start()


---

## 6. Training Mode Configuration

**Set the training mode below:**

In [None]:
# ============================================================================
# TRAINING MODE CONFIGURATION
# ============================================================================
# Set USE_PARALLEL_TRAINING to True for parallel training (faster on A100)
# Set to False for sequential training (more stable, easier to debug)

USE_PARALLEL_TRAINING = True  # Change to False for sequential training

# Define all models to train
ALL_MODELS = ['unet++', 'deeplabv3+', 'segformer', 'fc_siam_diff', 'siamese_unet++', 'stanet']

# Define model pairs for parallel training (only used if USE_PARALLEL_TRAINING=True)
MODEL_PAIRS = [
    ['unet++', 'deeplabv3+'],           # Pair 1: ~25 minutes
    ['segformer', 'fc_siam_diff'],      # Pair 2: ~25 minutes  
    ['siamese_unet++', 'stanet']        # Pair 3: ~25 minutes
]

# ============================================================================
# TRAINING EXECUTION
# ============================================================================

if USE_PARALLEL_TRAINING:
    print("\n" + "="*80)
    print("PARALLEL TRAINING MODE - Training 6 models in 3 pairs")
    print("="*80)
    print("="*80 + "\n")
    
    # Train all pairs
    results = train_models_parallel_pairs(
        MODEL_PAIRS,
        TRAINING_CONFIG,
        train_loader,
        val_loader,
        class_weights
    )
    
else:
    print("\n" + "="*80)
    print("SEQUENTIAL TRAINING MODE - Training 6 models one by one")
    print("="*80)
    print("="*80 + "\n")
    
    # Train models sequentially
    results = {}
    for model_name in ALL_MODELS:
        history, output_dir = train_model(
            model_name,
            TRAINING_CONFIG,
            train_loader,
            val_loader,
            class_weights
        )
        results[model_name] = {
            'history': history,
            'output_dir': output_dir
        }
        
        # Clear GPU cache between models
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

# ============================================================================
# EXTRACT RESULTS FOR ALL MODELS
# ============================================================================

# Extract individual results (works for both parallel and sequential)
unet_history = results['unet++']['history']
unet_output_dir = results['unet++']['output_dir']

deeplab_history = results['deeplabv3+']['history']
deeplab_output_dir = results['deeplabv3+']['output_dir']

segformer_history = results['segformer']['history']
segformer_output_dir = results['segformer']['output_dir']

fcsiamdiff_history = results['fc_siam_diff']['history']
fcsiamdiff_output_dir = results['fc_siam_diff']['output_dir']

siamese_unet_history = results['siamese_unet++']['history']
siamese_unet_output_dir = results['siamese_unet++']['output_dir']

stanet_history = results['stanet']['history']
stanet_output_dir = results['stanet']['output_dir']

# ============================================================================
# TRAINING COMPLETE
# ============================================================================

print("\n" + "="*80)
print("ALL 6 MODELS TRAINED SUCCESSFULLY!")
print("="*80)
print(f"Training mode: {'PARALLEL' if USE_PARALLEL_TRAINING else 'SEQUENTIAL'}")
print(f"\nResults stored:")
for model_name, result in results.items():
    print(f"  {model_name.upper()}: {result['output_dir']}")
print("="*80 + "\n")


---

## 7. Training Metrics Visualization

In [None]:
def plot_training_history(history, model_name, save_path=None):
    """Plot training history for a single model."""
    fig = plt.figure(figsize=(16, 10))
    gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss plot
    ax1 = fig.add_subplot(gs[0, :])
    ax1.plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    ax1.plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title(f'{model_name} - Training and Validation Loss', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # IoU plot
    ax2 = fig.add_subplot(gs[1, 0])
    ax2.plot(epochs, history['train_iou'], 'b-', label='Train IoU', linewidth=2)
    ax2.plot(epochs, history['val_iou'], 'r-', label='Val IoU', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=11)
    ax2.set_ylabel('IoU', fontsize=11)
    ax2.set_title('Mean IoU', fontsize=12, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    
    # Dice plot
    ax3 = fig.add_subplot(gs[1, 1])
    ax3.plot(epochs, history['train_dice'], 'b-', label='Train Dice', linewidth=2)
    ax3.plot(epochs, history['val_dice'], 'r-', label='Val Dice', linewidth=2)
    ax3.set_xlabel('Epoch', fontsize=11)
    ax3.set_ylabel('Dice', fontsize=11)
    ax3.set_title('Mean Dice Coefficient', fontsize=12, fontweight='bold')
    ax3.legend(fontsize=10)
    ax3.grid(True, alpha=0.3)
    
    plt.suptitle(f'{model_name} Training Metrics', fontsize=16, fontweight='bold', y=0.995)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()
    
    # Print best metrics
    best_epoch = max(range(len(history['val_iou'])), key=lambda i: history['val_iou'][i])
    print(f"\n{'='*80}")
    print(f"{model_name} - Best Validation Metrics (Epoch {best_epoch + 1})")
    print(f"{'='*80}")
    print(f"Mean IoU: {history['val_iou'][best_epoch]:.4f}")
    print(f"Mean Dice: {history['val_dice'][best_epoch]:.4f}")
    print(f"Mean F1: {history['val_f1'][best_epoch]:.4f}")
    print(f"{'='*80}\n")

## 8. Model Comparison

In [None]:
# Compare all models
def compare_models(histories, model_names):
    """Compare multiple models."""
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
    
    # Plot 1: Validation Loss
    for history, name, color in zip(histories, model_names, colors):
        epochs = range(1, len(history['val_loss']) + 1)
        axes[0, 0].plot(epochs, history['val_loss'], label=name, linewidth=2, color=color)
    axes[0, 0].set_xlabel('Epoch', fontsize=12)
    axes[0, 0].set_ylabel('Loss', fontsize=12)
    axes[0, 0].set_title('Validation Loss Comparison', fontsize=14, fontweight='bold')
    axes[0, 0].legend(fontsize=10)
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Validation IoU
    for history, name, color in zip(histories, model_names, colors):
        epochs = range(1, len(history['val_iou']) + 1)
        axes[0, 1].plot(epochs, history['val_iou'], label=name, linewidth=2, color=color)
    axes[0, 1].set_xlabel('Epoch', fontsize=12)
    axes[0, 1].set_ylabel('IoU', fontsize=12)
    axes[0, 1].set_title('Validation Mean IoU Comparison', fontsize=14, fontweight='bold')
    axes[0, 1].legend(fontsize=10)
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot 3: Validation Dice
    for history, name, color in zip(histories, model_names, colors):
        epochs = range(1, len(history['val_dice']) + 1)
        axes[1, 0].plot(epochs, history['val_dice'], label=name, linewidth=2, color=color)
    axes[1, 0].set_xlabel('Epoch', fontsize=12)
    axes[1, 0].set_ylabel('Dice', fontsize=12)
    axes[1, 0].set_title('Validation Mean Dice Comparison', fontsize=14, fontweight='bold')
    axes[1, 0].legend(fontsize=10)
    axes[1, 0].grid(True, alpha=0.3)
    
    # Plot 4: Best Metrics Bar Chart
    x = np.arange(len(model_names))
    width = 0.25
    
    # Get best metrics for each model
    iou_values = []
    dice_values = []
    f1_values = []
    
    for history in histories:
        best_epoch = max(range(len(history['val_iou'])), key=lambda i: history['val_iou'][i])
        iou_values.append(history['val_iou'][best_epoch])
        dice_values.append(history['val_dice'][best_epoch])
        f1_values.append(history['val_f1'][best_epoch])
    
    axes[1, 1].bar(x - width, iou_values, width, label='IoU', alpha=0.8)
    axes[1, 1].bar(x, dice_values, width, label='Dice', alpha=0.8)
    axes[1, 1].bar(x + width, f1_values, width, label='F1', alpha=0.8)
    
    axes[1, 1].set_xlabel('Model', fontsize=12)
    axes[1, 1].set_ylabel('Score', fontsize=12)
    axes[1, 1].set_title('Best Validation Metrics Comparison', fontsize=14, fontweight='bold')
    axes[1, 1].set_xticks(x)
    axes[1, 1].set_xticklabels(model_names, rotation=45, ha='right')
    axes[1, 1].legend(fontsize=11)
    axes[1, 1].grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('../outputs/model_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print comparison table
    print(f"\n{'='*100}")
    print("MODEL COMPARISON - BEST VALIDATION METRICS")
    print(f"{'='*100}")
    print(f"{'Model':<20} {'Best Epoch':<12} {'Mean IoU':<12} {'Mean Dice':<12} {'Mean F1':<12}")
    print(f"{'-'*100}")
    
    for name, history in zip(model_names, histories):
        best_epoch = max(range(len(history['val_iou'])), key=lambda i: history['val_iou'][i])
        
        print(f"{name:<20} {best_epoch+1:<12} {history['val_iou'][best_epoch]:<12.4f} "
              f"{history['val_dice'][best_epoch]:<12.4f} {history['val_f1'][best_epoch]:<12.4f}")
    
    print(f"{'='*100}\n")

# Compare all models
compare_models(
    [unet_history, deeplab_history, segformer_history, fcsiamdiff_history, siamese_unet_history, stanet_history],
    ['U-Net++', 'DeepLabV3+', 'SegFormer', 'FC-Siam-Diff', 'Siamese U-Net++', 'STANet']
)

## 9. TensorBoard Visualization

In [None]:
# Launch TensorBoard to view training metrics
print("To view TensorBoard:")
print("1. Run in terminal: tensorboard --logdir=../outputs/tensorboard --port=6006")
print("2. Open browser: http://localhost:6006")
print("\nTensorBoard shows:")
print("  - Training/validation loss curves")
print("  - IoU, Dice, F1 metrics over time")
print("  - Per-class performance")
print("  - Learning rate schedules")
print("  - Model graphs")