In [None]:
#Import all libraries
import sys
import h5py
import json
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from pathlib import Path
from typing import Dict
import pandas as pd
import torch
import numpy as np
from typing import Dict, List, Tuple, Optional
import cv2
import random
from scipy.ndimage import rotate, shift, zoom
from skimage.restoration import denoise_nl_means, estimate_sigma
from skimage.filters import gaussian, median
from scipy import ndimage
import os
from torch.utils.tensorboard import SummaryWriter
import time



In [None]:
# Mount Google Drive (Run this cell FIRST in Colab)
# If this cell fails:
# 1. Go to Runtime -> Restart runtime
# 2. Run this cell again and grant permissions
# 3. Wait for "Mounted at /content/drive" message

try:
    from google.colab import drive
    import os
    
    mount_point = '/content/drive'
    
    # Check if already mounted
    if os.path.ismount(mount_point):
        print("✓ Google Drive is already mounted!")
    else:
        print("Mounting Google Drive...")
        print("Please grant permissions when prompted.")
        drive.mount(mount_point)
        print("✓ Google Drive mounted successfully!")
    
    # Verify the dataset file exists
    dataset_path = '/content/drive/MyDrive/brain_tumor_dataset.h5'
    if os.path.exists(dataset_path):
        print(f"✓ Dataset found at: {dataset_path}")
    else:
        print(f"⚠ WARNING: Dataset not found at {dataset_path}")
        print("Please ensure 'brain_tumor_dataset.h5' is in your Google Drive's MyDrive folder.")
        
except ImportError:
    print("Not running in Colab - skipping Google Drive mount")
except Exception as e:
    print(f"Error: {e}")
    print("\nIf mounting fails, try these steps:")
    print("1. Runtime -> Restart runtime")
    print("2. Clear all outputs")
    print("3. Run this cell again")
    print("4. Make sure to grant permissions when the popup appears")


In [None]:
"""
Configuration management for brain tumor segmentation project.
"""

class Config:
    """Project configuration settings."""
    
    # Project paths
    ENVIRONMENT = "colab"  # Options: "local", "colab", "server"
    PROJECT_ROOT = Path('__file__').parent.parent
    DATA_DIR = PROJECT_ROOT / "data"
    RAW_DATA_DIR = DATA_DIR / "raw"
    PROCESSED_DATA_DIR = DATA_DIR / "processed"
    EXTERNAL_DATA_DIR = DATA_DIR / "external"
    SPLITS_DIR = DATA_DIR / "splits"
    
    MODELS_DIR = PROJECT_ROOT / "models"
    CUSTOM_MODELS_DIR = MODELS_DIR / "custom"
    PRETRAINED_MODELS_DIR = MODELS_DIR / "pretrained"
    SAVED_MODELS_DIR = MODELS_DIR / "saved"
    
    RESULTS_DIR = PROJECT_ROOT / "results"
    LOGS_DIR = PROJECT_ROOT / "logs"
    
    # Dataset paths
    HDF5_DATASET_PATH = PROJECT_ROOT / "dataset" / "brain_tumor_dataset.h5"
    
    # Data settings
    IMAGE_SIZE = 256  # Standardize to 256x256
    TRAIN_SPLIT = 0.70
    VAL_SPLIT = 0.15
    TEST_SPLIT = 0.15
    RANDOM_SEED = 42
    
    # Preprocessing settings
    DENOISING_METHOD = "bilateral"  # Options: gaussian, median, bilateral, nlm, wavelet
    NORMALIZATION_METHOD = "z_score"  # Options: z_score, min_max, histogram_eq, clahe
    
    # Data augmentation settings
    AUGMENTATION_ENABLED = True
    ROTATION_RANGE = 15  # degrees
    TRANSLATION_RANGE = 0.1  # 10% of image size
    SCALE_RANGE = (0.9, 1.1)
    BRIGHTNESS_RANGE = 0.1  # ±10%
    CONTRAST_RANGE = 0.1  # ±10%
    FLIP_PROBABILITY = 0.5
    
    # Training settings
    BATCH_SIZE = 16
    NUM_EPOCHS = 100
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    EARLY_STOPPING_PATIENCE = 15
    REDUCE_LR_PATIENCE = 10
    REDUCE_LR_FACTOR = 0.5
    MIN_LEARNING_RATE = 1e-7
    
    # Model settings
    NUM_CLASSES = 1  # Binary segmentation
    DROPOUT_RATE = 0.2
    
    # Custom model settings
    CUSTOM_MODEL_FILTERS = [64, 128, 256, 512, 1024]
    
    # Pre-trained model settings
    PRETRAINED_MODEL_TYPE = "unet"  # Options: unet, resnet, vgg
    PRETRAINED_ENCODER = "resnet50"  # For ResNet-based models
    FREEZE_ENCODER_EPOCHS = 5  # Freeze encoder for first N epochs
    
    # Evaluation settings
    EVAL_METRICS = ["dice", "iou", "accuracy", "sensitivity", "specificity", "precision", "f1"]
    
    # Hardware settings
    NUM_WORKERS = 4
    PIN_MEMORY = True
    USE_MIXED_PRECISION = True
    
    # Logging settings
    LOG_INTERVAL = 10  # Log every N batches
    SAVE_INTERVAL = 5  # Save checkpoint every N epochs
    VISUALIZE_INTERVAL = 50  # Visualize predictions every N batches
    
    @classmethod
    def get_environment(cls):
        return cls.ENVIRONMENT

    @classmethod
    def get_data_from_google_drive(cls):
        """Get data from Google Drive (assumes drive is already mounted)."""
        if cls.ENVIRONMENT == "colab":
            import os
            
            # Path to dataset in Google Drive
            dataset_path = '/content/drive/MyDrive/brain_tumor_dataset.h5'
            
            # Check if the file exists
            if os.path.exists(dataset_path):
                print(f"✓ Using dataset from: {dataset_path}")
                return dataset_path
            else:
                error_msg = (
                    f"\n{'='*60}\n"
                    f"ERROR: Dataset not found!\n"
                    f"{'='*60}\n"
                    f"Expected location: {dataset_path}\n\n"
                    f"Please ensure:\n"
                    f"1. You ran the 'Mount Google Drive' cell first\n"
                    f"2. The file 'brain_tumor_dataset.h5' is in your Google Drive's MyDrive folder\n"
                    f"3. The mount was successful (you should see 'Mounted at /content/drive')\n"
                    f"{'='*60}"
                )
                raise FileNotFoundError(error_msg)
                
    @classmethod
    def create_directories(cls):
        """Create all necessary directories."""
        directories = [
            cls.DATA_DIR, cls.RAW_DATA_DIR, cls.PROCESSED_DATA_DIR,
            cls.EXTERNAL_DATA_DIR, cls.SPLITS_DIR,
            cls.MODELS_DIR, cls.CUSTOM_MODELS_DIR, cls.PRETRAINED_MODELS_DIR,
            cls.SAVED_MODELS_DIR, cls.RESULTS_DIR, cls.LOGS_DIR
        ]
        for directory in directories:
            directory.mkdir(parents=True, exist_ok=True)
    
    @classmethod
    def get_model_path(cls, model_name: str, model_type: str = "custom"):
        """Get path for saving/loading model."""
        if model_type == "custom":
            return cls.SAVED_MODELS_DIR / f"{model_name}_custom.pth"
        else:
            return cls.SAVED_MODELS_DIR / f"{model_name}_pretrained.pth"



In [None]:
"""
Model comparison utilities.
"""

def generate_comparison_report(
    results_path: Path,
    output_dir: Path
):
    """
    Generate comprehensive comparison report.
    
    Args:
        results_path: Path to evaluation results JSON
        output_dir: Directory to save report
    """
    # Load results
    with open(results_path, 'r') as f:
        results = json.load(f)
    
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create comparison table
    if 'custom' in results and 'pretrained' in results:
        comparison_data = []
        
        for metric in ['dice', 'iou', 'accuracy', 'sensitivity', 'specificity', 'precision', 'f1']:
            custom_val = results['custom']['metrics'].get(metric, 0)
            pretrained_val = results['pretrained']['metrics'].get(metric, 0)
            diff = pretrained_val - custom_val
            
            comparison_data.append({
                'Metric': metric.upper(),
                'Custom Model': custom_val,
                'Pre-trained Model': pretrained_val,
                'Difference': diff,
                'Winner': 'pretrained' if diff > 0 else 'custom'
            })
        
        df = pd.DataFrame(comparison_data)
        
        # Save as CSV
        csv_path = output_dir / "comparison_table.csv"
        df.to_csv(csv_path, index=False)
        print(f"Comparison table saved to: {csv_path}")
        
        # Save as markdown
        md_path = output_dir / "comparison_report.md"
        with open(md_path, 'w') as f:
            f.write("# Model Comparison Report\n\n")
            f.write(f"**Overall Winner**: {results.get('overall_winner', 'N/A')}\n\n")
            f.write("## Metrics Comparison\n\n")
            f.write(df.to_markdown(index=False))
            f.write("\n\n")
            f.write("## Summary\n\n")
            f.write(f"- Custom Model Best Val DSC: {results['custom'].get('best_val_dice', 'N/A')}\n")
            f.write(f"- Pre-trained Model Best Val DSC: {results['pretrained'].get('best_val_dice', 'N/A')}\n")
            f.write(f"- Custom Model Epochs: {results['custom'].get('epoch', 'N/A')}\n")
            f.write(f"- Pre-trained Model Epochs: {results['pretrained'].get('epoch', 'N/A')}\n")
        
        print(f"Comparison report saved to: {md_path}")
        
        # Create visualization
        plot_path = output_dir / "metrics_comparison.png"
        plot_metrics_comparison(
            results['custom']['metrics'],
            results['pretrained']['metrics'],
            save_path=plot_path
        )
    
    print(f"\nComparison report generated in: {output_dir}")



