# Advanced Training Techniques - IRST Library
## Multi-GPU Training, Hyperparameter Optimization & Custom Losses

This notebook demonstrates advanced training techniques for infrared small target detection using the IRST Library.

### 🎯 **What you'll learn:**
- 🔥 **Multi-GPU Training** - Distributed training across multiple GPUs
- 🎛️ **Hyperparameter Optimization** - Automated hyperparameter tuning with Optuna
- 🧠 **Custom Loss Functions** - Design and implement domain-specific losses
- 📊 **Advanced Metrics** - Comprehensive evaluation with custom metrics
- 🚀 **Mixed Precision Training** - Accelerated training with reduced memory usage
- 🔄 **Learning Rate Schedules** - Advanced scheduling strategies
- 💾 **Model Checkpointing** - Advanced saving and loading strategies
- 📈 **Experiment Tracking** - MLflow integration for experiment management

### 🛠️ **Prerequisites:**
- Completed the basic tutorial
- PyTorch with CUDA support
- Multiple GPUs (recommended)
- Understanding of deep learning fundamentals

## 1. Setup and Imports

In [None]:
# Core libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Deep learning libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler

# IRST Library
from irst_library import IRSTDetector
from irst_library.datasets import SIRSTDataset, IRSTD1kDataset
from irst_library.models import SERANKNet, ACMNet, MSHNet
from irst_library.training import IRSTTrainer, DistributedTrainer
from irst_library.evaluation import IRSTEvaluator
from irst_library.utils import visualize_detection, plot_metrics
from irst_library.losses import FocalLoss, DiceLoss, IoULoss, TverskyLoss

# Hyperparameter optimization
import optuna
from optuna.integration import PyTorchLightningPruningCallback

# Experiment tracking
import mlflow
import mlflow.pytorch

# Advanced utilities
from sklearn.model_selection import StratifiedKFold
from torch.optim.lr_scheduler import OneCycleLR, CosineAnnealingWarmRestarts
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Progress tracking
from tqdm.auto import tqdm
import logging
from datetime import datetime

# Set up environment
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'  # Use multiple GPUs if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
world_size = torch.cuda.device_count()

print(f"🚀 Advanced Training Setup Complete!")
print(f"   💻 Device: {device}")
print(f"   🔥 Available GPUs: {world_size}")
print(f"   🧠 PyTorch Version: {torch.__version__}")
print(f"   📊 CUDA Version: {torch.version.cuda}")

# Check distributed training capability
if world_size > 1:
    print(f"   ✅ Multi-GPU training available with {world_size} GPUs")
else:
    print(f"   ⚠️ Single GPU mode - consider using multiple GPUs for faster training")

## 2. Multi-GPU Distributed Training

Multi-GPU training can significantly speed up training while handling larger batch sizes. Let's set up distributed training with PyTorch's DistributedDataParallel.

In [None]:
def setup_distributed(rank, world_size):
    """Initialize distributed training."""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    """Clean up distributed training."""
    dist.destroy_process_group()

class DistributedIRSTTrainer:
    """Advanced trainer with multi-GPU support."""
    
    def __init__(self, model, train_loader, val_loader, config, rank, world_size):
        self.rank = rank
        self.world_size = world_size
        self.device = torch.device(f'cuda:{rank}')
        
        # Move model to device and wrap with DDP
        self.model = model.to(self.device)
        self.model = DDP(self.model, device_ids=[rank])
        
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        # Setup optimizer and scheduler
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        
        # Mixed precision scaler
        self.scaler = GradScaler() if config.get('mixed_precision', True) else None
        
        # Learning rate scheduler
        self.scheduler = self.setup_scheduler()
        
        # Loss functions
        self.setup_losses()
        
        print(f"🚀 Distributed trainer initialized on rank {rank}/{world_size}")
    
    def setup_scheduler(self):
        """Setup advanced learning rate scheduler."""
        if self.config.get('scheduler') == 'onecycle':
            return OneCycleLR(
                self.optimizer,
                max_lr=self.config['learning_rate'],
                epochs=self.config['epochs'],
                steps_per_epoch=len(self.train_loader)
            )
        elif self.config.get('scheduler') == 'cosine_restart':
            return CosineAnnealingWarmRestarts(
                self.optimizer,
                T_0=10,
                T_mult=2,
                eta_min=1e-6
            )
        else:
            return optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=self.config['epochs']
            )
    
    def setup_losses(self):
        """Setup custom loss functions."""
        self.losses = {
            'focal': FocalLoss(alpha=0.25, gamma=2.0),
            'dice': DiceLoss(smooth=1.0),
            'iou': IoULoss(),
            'tversky': TverskyLoss(alpha=0.7, beta=0.3)
        }
        self.loss_weights = self.config.get('loss_weights', {
            'focal': 0.4, 'dice': 0.3, 'iou': 0.2, 'tversky': 0.1
        })