In [None]:
"""
Evaluation metrics for brain tumor segmentation.
Primary metric: Dice Similarity Coefficient (DSC)
"""


def dice_coefficient(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    threshold: float = 0.5,
    smooth: float = 1e-6
) -> float:
    """
    Compute Dice Similarity Coefficient (DSC).
    
    Args:
        predictions: Predicted masks (B, 1, H, W) with values in [0, 1]
        targets: Ground truth masks (B, 1, H, W) with values in [0, 1]
        threshold: Threshold for binarizing predictions
        smooth: Smoothing factor to avoid division by zero
    
    Returns:
        Dice coefficient (0-1, higher is better)
    """
    # Binarize predictions
    predictions_binary = (predictions > threshold).float()
    targets_binary = targets.float()
    
    # Flatten tensors
    predictions_flat = predictions_binary.view(-1)
    targets_flat = targets_binary.view(-1)
    
    # Compute intersection and union
    intersection = (predictions_flat * targets_flat).sum()
    dice = (2.0 * intersection + smooth) / (
        predictions_flat.sum() + targets_flat.sum() + smooth
    )
    
    return dice.item()


def iou_score(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    threshold: float = 0.5,
    smooth: float = 1e-6
) -> float:
    """
    Compute Intersection over Union (IoU).
    
    Args:
        predictions: Predicted masks
        targets: Ground truth masks
        threshold: Threshold for binarizing predictions
        smooth: Smoothing factor
    
    Returns:
        IoU score (0-1, higher is better)
    """
    # Binarize predictions
    predictions_binary = (predictions > threshold).float()
    targets_binary = targets.float()
    
    # Flatten tensors
    predictions_flat = predictions_binary.view(-1)
    targets_flat = targets_binary.view(-1)
    
    # Compute intersection and union
    intersection = (predictions_flat * targets_flat).sum()
    union = predictions_flat.sum() + targets_flat.sum() - intersection
    
    iou = (intersection + smooth) / (union + smooth)
    
    return iou.item()


def pixel_accuracy(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    threshold: float = 0.5
) -> float:
    """
    Compute pixel accuracy.
    
    Args:
        predictions: Predicted masks
        targets: Ground truth masks
        threshold: Threshold for binarizing predictions
    
    Returns:
        Pixel accuracy (0-1, higher is better)
    """
    predictions_binary = (predictions > threshold).float()
    targets_binary = targets.float()
    
    correct = (predictions_binary == targets_binary).float()
    accuracy = correct.sum() / correct.numel()
    
    return accuracy.item()


def sensitivity(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    threshold: float = 0.5,
    smooth: float = 1e-6
) -> float:
    """
    Compute sensitivity (recall, true positive rate).
    
    Args:
        predictions: Predicted masks
        targets: Ground truth masks
        threshold: Threshold for binarizing predictions
        smooth: Smoothing factor
    
    Returns:
        Sensitivity (0-1, higher is better)
    """
    predictions_binary = (predictions > threshold).float()
    targets_binary = targets.float()
    
    # True positives and false negatives
    tp = ((predictions_binary == 1) & (targets_binary == 1)).float().sum()
    fn = ((predictions_binary == 0) & (targets_binary == 1)).float().sum()
    
    sensitivity = (tp + smooth) / (tp + fn + smooth)
    
    return sensitivity.item()


def specificity(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    threshold: float = 0.5,
    smooth: float = 1e-6
) -> float:
    """
    Compute specificity (true negative rate).
    
    Args:
        predictions: Predicted masks
        targets: Ground truth masks
        threshold: Threshold for binarizing predictions
        smooth: Smoothing factor
    
    Returns:
        Specificity (0-1, higher is better)
    """
    predictions_binary = (predictions > threshold).float()
    targets_binary = targets.float()
    
    # True negatives and false positives
    tn = ((predictions_binary == 0) & (targets_binary == 0)).float().sum()
    fp = ((predictions_binary == 1) & (targets_binary == 0)).float().sum()
    
    specificity = (tn + smooth) / (tn + fp + smooth)
    
    return specificity.item()


def precision_score(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    threshold: float = 0.5,
    smooth: float = 1e-6
) -> float:
    """
    Compute precision (positive predictive value).
    
    Args:
        predictions: Predicted masks
        targets: Ground truth masks
        threshold: Threshold for binarizing predictions
        smooth: Smoothing factor
    
    Returns:
        Precision (0-1, higher is better)
    """
    predictions_binary = (predictions > threshold).float()
    targets_binary = targets.float()
    
    # True positives and false positives
    tp = ((predictions_binary == 1) & (targets_binary == 1)).float().sum()
    fp = ((predictions_binary == 1) & (targets_binary == 0)).float().sum()
    
    precision = (tp + smooth) / (tp + fp + smooth)
    
    return precision.item()


def f1_score(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    threshold: float = 0.5,
    smooth: float = 1e-6
) -> float:
    """
    Compute F1 score (harmonic mean of precision and recall).
    
    Args:
        predictions: Predicted masks
        targets: Ground truth masks
        threshold: Threshold for binarizing predictions
        smooth: Smoothing factor
    
    Returns:
        F1 score (0-1, higher is better)
    """
    prec = precision_score(predictions, targets, threshold, smooth)
    sens = sensitivity(predictions, targets, threshold, smooth)
    
    f1 = (2 * prec * sens + smooth) / (prec + sens + smooth)
    
    return f1


def compute_all_metrics(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    threshold: float = 0.5
) -> Dict[str, float]:
    """
    Compute all evaluation metrics.
    
    Args:
        predictions: Predicted masks
        targets: Ground truth masks
        threshold: Threshold for binarizing predictions
    
    Returns:
        Dictionary of metric names and values
    """
    metrics = {
        'dice': dice_coefficient(predictions, targets, threshold),
        'iou': iou_score(predictions, targets, threshold),
        'accuracy': pixel_accuracy(predictions, targets, threshold),
        'sensitivity': sensitivity(predictions, targets, threshold),
        'specificity': specificity(predictions, targets, threshold),
        'precision': precision_score(predictions, targets, threshold),
        'f1': f1_score(predictions, targets, threshold)
    }
    
    return metrics


def evaluate_model(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    device: torch.device,
    threshold: float = 0.5
) -> Dict[str, float]:
    """
    Evaluate model on a dataset.
    
    Args:
        model: Trained model
        dataloader: DataLoader for evaluation
        device: Device to run evaluation on
        threshold: Threshold for binarizing predictions
    
    Returns:
        Dictionary of average metrics
    """
    model.eval()
    
    all_metrics = {
        'dice': [],
        'iou': [],
        'accuracy': [],
        'sensitivity': [],
        'specificity': [],
        'precision': [],
        'f1': []
    }
    
    with torch.no_grad():
        for batch in dataloader:
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)
            
            # Forward pass
            predictions = model(images)
            
            # Compute metrics for this batch
            batch_metrics = compute_all_metrics(predictions, masks, threshold)
            
            # Accumulate
            for key in all_metrics:
                all_metrics[key].append(batch_metrics[key])
    
    # Compute averages
    avg_metrics = {key: np.mean(values) for key, values in all_metrics.items()}
    
    return avg_metrics



In [None]:
"""
Visualization utilities for model predictions and results.
"""
import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Optional, List
import seaborn as sns


def visualize_predictions(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    device: torch.device,
    num_samples: int = 8,
    threshold: float = 0.5,
    save_path: Optional[Path] = None
):
    """
    Visualize model predictions on sample images.
    
    Args:
        model: Trained model
        dataloader: DataLoader for samples
        device: Device to run inference on
        num_samples: Number of samples to visualize
        threshold: Threshold for binarizing predictions
        save_path: Path to save visualization
    """
    model.eval()
    
    # Get a batch
    batch = next(iter(dataloader))
    images = batch['image'].to(device)
    masks = batch['mask'].to(device)
    
    # Limit to num_samples
    images = images[:num_samples]
    masks = masks[:num_samples]
    
    # Get predictions
    with torch.no_grad():
        predictions = model(images)
        predictions_binary = (predictions > threshold).float()
    
    # Move to CPU and convert to numpy
    images = images.cpu().numpy()
    masks = masks.cpu().numpy()
    predictions = predictions.cpu().numpy()
    predictions_binary = predictions_binary.cpu().numpy()
    
    # Create figure
    fig, axes = plt.subplots(num_samples, 4, figsize=(16, 4 * num_samples))
    
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(num_samples):
        # Original image
        axes[i, 0].imshow(images[i, 0], cmap='gray')
        axes[i, 0].set_title('Original Image')
        axes[i, 0].axis('off')
        
        # Ground truth mask
        axes[i, 1].imshow(masks[i, 0], cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        # Predicted mask (probability)
        axes[i, 2].imshow(predictions[i, 0], cmap='hot')
        axes[i, 2].set_title('Prediction (Probability)')
        axes[i, 2].axis('off')
        
        # Predicted mask (binary)
        axes[i, 3].imshow(predictions_binary[i, 0], cmap='gray')
        axes[i, 3].set_title('Prediction (Binary)')
        axes[i, 3].axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Visualization saved to {save_path}")
    else:
        plt.show()
    
    plt.close()


def plot_training_curves(
    train_losses: List[float],
    val_dice_scores: List[float],
    save_path: Optional[Path] = None
):
    """
    Plot training curves.
    
    Args:
        train_losses: List of training losses per epoch
        val_dice_scores: List of validation DSC scores per epoch
        save_path: Path to save plot
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Training loss
    ax1.plot(train_losses, label='Training Loss', color='blue')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Validation DSC
    ax2.plot(val_dice_scores, label='Validation DSC', color='green')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Dice Similarity Coefficient')
    ax2.set_title('Validation DSC')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Training curves saved to {save_path}")
    else:
        plt.show()
    
    plt.close()


def plot_metrics_comparison(
    custom_metrics: dict,
    pretrained_metrics: dict,
    save_path: Optional[Path] = None
):
    """
    Plot comparison of metrics between models.
    
    Args:
        custom_metrics: Metrics dictionary for custom model
        pretrained_metrics: Metrics dictionary for pre-trained model
        save_path: Path to save plot
    """
    metrics_names = ['dice', 'iou', 'accuracy', 'sensitivity', 'specificity', 'precision', 'f1']
    
    custom_values = [custom_metrics.get(m, 0) for m in metrics_names]
    pretrained_values = [pretrained_metrics.get(m, 0) for m in metrics_names]
    
    x = np.arange(len(metrics_names))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    bars1 = ax.bar(x - width/2, custom_values, width, label='Custom Model', alpha=0.8)
    bars2 = ax.bar(x + width/2, pretrained_values, width, label='Pre-trained Model', alpha=0.8)
    
    ax.set_ylabel('Score')
    ax.set_title('Model Comparison - Evaluation Metrics')
    ax.set_xticks(x)
    ax.set_xticklabels([m.upper() for m in metrics_names], rotation=45, ha='right')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_ylim([0, 1.1])
    
    # Add value labels on bars
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.3f}',
                   ha='center', va='bottom', fontsize=8)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Comparison plot saved to {save_path}")
    else:
        plt.show()
    
    plt.close()


def create_confusion_matrix_plot(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    threshold: float = 0.5,
    save_path: Optional[Path] = None
):
    """
    Create and plot confusion matrix.
    
    Args:
        predictions: Predicted masks
        targets: Ground truth masks
        threshold: Threshold for binarizing predictions
        save_path: Path to save plot
    """
    predictions_binary = (predictions > threshold).float()
    targets_binary = targets.float()
    
    # Flatten
    pred_flat = predictions_binary.view(-1).cpu().numpy()
    target_flat = targets_binary.view(-1).cpu().numpy()
    
    # Compute confusion matrix
    from sklearn.metrics import confusion_matrix
    cm = confusion_matrix(target_flat, pred_flat, labels=[0, 1])
    
    # Plot
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                xticklabels=['Background', 'Tumor'],
                yticklabels=['Background', 'Tumor'])
    ax.set_xlabel('Predicted')
    ax.set_ylabel('Actual')
    ax.set_title('Confusion Matrix')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Confusion matrix saved to {save_path}")
    else:
        plt.show()
    
    plt.close()



In [None]:
"""
Data augmentation transforms for training.
"""
class AugmentationTransform:
    """Apply augmentation to both image and mask."""
    
    def __init__(
        self,
        rotation_range: float = 15,
        translation_range: float = 0.1,
        scale_range: Tuple[float, float] = (0.9, 1.1),
        brightness_range: float = 0.1,
        contrast_range: float = 0.1,
        flip_probability: float = 0.5,
        elastic_deformation: bool = False
    ):
        """
        Initialize augmentation parameters.
        
        Args:
            rotation_range: Maximum rotation in degrees
            translation_range: Maximum translation as fraction of image size
            scale_range: (min, max) scaling factors
            brightness_range: Maximum brightness adjustment (±)
            contrast_range: Maximum contrast adjustment (±)
            flip_probability: Probability of horizontal/vertical flip
            elastic_deformation: Whether to apply elastic deformation
        """
        self.rotation_range = rotation_range
        self.translation_range = translation_range
        self.scale_range = scale_range
        self.brightness_range = brightness_range
        self.contrast_range = contrast_range
        self.flip_probability = flip_probability
        self.elastic_deformation = elastic_deformation
    
    def __call__(self, sample: Dict) -> Dict:
        """
        Apply augmentation to sample.
        
        Args:
            sample: Dictionary with 'image' and 'mask' keys
        
        Returns:
            Augmented sample
        """
        image = sample['image'].copy()
        mask = sample['mask'].copy()
        
        # Random rotation
        if self.rotation_range > 0:
            angle = random.uniform(-self.rotation_range, self.rotation_range)
            image = self._rotate(image, angle)
            mask = self._rotate(mask, angle)
        
        # Random translation
        if self.translation_range > 0:
            h, w = image.shape[:2]
            tx = random.uniform(-self.translation_range, self.translation_range) * w
            ty = random.uniform(-self.translation_range, self.translation_range) * h
            image = self._translate(image, tx, ty)
            mask = self._translate(mask, tx, ty)
        
        # Random scaling
        if self.scale_range[0] != 1.0 or self.scale_range[1] != 1.0:
            scale = random.uniform(self.scale_range[0], self.scale_range[1])
            image = self._scale(image, scale)
            mask = self._scale(mask, scale)
        
        # Random flips
        if random.random() < self.flip_probability:
            if random.random() < 0.5:
                image = np.fliplr(image)
                mask = np.fliplr(mask)
            else:
                image = np.flipud(image)
                mask = np.flipud(mask)
        
        # Intensity augmentations (only for image)
        if self.brightness_range > 0:
            brightness = random.uniform(-self.brightness_range, self.brightness_range)
            image = self._adjust_brightness(image, brightness)
        
        if self.contrast_range > 0:
            contrast = random.uniform(-self.contrast_range, self.contrast_range)
            image = self._adjust_contrast(image, contrast)
        
        # Elastic deformation
        if self.elastic_deformation:
            image, mask = self._elastic_deform(image, mask)
        
        return {'image': image, 'mask': mask}
    
    def _rotate(self, image: np.ndarray, angle: float) -> np.ndarray:
        """Rotate image."""
        h, w = image.shape[:2]
        center = (w // 2, h // 2)
        M = cv2.getRotationMatrix2D(center, angle, 1.0)
        
        if image.max() <= 1.0:
            image_uint8 = (image * 255).astype(np.uint8)
        else:
            image_uint8 = image.astype(np.uint8)
        
        rotated = cv2.warpAffine(
            image_uint8, M, (w, h),
            flags=cv2.INTER_LINEAR,
            borderMode=cv2.BORDER_REFLECT
        )
        
        return rotated.astype(np.float32) / 255.0 if image.max() <= 1.0 else rotated.astype(np.float32)
    
    def _translate(self, image: np.ndarray, tx: float, ty: float) -> np.ndarray:
        """Translate image."""
        h, w = image.shape[:2]
        M = np.float32([[1, 0, tx], [0, 1, ty]])
        
        if image.max() <= 1.0:
            image_uint8 = (image * 255).astype(np.uint8)
        else:
            image_uint8 = image.astype(np.uint8)
        
        translated = cv2.warpAffine(
            image_uint8, M, (w, h),
            flags=cv2.INTER_LINEAR,
            borderMode=cv2.BORDER_REFLECT
        )
        
        return translated.astype(np.float32) / 255.0 if image.max() <= 1.0 else translated.astype(np.float32)
    
    def _scale(self, image: np.ndarray, scale: float) -> np.ndarray:
        """Scale image."""
        h, w = image.shape[:2]
        new_h, new_w = int(h * scale), int(w * scale)
        
        if image.max() <= 1.0:
            image_uint8 = (image * 255).astype(np.uint8)
        else:
            image_uint8 = image.astype(np.uint8)
        
        scaled = cv2.resize(image_uint8, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        
        # Crop or pad to original size
        if scale > 1.0:
            # Crop center
            start_h = (new_h - h) // 2
            start_w = (new_w - w) // 2
            scaled = scaled[start_h:start_h+h, start_w:start_w+w]
        else:
            # Pad
            pad_h = (h - new_h) // 2
            pad_w = (w - new_w) // 2
            scaled = np.pad(scaled, ((pad_h, h-new_h-pad_h), (pad_w, w-new_w-pad_w)), mode='reflect')
        
        return scaled.astype(np.float32) / 255.0 if image.max() <= 1.0 else scaled.astype(np.float32)
    
    def _adjust_brightness(self, image: np.ndarray, brightness: float) -> np.ndarray:
        """Adjust brightness."""
        return np.clip(image + brightness, 0, 1)
    
    def _adjust_contrast(self, image: np.ndarray, contrast: float) -> np.ndarray:
        """Adjust contrast."""
        mean = np.mean(image)
        return np.clip((image - mean) * (1 + contrast) + mean, 0, 1)
    
    def _elastic_deform(self, image: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Apply elastic deformation."""
        # Simplified elastic deformation
        alpha = random.uniform(50, 150)
        sigma = random.uniform(5, 10)
        
        h, w = image.shape[:2]
        dx = np.random.randn(h, w) * alpha
        dy = np.random.randn(h, w) * alpha
        
        # Smooth the displacement fields
        from scipy.ndimage import gaussian_filter
        dx = gaussian_filter(dx, sigma)
        dy = gaussian_filter(dy, sigma)
        
        # Create coordinate grids
        x, y = np.meshgrid(np.arange(w), np.arange(h))
        x_new = np.clip(x + dx, 0, w - 1).astype(np.float32)
        y_new = np.clip(y + dy, 0, h - 1).astype(np.float32)
        
        # Apply deformation
        if image.max() <= 1.0:
            image_uint8 = (image * 255).astype(np.uint8)
        else:
            image_uint8 = image.astype(np.uint8)
        
        image_deformed = cv2.remap(
            image_uint8, x_new, y_new,
            cv2.INTER_LINEAR,
            borderMode=cv2.BORDER_REFLECT
        )
        
        if mask.max() <= 1.0:
            mask_uint8 = (mask * 255).astype(np.uint8)
        else:
            mask_uint8 = mask.astype(np.uint8)
        
        mask_deformed = cv2.remap(
            mask_uint8, x_new, y_new,
            cv2.INTER_LINEAR,
            borderMode=cv2.BORDER_REFLECT
        )
        
        return (
            image_deformed.astype(np.float32) / 255.0 if image.max() <= 1.0 else image_deformed.astype(np.float32),
            mask_deformed.astype(np.float32) / 255.0 if mask.max() <= 1.0 else mask_deformed.astype(np.float32)
        )


class PreprocessingTransform:
    """Apply preprocessing (denoising, normalization, resizing) to images."""
    
    def __init__(
        self,
        target_size: int = 256,
        denoising_method: Optional[str] = None,
        normalization_method: str = "z_score"
    ):
        """
        Initialize preprocessing transform.
        
        Args:
            target_size: Target image size (square)
            denoising_method: Denoising method (None to skip)
            normalization_method: Normalization method
        """
        self.target_size = target_size
        self.denoising_method = denoising_method
        self.normalization_method = normalization_method
    
    def __call__(self, sample: Dict) -> Dict:
        """
        Apply preprocessing to sample.
        
        Args:
            sample: Dictionary with 'image' and 'mask' keys
        
        Returns:
            Preprocessed sample
        """
        image = sample['image'].copy()
        mask = sample['mask'].copy()
        
        # Resize if needed
        if image.shape[0] != self.target_size or image.shape[1] != self.target_size:
            image = self._resize(image, self.target_size)
            mask = self._resize(mask, self.target_size, is_mask=True)
        
        # Denoising (only for image)
        if self.denoising_method:
            from preprocessing.denoising import apply_denoising
            image = apply_denoising(image, method=self.denoising_method)
        
        # Normalization (only for image)
        if self.normalization_method:
            from preprocessing.normalization import apply_normalization
            image = apply_normalization(image, method=self.normalization_method)
        
        # Ensure mask is binary
        mask = (mask > 0.5).astype(np.float32)
        
        return {'image': image, 'mask': mask}
    
    def _resize(self, image: np.ndarray, target_size: int, is_mask: bool = False) -> np.ndarray:
        """Resize image to target size."""
        if image.max() <= 1.0 and not is_mask:
            image_uint8 = (image * 255).astype(np.uint8)
        elif is_mask:
            image_uint8 = (image * 255).astype(np.uint8)
        else:
            image_uint8 = image.astype(np.uint8)
        
        interpolation = cv2.INTER_NEAREST if is_mask else cv2.INTER_LINEAR
        resized = cv2.resize(image_uint8, (target_size, target_size), interpolation=interpolation)
        
        if image.max() <= 1.0 or is_mask:
            return resized.astype(np.float32) / 255.0
        else:
            return resized.astype(np.float32)



In [None]:
"""
Image denoising methods for MRI brain images.
"""

def gaussian_denoise(
    image: np.ndarray,
    kernel_size: int = 5,
    sigma: float = 1.0
) -> np.ndarray:
    """
    Apply Gaussian filtering for denoising.
    
    Args:
        image: Input grayscale image
        kernel_size: Size of Gaussian kernel (must be odd)
        sigma: Standard deviation of Gaussian kernel
    
    Returns:
        Denoised image
    """
    if kernel_size % 2 == 0:
        kernel_size += 1
    
    return gaussian(image, sigma=sigma, mode='reflect')


def median_denoise(
    image: np.ndarray,
    kernel_size: int = 5
) -> np.ndarray:
    """
    Apply median filtering for denoising.
    
    Args:
        image: Input grayscale image
        kernel_size: Size of median filter kernel (must be odd)
    
    Returns:
        Denoised image
    """
    if kernel_size % 2 == 0:
        kernel_size += 1
    
    return median(image, selem=np.ones((kernel_size, kernel_size)))


def bilateral_denoise(
    image: np.ndarray,
    d: int = 9,
    sigma_color: float = 75.0,
    sigma_space: float = 75.0
) -> np.ndarray:
    """
    Apply bilateral filtering for edge-preserving denoising.
    
    Args:
        image: Input grayscale image (0-255 range)
        d: Diameter of pixel neighborhood
        sigma_color: Filter sigma in the color space
        sigma_space: Filter sigma in the coordinate space
    
    Returns:
        Denoised image
    """
    # Convert to uint8 if needed
    if image.max() <= 1.0:
        image_uint8 = (image * 255).astype(np.uint8)
    else:
        image_uint8 = image.astype(np.uint8)
    
    denoised = cv2.bilateralFilter(
        image_uint8, d, sigma_color, sigma_space
    )
    
    # Convert back to float
    return denoised.astype(np.float32) / 255.0 if image.max() <= 1.0 else denoised.astype(np.float32)


def nlm_denoise(
    image: np.ndarray,
    patch_size: int = 5,
    patch_distance: int = 6,
    h: Optional[float] = None
) -> np.ndarray:
    """
    Apply Non-Local Means denoising.
    
    Args:
        image: Input grayscale image
        patch_size: Size of patches used for denoising
        patch_distance: Max distance to search for patches
        h: Cut-off distance for the exponential function (auto-estimated if None)
    
    Returns:
        Denoised image
    """
    # Estimate noise if h not provided
    if h is None:
        sigma_est = estimate_sigma(image, multichannel=False)
        h = 0.8 * sigma_est
    
    denoised = denoise_nl_means(
        image,
        patch_size=patch_size,
        patch_distance=patch_distance,
        h=h,
        multichannel=False,
        fast_mode=True
    )
    
    return denoised


def wavelet_denoise(
    image: np.ndarray,
    wavelet: str = 'db4',
    mode: str = 'soft',
    threshold: Optional[float] = None
) -> np.ndarray:
    """
    Apply wavelet denoising.
    
    Args:
        image: Input grayscale image
        wavelet: Wavelet type (e.g., 'db4', 'haar', 'bior2.2')
        mode: Thresholding mode ('soft' or 'hard')
        threshold: Threshold value (auto-estimated if None)
    
    Returns:
        Denoised image
    """
    try:
        import pywt
    except ImportError:
        raise ImportError("pywt (PyWavelets) is required for wavelet denoising")
    
    # Decompose
    coeffs = pywt.wavedec2(image, wavelet, mode='symmetric')
    
    # Estimate threshold if not provided
    if threshold is None:
        # Use universal threshold
        sigma = np.median(np.abs(coeffs[-1])) / 0.6745
        threshold = sigma * np.sqrt(2 * np.log(image.size))
    
    # Threshold coefficients
    coeffs_thresh = list(coeffs)
    coeffs_thresh[0] = pywt.threshold(coeffs[0], threshold, mode=mode)
    for i in range(1, len(coeffs)):
        coeffs_thresh[i] = tuple(
            pywt.threshold(detail, threshold, mode=mode)
            for detail in coeffs[i]
        )
    
    # Reconstruct
    denoised = pywt.waverec2(coeffs_thresh, wavelet, mode='symmetric')
    
    return denoised


def apply_denoising(
    image: np.ndarray,
    method: str = "bilateral",
    **kwargs
) -> np.ndarray:
    """
    Apply denoising using specified method.
    
    Args:
        image: Input grayscale image
        method: Denoising method ('gaussian', 'median', 'bilateral', 'nlm', 'wavelet')
        **kwargs: Additional arguments for specific denoising method
    
    Returns:
        Denoised image
    """
    methods = {
        'gaussian': gaussian_denoise,
        'median': median_denoise,
        'bilateral': bilateral_denoise,
        'nlm': nlm_denoise,
        'wavelet': wavelet_denoise
    }
    
    if method not in methods:
        raise ValueError(f"Unknown denoising method: {method}. Choose from {list(methods.keys())}")
    
    return methods[method](image, **kwargs)



In [None]:
"""
HDF5 dataset loader for brain tumor MRI images.
"""
# import h5py
# import numpy as np
# import torch
from torch.utils.data import Dataset, DataLoader
# from typing import Tuple, List, Dict, Optional
# from pathlib import Path
# import json
# from sklearn.model_selection import train_test_split

# from utils.config import Config


class BrainTumorDataset(Dataset):
    """PyTorch Dataset for brain tumor HDF5 data."""
    
    def __init__(
        self,
        hdf5_path: Path,
        patient_ids: List[str],
        transform=None,
        return_label: bool = False
    ):
        """
        Initialize dataset.
        
        Args:
            hdf5_path: Path to HDF5 file
            patient_ids: List of patient group IDs to use
            transform: Optional transform to apply to image and mask
            return_label: Whether to return tumor type label
        """
        self.hdf5_path = hdf5_path
        self.patient_ids = patient_ids
        self.transform = transform
        self.return_label = return_label
        
    def __len__(self) -> int:
        return len(self.patient_ids)
    
    def __getitem__(self, idx: int) -> Dict:
        """Get a single sample."""
        patient_id = self.patient_ids[idx]
        
        with h5py.File(self.hdf5_path, 'r') as f:
            group = f[patient_id]
            image = np.array(group['image'], dtype=np.float32)
            mask = np.array(group['tumor_mask'], dtype=np.float32)
            
            if self.return_label:
                label = int(group['label'][()])
            else:
                label = None
        
        # Normalize image to [0, 1] range
        if image.max() > 1.0:
            image = image / 255.0
        
        # Ensure mask is binary [0, 1]
        mask = (mask > 0.5).astype(np.float32)
        
        # Apply transforms
        if self.transform:
            transformed = self.transform({'image': image, 'mask': mask})
            image = transformed['image']
            mask = transformed['mask']
        
        # Convert to tensors
        image = torch.from_numpy(image).unsqueeze(0)  # Add channel dimension
        mask = torch.from_numpy(mask).unsqueeze(0)  # Add channel dimension
        
        result = {
            'image': image,
            'mask': mask,
            'patient_id': patient_id
        }
        
        if self.return_label:
            result['label'] = label
        
        return result


class HDF5DatasetExplorer:
    """Utility class to explore HDF5 dataset structure."""
    
    def __init__(self, hdf5_path: Path):
        self.hdf5_path = hdf5_path
    
    def explore(self) -> Dict:
        """Explore dataset and return statistics."""
        stats = {
            'total_patients': 0,
            'patient_ids': [],
            'image_shapes': [],
            'mask_shapes': [],
            'labels': [],
            'label_distribution': {}
        }
        
        with h5py.File(self.hdf5_path, 'r') as f:
            patient_ids = list(f.keys())
            stats['total_patients'] = len(patient_ids)
            stats['patient_ids'] = patient_ids
            
            for patient_id in patient_ids:
                group = f[patient_id]
                
                # Get image shape
                image = group['image']
                stats['image_shapes'].append(image.shape)
                
                # Get mask shape
                mask = group['tumor_mask']
                stats['mask_shapes'].append(mask.shape)
                
                # Get label
                label = int(group['label'][()])
                stats['labels'].append(label)
                
                # Update label distribution
                label_name = {1: 'Meningioma', 2: 'Glioma', 3: 'Pituitary'}[label]
                stats['label_distribution'][label_name] = \
                    stats['label_distribution'].get(label_name, 0) + 1
        
        return stats
    
    def print_summary(self):
        """Print dataset summary."""
        stats = self.explore()
        
        print("=" * 60)
        print("Dataset Summary")
        print("=" * 60)
        print(f"Total Patients: {stats['total_patients']}")
        print(f"\nLabel Distribution:")
        for label, count in stats['label_distribution'].items():
            percentage = (count / stats['total_patients']) * 100
            print(f"  {label}: {count} ({percentage:.1f}%)")
        
        print(f"\nImage Shapes:")
        unique_shapes = set(stats['image_shapes'])
        for shape in unique_shapes:
            count = stats['image_shapes'].count(shape)
            print(f"  {shape}: {count} images")
        
        print(f"\nMask Shapes:")
        unique_shapes = set(stats['mask_shapes'])
        for shape in unique_shapes:
            count = stats['mask_shapes'].count(shape)
            print(f"  {shape}: {count} masks")
        print("=" * 60)


def create_data_splits(
    hdf5_path: Path,
    output_dir: Path,
    train_ratio: float = 0.70,
    val_ratio: float = 0.15,
    test_ratio: float = 0.15,
    random_seed: int = 42,
    stratify: bool = True
) -> Dict[str, List[str]]:
    """
    Create train/validation/test splits with optional stratification.
    
    Args:
        hdf5_path: Path to HDF5 file
        output_dir: Directory to save split files
        train_ratio: Proportion for training set
        val_ratio: Proportion for validation set
        test_ratio: Proportion for test set
        random_seed: Random seed for reproducibility
        stratify: Whether to stratify by tumor type label
    
    Returns:
        Dictionary with 'train', 'val', 'test' patient ID lists
    """
    # Load all patient IDs and labels
    patient_ids = []
    labels = []
    
    with h5py.File(hdf5_path, 'r') as f:
        for patient_id in f.keys():
            patient_ids.append(patient_id)
            labels.append(int(f[patient_id]['label'][()]))
    
    patient_ids = np.array(patient_ids)
    labels = np.array(labels)
    
    # Create splits
    if stratify:
        # First split: train vs (val + test)
        train_ids, temp_ids, train_labels, temp_labels = train_test_split(
            patient_ids, labels,
            test_size=(val_ratio + test_ratio),
            random_state=random_seed,
            stratify=labels
        )
        
        # Second split: val vs test
        val_ratio_adjusted = val_ratio / (val_ratio + test_ratio)
        val_ids, test_ids, val_labels, test_labels = train_test_split(
            temp_ids, temp_labels,
            test_size=(1 - val_ratio_adjusted),
            random_state=random_seed,
            stratify=temp_labels
        )
    else:
        # Random split without stratification
        train_ids, temp_ids = train_test_split(
            patient_ids,
            test_size=(val_ratio + test_ratio),
            random_state=random_seed
        )
        
        val_ratio_adjusted = val_ratio / (val_ratio + test_ratio)
        val_ids, test_ids = train_test_split(
            temp_ids,
            test_size=(1 - val_ratio_adjusted),
            random_state=random_seed
        )
    
    splits = {
        'train': train_ids.tolist(),
        'val': val_ids.tolist(),
        'test': test_ids.tolist()
    }
    
    # Save splits to JSON files
    output_dir.mkdir(parents=True, exist_ok=True)
    for split_name, ids in splits.items():
        output_path = output_dir / f"{split_name}_ids.json"
        with open(output_path, 'w') as f:
            json.dump(ids, f, indent=2)
    
    # Print split statistics
    print("=" * 60)
    print("Data Split Summary")
    print("=" * 60)
    for split_name, ids in splits.items():
        print(f"{split_name.upper()}: {len(ids)} patients ({len(ids)/len(patient_ids)*100:.1f}%)")
    print("=" * 60)
    
    return splits


def load_data_splits(splits_dir: Path) -> Dict[str, List[str]]:
    """Load previously created data splits."""
    splits = {}
    for split_name in ['train', 'val', 'test']:
        split_path = splits_dir / f"{split_name}_ids.json"
        if split_path.exists():
            with open(split_path, 'r') as f:
                splits[split_name] = json.load(f)
        else:
            raise FileNotFoundError(f"Split file not found: {split_path}")
    return splits


def create_dataloaders(
    hdf5_path: Path,
    splits: Dict[str, List[str]],
    batch_size: int = 16,
    num_workers: int = 4,
    train_transform=None,
    val_transform=None,
    return_label: bool = False
) -> Dict[str, DataLoader]:
    """
    Create DataLoaders for train, validation, and test sets.
    
    Returns:
        Dictionary with 'train', 'val', 'test' DataLoaders
    """
    datasets = {
        'train': BrainTumorDataset(
            hdf5_path, splits['train'],
            transform=train_transform,
            return_label=return_label
        ),
        'val': BrainTumorDataset(
            hdf5_path, splits['val'],
            transform=val_transform,
            return_label=return_label
        ),
        'test': BrainTumorDataset(
            hdf5_path, splits['test'],
            transform=val_transform,
            return_label=return_label
        )
    }
    
    dataloaders = {
        'train': DataLoader(
            datasets['train'],
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=True
        ),
        'val': DataLoader(
            datasets['val'],
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        ),
        'test': DataLoader(
            datasets['test'],
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )
    }
    
    return dataloaders



In [None]:
"""
Image normalization methods for MRI brain images.
"""

def z_score_normalize(image: np.ndarray) -> Tuple[np.ndarray, float, float]:
    """
    Z-score normalization: (pixel - mean) / std
    
    Args:
        image: Input image
    
    Returns:
        Normalized image, mean, std
    """
    mean = np.mean(image)
    std = np.std(image)
    
    if std == 0:
        return image, mean, std
    
    normalized = (image - mean) / std
    return normalized, mean, std


def min_max_normalize(image: np.ndarray) -> Tuple[np.ndarray, float, float]:
    """
    Min-max normalization: (pixel - min) / (max - min)
    
    Args:
        image: Input image
    
    Returns:
        Normalized image (0-1 range), min, max
    """
    min_val = np.min(image)
    max_val = np.max(image)
    
    if max_val == min_val:
        return image, min_val, max_val
    
    normalized = (image - min_val) / (max_val - min_val)
    return normalized, min_val, max_val


def histogram_equalize(image: np.ndarray) -> np.ndarray:
    """
    Histogram equalization for contrast enhancement.
    
    Args:
        image: Input image (0-255 range expected)
    
    Returns:
        Equalized image
    """
    # Convert to uint8 if needed
    if image.max() <= 1.0:
        image_uint8 = (image * 255).astype(np.uint8)
    else:
        image_uint8 = image.astype(np.uint8)
    
    equalized = cv2.equalizeHist(image_uint8)
    
    # Convert back to float
    return equalized.astype(np.float32) / 255.0 if image.max() <= 1.0 else equalized.astype(np.float32)


def clahe_normalize(
    image: np.ndarray,
    clip_limit: float = 2.0,
    tile_grid_size: Tuple[int, int] = (8, 8)
) -> np.ndarray:
    """
    Contrast Limited Adaptive Histogram Equalization (CLAHE).
    
    Args:
        image: Input image (0-255 range expected)
        clip_limit: Threshold for contrast limiting
        tile_grid_size: Size of grid for histogram equalization
    
    Returns:
        CLAHE normalized image
    """
    # Convert to uint8 if needed
    if image.max() <= 1.0:
        image_uint8 = (image * 255).astype(np.uint8)
    else:
        image_uint8 = image.astype(np.uint8)
    
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    normalized = clahe.apply(image_uint8)
    
    # Convert back to float
    return normalized.astype(np.float32) / 255.0 if image.max() <= 1.0 else normalized.astype(np.float32)


def apply_normalization(
    image: np.ndarray,
    method: str = "z_score",
    **kwargs
) -> np.ndarray:
    """
    Apply normalization using specified method.
    
    Args:
        image: Input image
        method: Normalization method ('z_score', 'min_max', 'histogram_eq', 'clahe')
        **kwargs: Additional arguments for specific normalization method
    
    Returns:
        Normalized image
    """
    methods = {
        'z_score': lambda img: z_score_normalize(img)[0],
        'min_max': lambda img: min_max_normalize(img)[0],
        'histogram_eq': histogram_equalize,
        'clahe': clahe_normalize
    }
    
    if method not in methods:
        raise ValueError(f"Unknown normalization method: {method}. Choose from {list(methods.keys())}")
    
    return methods[method](image, **kwargs)



In [None]:
"""
Loss functions for brain tumor segmentation.
"""
class DiceLoss(nn.Module):
    """
    Dice Loss for binary segmentation.
    Directly optimizes Dice Similarity Coefficient.
    """
    
    def __init__(self, smooth: float = 1e-6):
        """
        Initialize Dice Loss.
        
        Args:
            smooth: Smoothing factor to avoid division by zero
        """
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute Dice Loss.
        
        Args:
            predictions: Predicted masks (B, 1, H, W) with values in [0, 1]
            targets: Ground truth masks (B, 1, H, W) with values in [0, 1]
        
        Returns:
            Dice loss value
        """
        # Flatten tensors
        predictions = predictions.view(-1)
        targets = targets.view(-1)
        
        # Compute intersection and union
        intersection = (predictions * targets).sum()
        dice = (2.0 * intersection + self.smooth) / (
            predictions.sum() + targets.sum() + self.smooth
        )
        
        # Return 1 - dice (loss to minimize)
        return 1 - dice


class CombinedLoss(nn.Module):
    """
    Combined Dice Loss and Binary Cross-Entropy Loss.
    Often provides better training stability.
    """
    
    def __init__(self, dice_weight: float = 0.5, bce_weight: float = 0.5, smooth: float = 1e-6):
        """
        Initialize Combined Loss.
        
        Args:
            dice_weight: Weight for Dice Loss
            bce_weight: Weight for BCE Loss
            smooth: Smoothing factor for Dice Loss
        """
        super(CombinedLoss, self).__init__()
        self.dice_weight = dice_weight
        self.bce_weight = bce_weight
        self.dice_loss = DiceLoss(smooth=smooth)
        self.bce_loss = nn.BCELoss()
    
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute combined loss.
        
        Args:
            predictions: Predicted masks
            targets: Ground truth masks
        
        Returns:
            Combined loss value
        """
        dice = self.dice_loss(predictions, targets)
        bce = self.bce_loss(predictions, targets)
        
        return self.dice_weight * dice + self.bce_weight * bce


class FocalLoss(nn.Module):
    """
    Focal Loss for handling class imbalance.
    Useful when tumor regions are small compared to background.
    """
    
    def __init__(self, alpha: float = 0.25, gamma: float = 2.0):
        """
        Initialize Focal Loss.
        
        Args:
            alpha: Weighting factor for rare class
            gamma: Focusing parameter
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Compute Focal Loss.
        
        Args:
            predictions: Predicted masks
            targets: Ground truth masks
        
        Returns:
            Focal loss value
        """
        # Compute BCE
        bce = F.binary_cross_entropy(predictions, targets, reduction='none')
        
        # Compute p_t
        p_t = predictions * targets + (1 - predictions) * (1 - targets)
        
        # Compute focal weight
        focal_weight = (1 - p_t) ** self.gamma
        
        # Apply alpha
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        
        # Compute focal loss
        focal_loss = alpha_t * focal_weight * bce
        
        return focal_loss.mean()



In [None]:
"""
Training utilities for brain tumor segmentation models.
"""
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
import time


class Trainer:
    """Generic trainer for segmentation models."""
    
    def __init__(
        self,
        model: nn.Module,
        train_loader: torch.utils.data.DataLoader,
        val_loader: torch.utils.data.DataLoader,
        criterion: nn.Module,
        optimizer: optim.Optimizer,
        device: torch.device,
        config,
        model_name: str = "model"
    ):
        """
        Initialize trainer.
        
        Args:
            model: Model to train
            train_loader: Training data loader
            val_loader: Validation data loader
            criterion: Loss function
            optimizer: Optimizer
            device: Device to train on
            config: Configuration object
            model_name: Name for saving checkpoints
        """
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.config = config
        self.model_name = model_name
        
        # Learning rate scheduler
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='max',
            factor=config.REDUCE_LR_FACTOR,
            patience=config.REDUCE_LR_PATIENCE,
            min_lr=config.MIN_LEARNING_RATE,
            verbose=True
        )
        
        # Mixed precision training
        self.use_amp = config.USE_MIXED_PRECISION and hasattr(torch.cuda, 'amp')
        if self.use_amp:
            self.scaler = torch.cuda.amp.GradScaler()
        
        # TensorBoard writer
        log_dir = config.LOGS_DIR / model_name
        self.writer = SummaryWriter(log_dir=str(log_dir))
        
        # Training state
        self.current_epoch = 0
        self.best_val_dice = 0.0
        self.train_losses = []
        self.val_dice_scores = []
        
        # Early stopping
        self.patience_counter = 0
        self.early_stopping_patience = config.EARLY_STOPPING_PATIENCE
    
    def train_epoch(self) -> float:
        """Train for one epoch."""
        self.model.train()
        running_loss = 0.0
        num_batches = 0
        
        for batch_idx, batch in enumerate(self.train_loader):
            images = batch['image'].to(self.device)
            masks = batch['mask'].to(self.device)
            
            # Zero gradients
            self.optimizer.zero_grad()
            
            # Forward pass
            if self.use_amp:
                with torch.cuda.amp.autocast():
                    predictions = self.model(images)
                    loss = self.criterion(predictions, masks)
                
                # Backward pass
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                predictions = self.model(images)
                loss = self.criterion(predictions, masks)
                
                # Backward pass
                loss.backward()
                self.optimizer.step()
            
            # Update statistics
            running_loss += loss.item()
            num_batches += 1
            
            # Logging
            if batch_idx % self.config.LOG_INTERVAL == 0:
                print(f'Epoch {self.current_epoch}, Batch {batch_idx}/{len(self.train_loader)}, '
                      f'Loss: {loss.item():.4f}')
                
                # Log to TensorBoard
                global_step = self.current_epoch * len(self.train_loader) + batch_idx
                self.writer.add_scalar('Train/Loss', loss.item(), global_step)
        
        avg_loss = running_loss / num_batches
        return avg_loss
    
    def validate(self) -> Dict[str, float]:
        """Validate model."""
        self.model.eval()
        
        # Evaluate on validation set
        metrics = evaluate_model(self.model, self.val_loader, self.device)
        
        return metrics
    
    def save_checkpoint(self, is_best: bool = False):
        """Save model checkpoint."""
        checkpoint = {
            'epoch': self.current_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'best_val_dice': self.best_val_dice,
            'train_losses': self.train_losses,
            'val_dice_scores': self.val_dice_scores
        }
        
        if self.use_amp:
            checkpoint['scaler_state_dict'] = self.scaler.state_dict()
        
        # Save latest checkpoint
        checkpoint_path = self.config.SAVED_MODELS_DIR / f"{self.model_name}_latest.pth"
        torch.save(checkpoint, checkpoint_path)
        
        # Save best checkpoint
        if is_best:
            best_path = self.config.SAVED_MODELS_DIR / f"{self.model_name}_best.pth"
            torch.save(checkpoint, best_path)
            print(f"Saved best model with DSC: {self.best_val_dice:.4f}")
    
    def load_checkpoint(self, checkpoint_path: Path):
        """Load model checkpoint."""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.current_epoch = checkpoint['epoch']
        self.best_val_dice = checkpoint['best_val_dice']
        self.train_losses = checkpoint.get('train_losses', [])
        self.val_dice_scores = checkpoint.get('val_dice_scores', [])
        
        if self.use_amp and 'scaler_state_dict' in checkpoint:
            self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
        
        print(f"Loaded checkpoint from epoch {self.current_epoch}")
    
    def train(self, num_epochs: int, resume_from: Optional[Path] = None):
        """
        Main training loop.
        
        Args:
            num_epochs: Number of epochs to train
            resume_from: Path to checkpoint to resume from
        """
        if resume_from:
            self.load_checkpoint(resume_from)
        
        print(f"Starting training for {num_epochs} epochs...")
        print(f"Device: {self.device}")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
        start_time = time.time()
        
        for epoch in range(self.current_epoch, num_epochs):
            self.current_epoch = epoch
            
            # Train
            train_loss = self.train_epoch()
            self.train_losses.append(train_loss)
            
            # Validate
            val_metrics = self.validate()
            val_dice = val_metrics['dice']
            self.val_dice_scores.append(val_dice)
            
            # Update learning rate
            self.scheduler.step(val_dice)
            
            # Check if best model
            is_best = val_dice > self.best_val_dice
            if is_best:
                self.best_val_dice = val_dice
                self.patience_counter = 0
            else:
                self.patience_counter += 1
            
            # Save checkpoint
            if (epoch + 1) % self.config.SAVE_INTERVAL == 0 or is_best:
                self.save_checkpoint(is_best=is_best)
            
            # Log epoch summary
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val DSC: {val_dice:.4f}")
            print(f"Val IoU: {val_metrics['iou']:.4f}")
            print(f"Best Val DSC: {self.best_val_dice:.4f}")
            print(f"LR: {self.optimizer.param_groups[0]['lr']:.6f}")
            
            # Log to TensorBoard
            self.writer.add_scalar('Epoch/Train_Loss', train_loss, epoch)
            self.writer.add_scalar('Epoch/Val_Dice', val_dice, epoch)
            self.writer.add_scalar('Epoch/Val_IoU', val_metrics['iou'], epoch)
            self.writer.add_scalar('Epoch/Learning_Rate', self.optimizer.param_groups[0]['lr'], epoch)
            
            # Early stopping
            if self.patience_counter >= self.early_stopping_patience:
                print(f"\nEarly stopping triggered after {epoch + 1} epochs")
                break
        
        training_time = time.time() - start_time
        print(f"\nTraining completed in {training_time/60:.2f} minutes")
        print(f"Best validation DSC: {self.best_val_dice:.4f}")
        
        # Save final checkpoint
        self.save_checkpoint(is_best=False)
        self.writer.close()