# Setup data for distributed training
def prepare_distributed_data():
    """Prepare data loaders for distributed training."""
    
    # Load dataset
    from irst_library.datasets import SIRSTDataset
    
    # Training transforms with advanced augmentation
    train_transform = A.Compose([
        A.OneOf([
            A.RandomRotate90(p=1.0),
            A.Flip(p=1.0),
            A.Transpose(p=1.0),
        ], p=0.7),
        
        A.ShiftScaleRotate(
            shift_limit=0.1,
            scale_limit=0.2,
            rotate_limit=30,
            p=0.6
        ),
        
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3),
            A.GaussNoise(var_limit=(10.0, 50.0)),
            A.GaussianBlur(blur_limit=(3, 7)),
            A.MotionBlur(blur_limit=7),
        ], p=0.5),
        
        # Advanced infrared-specific augmentations
        A.RandomGamma(gamma_limit=(80, 120), p=0.3),
        A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.3),
        
        A.Normalize(mean=[0.485], std=[0.229], max_pixel_value=255.0),
        ToTensorV2()
    ], additional_targets={'mask': 'mask'})
    
    val_transform = A.Compose([
        A.Normalize(mean=[0.485], std=[0.229], max_pixel_value=255.0),
        ToTensorV2()
    ], additional_targets={'mask': 'mask'})
    
    # Create datasets
    train_dataset = SIRSTDataset(
        root="./data/SIRST",
        split="train",
        transform=train_transform
    )
    
    val_dataset = SIRSTDataset(
        root="./data/SIRST", 
        split="test",
        transform=val_transform
    )
    
    return train_dataset, val_dataset

# Initialize distributed training if multiple GPUs available
if world_size > 1:
    print(f"🔥 Setting up distributed training with {world_size} GPUs")
    
    # Prepare datasets
    train_dataset, val_dataset = prepare_distributed_data()
    
    # Advanced training configuration
    distributed_config = {
        'epochs': 100,
        'batch_size': 16,  # Per GPU batch size
        'learning_rate': 2e-4,  # Scaled for multi-GPU
        'weight_decay': 1e-4,
        'scheduler': 'onecycle',
        'mixed_precision': True,
        'gradient_clipping': 1.0,
        'warmup_epochs': 10,
        'loss_weights': {
            'focal': 0.4,
            'dice': 0.3,
            'iou': 0.2,
            'tversky': 0.1
        },
        'save_best_only': True,
        'monitor_metric': 'val_iou',
        'early_stopping_patience': 15
    }
    
    print("✅ Distributed training configuration:")
    for key, value in distributed_config.items():
        print(f"   {key}: {value}")
        
else:
    print("⚠️ Single GPU mode - distributed training features will be simulated")
    
    # Prepare standard datasets for single GPU
    train_dataset, val_dataset = prepare_distributed_data()
    distributed_config = {
        'epochs': 50,
        'batch_size': 8,  # Smaller batch for single GPU
        'learning_rate': 1e-4,
        'weight_decay': 1e-5,
        'scheduler': 'cosine',
        'mixed_precision': True
    }