In [None]:
"""
Pre-trained model architectures for brain tumor segmentation.
Uses transfer learning with ResNet, VGG, or U-Net backbones.
"""
class ResNetUNet(nn.Module):
    """
    U-Net with ResNet encoder for brain tumor segmentation.
    Uses pre-trained ResNet as encoder backbone.
    """
    
    def __init__(
        self,
        encoder_name: str = "resnet50",
        pretrained: bool = True,
        num_classes: int = 1,
        dropout: float = 0.2
    ):
        """
        Initialize ResNet-U-Net.
        
        Args:
            encoder_name: ResNet variant ('resnet18', 'resnet34', 'resnet50', 'resnet101')
            pretrained: Use ImageNet pre-trained weights
            num_classes: Number of output classes (1 for binary segmentation)
            dropout: Dropout rate
        """
        super(ResNetUNet, self).__init__()
        
        # Load pre-trained ResNet
        resnet_models = {
            'resnet18': models.resnet18,
            'resnet34': models.resnet34,
            'resnet50': models.resnet50,
            'resnet101': models.resnet101
        }
        
        if encoder_name not in resnet_models:
            raise ValueError(f"Unknown encoder: {encoder_name}")
        
        encoder = resnet_models[encoder_name](pretrained=pretrained)
        
        # Encoder layers
        self.encoder0 = nn.Sequential(
            encoder.conv1,
            encoder.bn1,
            encoder.relu
        )
        self.encoder1 = nn.Sequential(
            encoder.maxpool,
            encoder.layer1
        )
        self.encoder2 = encoder.layer2
        self.encoder3 = encoder.layer3
        self.encoder4 = encoder.layer4
        
        # Get channel sizes
        if encoder_name in ['resnet18', 'resnet34']:
            channels = [64, 64, 128, 256, 512]
        else:  # resnet50, resnet101
            channels = [64, 256, 512, 1024, 2048]
        
        # Decoder
        self.decoder4 = self._make_decoder_block(channels[4], channels[3], dropout)
        self.decoder3 = self._make_decoder_block(channels[3], channels[2], dropout)
        self.decoder2 = self._make_decoder_block(channels[2], channels[1], dropout)
        self.decoder1 = self._make_decoder_block(channels[1], channels[0], dropout)
        self.decoder0 = self._make_decoder_block(channels[0], 64, dropout)
        
        # Output
        self.final_conv = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout),
            nn.Conv2d(32, num_classes, kernel_size=1),
            nn.Sigmoid()
        )
    
    def _make_decoder_block(self, in_channels: int, out_channels: int, dropout: float):
        """Create decoder block with upsampling and skip connections."""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
    
    def forward(self, x):
        # Encoder
        e0 = self.encoder0(x)  # 64 channels
        e1 = self.encoder1(e0)  # 64/256 channels
        e2 = self.encoder2(e1)  # 128/512 channels
        e3 = self.encoder3(e2)  # 256/1024 channels
        e4 = self.encoder4(e3)  # 512/2048 channels
        
        # Decoder with skip connections
        d4 = self.decoder4(e4)
        d4 = F.interpolate(d4, size=e3.shape[2:], mode='bilinear', align_corners=True)
        d4 = torch.cat([d4, e3], dim=1)
        
        d3 = self.decoder3(d4)
        d3 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=True)
        d3 = torch.cat([d3, e2], dim=1)
        
        d2 = self.decoder2(d3)
        d2 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=True)
        d2 = torch.cat([d2, e1], dim=1)
        
        d1 = self.decoder1(d2)
        d1 = F.interpolate(d1, size=e0.shape[2:], mode='bilinear', align_corners=True)
        d1 = torch.cat([d1, e0], dim=1)
        
        d0 = self.decoder0(d1)
        
        # Output
        out = self.final_conv(d0)
        
        return out


class VGGUNet(nn.Module):
    """
    U-Net with VGG encoder for brain tumor segmentation.
    Uses pre-trained VGG as encoder backbone.
    """
    
    def __init__(
        self,
        encoder_name: str = "vgg16",
        pretrained: bool = True,
        num_classes: int = 1,
        dropout: float = 0.2
    ):
        """
        Initialize VGG-U-Net.
        
        Args:
            encoder_name: VGG variant ('vgg11', 'vgg13', 'vgg16', 'vgg19')
            pretrained: Use ImageNet pre-trained weights
            num_classes: Number of output classes
            dropout: Dropout rate
        """
        super(VGGUNet, self).__init__()
        
        # Load pre-trained VGG
        vgg_models = {
            'vgg11': models.vgg11,
            'vgg13': models.vgg13,
            'vgg16': models.vgg16,
            'vgg19': models.vgg19
        }
        
        if encoder_name not in vgg_models:
            raise ValueError(f"Unknown encoder: {encoder_name}")
        
        vgg = vgg_models[encoder_name](pretrained=pretrained)
        features = list(vgg.features.children())
        
        # Encoder blocks
        self.encoder1 = nn.Sequential(*features[0:4])   # 64 channels
        self.encoder2 = nn.Sequential(*features[4:9])   # 128 channels
        self.encoder3 = nn.Sequential(*features[9:16])  # 256 channels
        self.encoder4 = nn.Sequential(*features[16:23]) # 512 channels
        self.encoder5 = nn.Sequential(*features[23:30]) # 512 channels
        
        # Decoder
        self.decoder5 = self._make_decoder_block(512, 512, dropout)
        self.decoder4 = self._make_decoder_block(512, 256, dropout)
        self.decoder3 = self._make_decoder_block(256, 128, dropout)
        self.decoder2 = self._make_decoder_block(128, 64, dropout)
        self.decoder1 = self._make_decoder_block(64, 64, dropout)
        
        # Output
        self.final_conv = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout),
            nn.Conv2d(32, num_classes, kernel_size=1),
            nn.Sigmoid()
        )
    
    def _make_decoder_block(self, in_channels: int, out_channels: int, dropout: float):
        """Create decoder block."""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
    
    def forward(self, x):
        # Encoder
        e1 = self.encoder1(x)  # 64
        e2 = self.encoder2(e1)  # 128
        e3 = self.encoder3(e2)  # 256
        e4 = self.encoder4(e3)  # 512
        e5 = self.encoder5(e4)  # 512
        
        # Decoder with skip connections
        d5 = self.decoder5(e5)
        d5 = F.interpolate(d5, size=e4.shape[2:], mode='bilinear', align_corners=True)
        d5 = torch.cat([d5, e4], dim=1)
        
        d4 = self.decoder4(d5)
        d4 = F.interpolate(d4, size=e3.shape[2:], mode='bilinear', align_corners=True)
        d4 = torch.cat([d4, e3], dim=1)
        
        d3 = self.decoder3(d4)
        d3 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=True)
        d3 = torch.cat([d3, e2], dim=1)
        
        d2 = self.decoder2(d3)
        d2 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=True)
        d2 = torch.cat([d2, e1], dim=1)
        
        d1 = self.decoder1(d2)
        
        # Output
        out = self.final_conv(d1)
        
        return out


def get_pretrained_model(
    model_type: str = "resnet",
    encoder_name: str = "resnet50",
    pretrained: bool = True,
    num_classes: int = 1,
    dropout: float = 0.2
):
    """
    Factory function to get pre-trained model.
    
    Args:
        model_type: Model type ('resnet', 'vgg')
        encoder_name: Encoder name (e.g., 'resnet50', 'vgg16')
        pretrained: Use pre-trained weights
        num_classes: Number of output classes
        dropout: Dropout rate
    
    Returns:
        Model instance
    """
    if model_type.lower() == "resnet":
        return ResNetUNet(encoder_name, pretrained, num_classes, dropout)
    elif model_type.lower() == "vgg":
        return VGGUNet(encoder_name, pretrained, num_classes, dropout)
    else:
        raise ValueError(f"Unknown model type: {model_type}")



In [None]:
"""
Training script for pre-trained brain tumor segmentation model.
"""



In [None]:

def main():
    """Main training function."""
    # Initialize configuration
    config = Config()
    config.create_directories()
    
    # Set random seeds for reproducibility
    torch.manual_seed(config.RANDOM_SEED)
    torch.cuda.manual_seed_all(config.RANDOM_SEED)
    import numpy as np
    np.random.seed(config.RANDOM_SEED)
    import random
    random.seed(config.RANDOM_SEED)
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Explore dataset
    print("\n" + "="*60)
    print("Exploring Dataset")
    print("="*60)
    path = config.HDF5_DATASET_PATH if config.ENVIRONMENT == "local" else config.get_data_from_google_drive()
    explorer = HDF5DatasetExplorer(path)
    explorer.print_summary()
    
    # Create or load data splits
    splits_dir = config.SPLITS_DIR
    if not (splits_dir / "train_ids.json").exists():
        print("\n" + "="*60)
        print("Creating Data Splits")
        print("="*60)
        splits = create_data_splits(
            config.HDF5_DATASET_PATH,
            splits_dir,
            train_ratio=config.TRAIN_SPLIT,
            val_ratio=config.VAL_SPLIT,
            test_ratio=config.TEST_SPLIT,
            random_seed=config.RANDOM_SEED,
            stratify=True
        )
    else:
        print("\n" + "="*60)
        print("Loading Existing Data Splits")
        print("="*60)
        splits = load_data_splits(splits_dir)
        for split_name, ids in splits.items():
            print(f"{split_name.upper()}: {len(ids)} patients")
    
    # Create transforms
    print("\n" + "="*60)
    print("Setting up Preprocessing and Augmentation")
    print("="*60)
    
    preprocessing = PreprocessingTransform(
        target_size=config.IMAGE_SIZE,
        denoising_method=config.DENOISING_METHOD,
        normalization_method=config.NORMALIZATION_METHOD
    )
    
    augmentation = AugmentationTransform(
        rotation_range=config.ROTATION_RANGE,
        translation_range=config.TRANSLATION_RANGE,
        scale_range=config.SCALE_RANGE,
        brightness_range=config.BRIGHTNESS_RANGE,
        contrast_range=config.CONTRAST_RANGE,
        flip_probability=config.FLIP_PROBABILITY
    ) if config.AUGMENTATION_ENABLED else None
    
    # Compose transforms
    def train_transform(sample):
        sample = preprocessing(sample)
        if augmentation:
            sample = augmentation(sample)
        return sample
    
    # Create data loaders
    print("\n" + "="*60)
    print("Creating Data Loaders")
    print("="*60)
    dataloaders = create_dataloaders(
        config.HDF5_DATASET_PATH,
        splits,
        batch_size=config.BATCH_SIZE,
        num_workers=config.NUM_WORKERS,
        train_transform=train_transform,
        val_transform=preprocessing,
        return_label=False
    )
    
    print(f"Train batches: {len(dataloaders['train'])}")
    print(f"Val batches: {len(dataloaders['val'])}")
    print(f"Test batches: {len(dataloaders['test'])}")
    
    # Create model
    print("\n" + "="*60)
    print("Creating Pre-trained Model")
    print("="*60)
    
    # Determine model type and encoder
    if config.PRETRAINED_MODEL_TYPE.lower() == "resnet":
        model_type = "resnet"
        encoder_name = config.PRETRAINED_ENCODER
    elif config.PRETRAINED_MODEL_TYPE.lower() == "vgg":
        model_type = "vgg"
        encoder_name = "vgg16"  # Default VGG
    else:
        model_type = "resnet"
        encoder_name = "resnet50"
    
    model = get_pretrained_model(
        model_type=model_type,
        encoder_name=encoder_name,
        pretrained=True,
        num_classes=1,
        dropout=config.DROPOUT_RATE
    ).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Loss function
    criterion = DiceLoss(smooth=1e-6)
    # Alternative: CombinedLoss(dice_weight=0.5, bce_weight=0.5)
    
    # Optimizer with differential learning rates
    # Lower LR for encoder (pre-trained), higher for decoder (new)
    encoder_params = []
    decoder_params = []
    
    for name, param in model.named_parameters():
        if 'encoder' in name:
            encoder_params.append(param)
        else:
            decoder_params.append(param)
    
    optimizer = optim.Adam([
        {'params': encoder_params, 'lr': config.LEARNING_RATE * 0.1},  # 10x lower LR for encoder
        {'params': decoder_params, 'lr': config.LEARNING_RATE}
    ], weight_decay=config.WEIGHT_DECAY)
    
    # Create trainer
    trainer = Trainer(
        model=model,
        train_loader=dataloaders['train'],
        val_loader=dataloaders['val'],
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        config=config,
        model_name=f"pretrained_{model_type}_{encoder_name}"
    )
    
    # Train
    print("\n" + "="*60)
    print("Starting Training")
    print("="*60)
    trainer.train(num_epochs=config.NUM_EPOCHS)
    
    print("\n" + "="*60)
    print("Training Complete!")
    print("="*60)
    print(f"Best validation DSC: {trainer.best_val_dice:.4f}")
    print(f"Model saved to: {config.SAVED_MODELS_DIR}")



In [None]:

main()
