In [1]:
# Cell 1: Install Dependencies
# ==================================================================================
!pip install -q monai
!pip install -q segmentation_models_pytorch
print("‚úÖ Libraries Installed: MONAI, SMP, WandB")

‚úÖ Libraries Installed: MONAI, SMP, WandB


In [2]:
# Cell 2: Configuration & Paths
# ==================================================================================
import os
import torch

CONFIG = {
    # --- 1. GEOMETRY (The Aspect Ratio Fix) ---
    # We pad 208 -> 240, then Resize -> 224.
    "IMG_SIZE": (224, 224),
    "NATIVE_PAD_SIZE": (240, 240), 
    
    # --- 2. DATA MIXING (The Anchor Strategy) ---
    # We want an effective batch size of 16.
    # Since MaxViT is heavy, we might use Grad Accumulation if T4 runs OOM.
    "REAL_BS": 16,       # Real images per step
    "PSEUDO_BS": 16,     # Pseudo images per step
    "GRAD_ACCUM_STEPS": 1, # Increase to 2 if you hit OOM (Effective BS becomes 32)
    
    # --- 3. TRAINING HYPERPARAMETERS ---
    "EPOCHS": 20,
    "LR": 1e-4,
    "WEIGHT_DECAY": 1e-2, # Stronger decay for Transformers
    "SEED": 42,
    "PRECISION": "amp",   # Automatic Mixed Precision
    
    # --- 4. PATHS (Mapped from your Screenshot) ---
    # Dataset 1: 200 Gold Standard
    "GOLD_IMG_DIR": "/kaggle/input/200-gold-standard-adni/200_AD_CN_MCI_11112025/images",
    "GOLD_MASK_DIR": "/kaggle/input/200-gold-standard-adni/200_AD_CN_MCI_11112025/masks",
    
    # Dataset 2: Gold Metadata
    "GOLD_CSV": "/kaggle/input/metadatafor200gd/metadata.csv",
    
    # Dataset 3: Titanium (Pseudo)
    "TITANIUM_IMG_DIR": "/kaggle/input/tittanium-standard-dataset/TITANIUM_20K_DATASET/images",
    "TITANIUM_MASK_DIR": "/kaggle/input/tittanium-standard-dataset/TITANIUM_20K_DATASET/masks",
    "TITANIUM_CSV": "/kaggle/input/tittanium-standard-dataset/TITANIUM_20K_DATASET/metadata.csv",
    
    # --- 5. LOGGING ---
    "PROJECT_NAME": "Brain_SOTA_MaxViT_AnchorUT",
    "ENTITY": "alzhemer_segmentaion", # Update if needed
    "FOLDS_TO_RUN": [0], # Run all 5 folds for scientific validity
    "CACHE_RATE": 1.0, # Cache 100% of Real data in RAM (Kaggle has plenty)
}

# Device Check
CONFIG['DEVICE'] = "cuda" if torch.cuda.is_available() else "cpu"
print(f"‚úÖ Configuration Loaded.")
print(f"   Target Resolution: {CONFIG['IMG_SIZE']}")
print(f"   Batch Composition: {CONFIG['REAL_BS']} Real + {CONFIG['PSEUDO_BS']} Pseudo")
print(f"   Path Check (Gold): {os.path.exists(CONFIG['GOLD_IMG_DIR'])}")
print(f"   Path Check (Titanium): {os.path.exists(CONFIG['TITANIUM_IMG_DIR'])}")

‚úÖ Configuration Loaded.
   Target Resolution: (224, 224)
   Batch Composition: 16 Real + 16 Pseudo
   Path Check (Gold): True
   Path Check (Titanium): True


In [3]:
# Cell 3: Imports & Reproducibility
# ==================================================================================
import gc
import sys
import random
import numpy as np
import pandas as pd
import cv2
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader

# MONAI (The Medical Engine)
import monai
from monai.data import Dataset, CacheDataset, DataLoader as MonaiLoader
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd, 
    ResizeWithPadOrCropd, Resized, RandFlipd, RandRotate90d, 
    RandShiftIntensityd, RandCoarseDropoutd, EnsureTyped, NormalizeIntensityd
)
from monai.utils import set_determinism

# Model Library
import segmentation_models_pytorch as smp

# Logging
import wandb
from kaggle_secrets import UserSecretsClient

# 1. Set Determinism (Reproducibility)
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False # Slower but reproducible
    set_determinism(seed=seed) # MONAI specific seed

seed_everything(CONFIG['SEED'])

# 2. Login to WandB
try:
    user_secrets = UserSecretsClient()
    wandb_key = user_secrets.get_secret("wandb_api")
    wandb.login(key=wandb_key)
    print("‚úÖ WandB Logged In")
except:
    print("‚ö†Ô∏è WandB Login Failed (Check Kaggle Secrets)")
    wandb.login(anonymous='must')

2025-11-20 22:30:16.467809: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763677816.490662     360 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763677816.497620     360 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mutkarsh3104-imp[0m ([33malzhemer_segmentaion[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


‚úÖ WandB Logged In


In [4]:
# Cell 4: MONAI Transforms
# ==================================================================================
def get_transforms(phase):
    """
    Returns MONAI Compose transforms.
    keys=["image", "mask"] tells MONAI to apply geometry changes to BOTH,
    but intensity changes ONLY to the image.
    """
    
    # 1. COMMON TRANSFORMS (Geometry Fix)
    # We Pad 208 -> 240 (Square), then Downsample -> 224.
    # This preserves the exact shape of the ventricles.
    common_transforms = [
        LoadImaged(keys=["image", "mask"]),
        EnsureChannelFirstd(keys=["image", "mask"]),
        
        # --- THE GEOMETRY FIX ---
        # Center the brain and pad with zeros to 240x240
        ResizeWithPadOrCropd(keys=["image", "mask"], spatial_size=CONFIG['NATIVE_PAD_SIZE']),
        # Resize to MaxViT input (224x224)
        Resized(keys=["image", "mask"], spatial_size=CONFIG['IMG_SIZE'], mode=("bilinear", "nearest")),
        
        # Normalize Intensity (Crucial for MRI)
        ScaleIntensityd(keys=["image"]), 
    ]

    # 2. TRAINING TRANSFORMS (Augmentation)
    if phase == 'train':
        train_transforms = [
            # Geometry Augs
            RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=0),
            RandFlipd(keys=["image", "mask"], prob=0.5, spatial_axis=1),
            RandRotate90d(keys=["image", "mask"], prob=0.5, max_k=3),
            
            # Intensity Augs (Scanner Variance)
            RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
            
            # --- THE BIAS FIX (Spatial Dropout) ---
            # Cut 8 holes of size 16x16. 
            # This stops MaxViT from memorizing the skull/background.
            RandCoarseDropoutd(
                keys=["image", "mask"],
                holes=8, spatial_size=(20, 20),
                fill_value=0, prob=0.3
            ),
            
            EnsureTyped(keys=["image", "mask"])
        ]
        return Compose(common_transforms + train_transforms)

    # 3. VALIDATION TRANSFORMS (Clean)
    else:
        val_transforms = [
            EnsureTyped(keys=["image", "mask"])
        ]
        return Compose(common_transforms + val_transforms)

print("‚úÖ Transforms Defined using MONAI (Geometry-Safe)")

‚úÖ Transforms Defined using MONAI (Geometry-Safe)


In [5]:
# Cell 5: Custom Transforms & Data Helper (CORRECTED)
# ==================================================================================
from monai.transforms import MapTransform

class ProcessMaskd(MapTransform):
    """
    Custom MONAI Transform to handle Label Types.
    - If Real (Gold): Binarizes mask (0.0 or 1.0).
    - If Pseudo (Titanium): Normalizes mask (0.0 to 1.0) for Soft Labels.
    """
    def __init__(self, keys, is_pseudo=False):
        super().__init__(keys)
        self.is_pseudo = is_pseudo

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            # Pseudo: Keep uncertainty (0-255 -> 0.0-1.0)
            if self.is_pseudo:
                d[key] = d[key].float() / 255.0
            # Real: Hard Threshold (0-255 -> 0.0 or 1.0)
            else:
                d[key] = torch.where(d[key] > 127, 1.0, 0.0)
        return d

def get_data_dicts(gold_df, pseudo_df=None, fold=0, phase='train'):
    """
    Prepares list of dictionaries for MONAI Dataset.
    Handles None inputs safely.
    """
    data_dicts = []
    
    if phase == 'train':
        # 1. Process Gold Data (Only if provided)
        if gold_df is not None:
            for _, row in gold_df.iterrows():
                data_dicts.append({
                    "image": os.path.join(CONFIG['GOLD_IMG_DIR'], row['image_id']),
                    "mask": os.path.join(CONFIG['GOLD_MASK_DIR'], row['mask_id']),
                    "source": "gold"
                })
            
        # 2. Process Pseudo Data (Only if provided)
        if pseudo_df is not None:
            for _, row in pseudo_df.iterrows():
                data_dicts.append({
                    "image": os.path.join(CONFIG['TITANIUM_IMG_DIR'], row['image_id']),
                    "mask": os.path.join(CONFIG['TITANIUM_MASK_DIR'], row['mask_id']),
                    "source": "pseudo"
                })
    else:
        # Validation List (Always Gold)
        if gold_df is not None:
            for _, row in gold_df.iterrows():
                data_dicts.append({
                    "image": os.path.join(CONFIG['GOLD_IMG_DIR'], row['image_id']),
                    "mask": os.path.join(CONFIG['GOLD_MASK_DIR'], row['mask_id']),
                    "source": "gold"
                })
            
    return data_dicts

In [6]:
# Cell 6: Hybrid Loss & Metrics (UPDATED)
# ==================================================================================
from monai.losses import HausdorffDTLoss, DiceLoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric, MeanIoU

# Loss (Same as before)
class HybridLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss(sigmoid=True, batch=True)
        self.hausdorff = HausdorffDTLoss(sigmoid=True)

    def forward(self, pred, target):
        loss_bce = self.bce(pred, target)
        loss_dice = self.dice(pred, target)
        # HD is heavy, we weight it lower
        loss_hd = self.hausdorff(pred, target)
        return (0.4 * loss_bce) + (0.3 * loss_dice) + (0.3 * loss_hd)

# --- METRICS INITIALIZATION ---
# We need independent counters for Train and Val to avoid mixing data
dice_metric = DiceMetric(include_background=False, reduction="mean")
iou_metric = MeanIoU(include_background=False, reduction="mean")
hd_metric = HausdorffDistanceMetric(include_background=False, percentile=95, reduction="mean")

# Helper to calculate simple batch dice for training monitoring (faster than MONAI metric)
def get_batch_dice(y_pred, y_true):
    y_pred = (y_pred.sigmoid() > 0.5).float()
    # Threshold soft labels for metric calculation
    y_true = (y_true > 0.5).float()
    intersection = (y_pred * y_true).sum()
    union = y_pred.sum() + y_true.sum()
    return 2.0 * intersection / (union + 1e-6)

In [7]:
# Cell 7: Model Builder
# ==================================================================================
def build_model():
    model = smp.UnetPlusPlus(
        encoder_name="tu-maxvit_rmlp_small_rw_224", # MaxViT Tiny (224 Native)
        encoder_weights="imagenet",
        in_channels=1, # MONAI LoadImage usually keeps it 1 channel if greyscale
        classes=1,
        activation=None
    )
    return model

In [8]:
# Cell 8: Training Engine (FULL METRICS VERSION)
# ==================================================================================
from itertools import cycle

def train_one_epoch(model, loader_gold, loader_pseudo, optimizer, loss_fn, scaler, epoch):
    model.train()
    running_loss = 0
    running_train_dice = 0 # <--- NEW: Track Train Accuracy
    
    iterator = tqdm(zip(cycle(loader_gold), loader_pseudo), total=len(loader_pseudo), desc=f"Train Ep {epoch}", leave=False)
    
    optimizer.zero_grad()
    
    for step, (batch_gold, batch_pseudo) in enumerate(iterator):
        img_g, mask_g = batch_gold['image'].to(CONFIG['DEVICE']), batch_gold['mask'].to(CONFIG['DEVICE'])
        img_p, mask_p = batch_pseudo['image'].to(CONFIG['DEVICE']), batch_pseudo['mask'].to(CONFIG['DEVICE'])
        
        images = torch.cat([img_g, img_p], dim=0)
        masks = torch.cat([mask_g, mask_p], dim=0)
        
        with autocast():
            outputs = model(images)
            loss = loss_fn(outputs, masks)
            loss = loss / CONFIG['GRAD_ACCUM_STEPS']
            
        scaler.scale(loss).backward()
        
        if (step + 1) % CONFIG['GRAD_ACCUM_STEPS'] == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
        # --- METRIC TRACKING ---
        current_loss = loss.item() * CONFIG['GRAD_ACCUM_STEPS']
        running_loss += current_loss
        
        # Calculate rough batch dice for monitoring (on GPU)
        batch_dice = get_batch_dice(outputs.detach(), masks)
        running_train_dice += batch_dice.item()
        
        # Detailed WandB Logging every 50 steps
        if step % 50 == 0:
            wandb.log({
                "train/step_loss": current_loss,
                "train/step_dice": batch_dice.item(),
                "train/lr": optimizer.param_groups[0]['lr']
            })
            
    epoch_loss = running_loss / len(loader_pseudo)
    epoch_dice = running_train_dice / len(loader_pseudo)
    return epoch_loss, epoch_dice

@torch.no_grad()
def valid_one_epoch(model, loader, epoch):
    model.eval()
    dice_metric.reset()
    hd_metric.reset()
    iou_metric.reset() # <--- NEW
    
    for batch in tqdm(loader, desc=f"Val Ep {epoch}", leave=False):
        images, masks = batch['image'].to(CONFIG['DEVICE']), batch['mask'].to(CONFIG['DEVICE'])
        
        with autocast():
            # TTA
            pred_1 = torch.sigmoid(model(images))
            pred_2 = torch.sigmoid(torch.flip(model(torch.flip(images, dims=[3])), dims=[3]))
            pred_avg = (pred_1 + pred_2) / 2.0
            
        pred_bin = (pred_avg > 0.5).float()
        
        # Update Metrics
        dice_metric(y_pred=pred_bin, y=masks)
        hd_metric(y_pred=pred_bin, y=masks)
        iou_metric(y_pred=pred_bin, y=masks) # <--- NEW
        
    mean_dice = dice_metric.aggregate().item()
    mean_hd = hd_metric.aggregate().item()
    mean_iou = iou_metric.aggregate().item() # <--- NEW
    
    # Log Sample to WandB
    wandb.log({
        "val/sample": wandb.Image(
            images[0].cpu().numpy().transpose(1, 2, 0),
            masks={
                "pred": {"mask_data": pred_bin[0,0].cpu().numpy(), "class_labels": {1: "Brain"}},
                "truth": {"mask_data": masks[0,0].cpu().numpy(), "class_labels": {1: "Brain"}}
            },
            caption=f"Epoch {epoch} Pred"
        )
    })
    
    return mean_dice, mean_hd, mean_iou

In [None]:
# Cell 9: K-Fold Execution (FINAL VERSION)
# ==================================================================================
from sklearn.model_selection import KFold
import json # Needed for report saving

# Load CSVs
gold_df = pd.read_csv(CONFIG['GOLD_CSV'])
pseudo_df = pd.read_csv(CONFIG['TITANIUM_CSV'])

# Define K-Fold
kf = KFold(n_splits=5, shuffle=True, random_state=CONFIG['SEED'])

for fold, (train_idx, val_idx) in enumerate(kf.split(gold_df)):
    if fold not in CONFIG['FOLDS_TO_RUN']:
        continue
        
    print(f"\nüöÄ STARTING FOLD {fold}")
    
    # 1. Init WandB
    run = wandb.init(project=CONFIG['PROJECT_NAME'], entity=CONFIG['ENTITY'], name=f"MaxViT_Fold_{fold}", config=CONFIG, reinit=True)
    
    # 2. Split Dataframes
    train_gold_fold = gold_df.iloc[train_idx]
    val_gold_fold = gold_df.iloc[val_idx]
    
    # 3. Prepare Dictionaries
    train_gold_dicts = get_data_dicts(train_gold_fold, phase='train')
    # Fix the None crash by handling pseudo safely in helper
    train_pseudo_dicts = get_data_dicts(None, pseudo_df, phase='train') 
    train_pseudo_only = [d for d in train_pseudo_dicts if d['source'] == 'pseudo']
    
    val_dicts = get_data_dicts(val_gold_fold, phase='valid')
    
    # 4. Create Datasets
    # CacheDataset for Real Data (Fast RAM access)
    ds_gold = CacheDataset(
        data=train_gold_dicts, 
        transform=Compose([get_transforms('train'), ProcessMaskd(keys=["mask"], is_pseudo=False)]),
        cache_rate=CONFIG['CACHE_RATE'], num_workers=4
    )
    
    # Standard Dataset for Pseudo (Stream from disk)
    ds_pseudo = Dataset(
        data=train_pseudo_only,
        transform=Compose([get_transforms('train'), ProcessMaskd(keys=["mask"], is_pseudo=True)])
    )
    
    ds_val = CacheDataset(
        data=val_dicts,
        transform=Compose([get_transforms('valid'), ProcessMaskd(keys=["mask"], is_pseudo=False)]),
        cache_rate=1.0
    )
    
    # 5. Create Loaders (SPEED OPTIMIZED)
    # Added persistent_workers=True to fix the "GPU Waiting" issue
    loader_gold = MonaiLoader(
        ds_gold, 
        batch_size=CONFIG['REAL_BS'], 
        shuffle=True, 
        num_workers=4, 
        persistent_workers=True, 
        drop_last=True
    )
    
    loader_pseudo = MonaiLoader(
        ds_pseudo, 
        batch_size=CONFIG['PSEUDO_BS'], 
        shuffle=True, 
        num_workers=4, 
        persistent_workers=True, 
        drop_last=True
    )
    
    loader_val = MonaiLoader(
        ds_val, 
        batch_size=24, # Increased for faster validation
        shuffle=False, 
        num_workers=4,
        persistent_workers=True
    )
    
    # 6. Setup Model & Opt
    model = build_model().to(CONFIG['DEVICE'])
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['LR'], weight_decay=CONFIG['WEIGHT_DECAY'])
    # Warmup Cosine Scheduler (Best for MaxViT)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1, eta_min=1e-6)
    scaler = GradScaler()
    loss_fn = HybridLoss()
    
    # Tracking Variables
    best_dice = 0.0
    best_metrics = {} # To save for final report
    
    # History for Local Plots
    history = {
        'train_loss': [], 'train_dice': [], 
        'val_dice': [], 'val_iou': [], 'val_hd': []
    }
    
    # 7. Epoch Loop
    for epoch in range(CONFIG['EPOCHS']):
        # Train One Epoch
        train_loss, train_dice = train_one_epoch(model, loader_gold, loader_pseudo, optimizer, loss_fn, scaler, epoch)
        
        # Validate One Epoch (Returns 3 metrics now)
        val_dice, val_hd, val_iou = valid_one_epoch(model, loader_val, epoch)
        
        scheduler.step()
        
        # Update Local History
        history['train_loss'].append(train_loss)
        history['train_dice'].append(train_dice)
        history['val_dice'].append(val_dice)
        history['val_iou'].append(val_iou)
        history['val_hd'].append(val_hd)
        
        # Log to WandB
        wandb.log({
            "epoch": epoch,
            "train/loss": train_loss,
            "train/dice": train_dice,
            "val/dice": val_dice,
            "val/iou": val_iou,
            "val/hd": val_hd,
            "lr": optimizer.param_groups[0]['lr']
        })
    
        print(f"Ep {epoch} | Loss: {train_loss:.4f} | TrDice: {train_dice:.3f} | ValDice: {val_dice:.4f} | ValIoU: {val_iou:.4f}")
        
        # Save Best Model & Metrics
        if val_dice > best_dice:
            best_dice = val_dice
            # Snapshot metrics for report
            best_metrics = {
                "fold": fold,
                "best_dice": val_dice,
                "best_iou": val_iou,
                "best_hd": val_hd,
                "best_epoch": epoch
            }
            torch.save(model.state_dict(), f"MaxViT_Best_Fold_{fold}.pth")
            print(f"  ‚≠ê New Best Dice! Saved.")
            
    wandb.finish()
    
    # 8. Save Metrics to JSON (For Final Report Cell)
    with open(f"metrics_fold_{fold}.json", "w") as f:
        json.dump(best_metrics, f)
    print(f"‚úÖ Fold {fold} metrics saved to disk.")

    # 9. Generate Local Analytics Plot (Optional, if Cell 11 is defined)
    try:
        plot_training_analytics(history, fold)
    except NameError:
        print("‚ö†Ô∏è Analytics plot skipped (Cell 11 function not found).")


üöÄ STARTING FOLD 0




Loading dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 160/160 [00:00<00:00, 231649.51it/s]
Loading dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 40/40 [00:00<00:00, 104596.11it/s]
  scaler = GradScaler()


Train Ep 0:   0%|          | 0/776 [00:00<?, ?it/s]

  with autocast():


In [None]:
# Cell 10: Visualization
# ==================================================================================
def visualize_results(model_path, val_loader, fold):
    print(f"üìä Visualizing Fold {fold}...")
    model = build_model()
    model.load_state_dict(torch.load(model_path))
    model.to(CONFIG['DEVICE'])
    model.eval()
    
    batch = next(iter(val_loader))
    images, masks = batch['image'].to(CONFIG['DEVICE']), batch['mask'].to(CONFIG['DEVICE'])
    
    with torch.no_grad():
        preds = torch.sigmoid(model(images))
        preds = (preds > 0.5).float()
        
    # Plot 5 samples
    fig, axes = plt.subplots(5, 4, figsize=(15, 20))
    for i in range(5):
        if i >= len(images): break
        
        # Input
        img_np = images[i, 0].cpu().numpy()
        axes[i,0].imshow(img_np, cmap='gray')
        axes[i,0].set_title("Input MRI")
        
        # Truth
        gt_np = masks[i, 0].cpu().numpy()
        axes[i,1].imshow(gt_np, cmap='gray')
        axes[i,1].set_title("Ground Truth")
        
        # Pred
        pred_np = preds[i, 0].cpu().numpy()
        axes[i,2].imshow(pred_np, cmap='gray')
        axes[i,2].set_title(f"Prediction")
        
        # Error Map (Red = FP, Blue = FN)
        diff = pred_np - gt_np
        axes[i,3].imshow(diff, cmap='coolwarm')
        axes[i,3].set_title("Error Map (Red=FP, Blue=FN)")
        
    plt.tight_layout()
    plt.show()

# To run visualization after training:
# visualize_results(f"MaxViT_Best_Fold_0.pth", loader_val, 0)

In [None]:
# Cell 11: Advanced Training Analytics Plotting
# ==================================================================================
import matplotlib.pyplot as plt
import seaborn as sns

def plot_training_analytics(history, fold):
    # Set style
    sns.set_style("darkgrid")
    plt.rcParams['font.size'] = 10
    
    epochs = range(len(history['train_loss']))
    fig, axes = plt.subplots(2, 2, figsize=(20, 12))
    fig.suptitle(f'Fold {fold} Training Dynamics', fontsize=16, weight='bold')
    
    # 1. LOSS DYNAMICS
    # Shows if the model is actually learning
    axes[0,0].plot(epochs, history['train_loss'], label='Train Hybrid Loss', color='#FF5733', linewidth=2)
    axes[0,0].set_title("üìâ Loss Convergence", fontsize=12, weight='bold')
    axes[0,0].set_xlabel("Epochs")
    axes[0,0].set_ylabel("Loss")
    axes[0,0].legend()
    
    # 2. ACCURACY METRICS
    # Shows Dice and IoU together
    axes[0,1].plot(epochs, history['val_dice'], label='Val Dice', color='#2E86C1', linewidth=2)
    axes[0,1].plot(epochs, history['val_iou'], label='Val IoU', color='#28B463', linestyle='--')
    axes[0,1].set_title("üéØ Accuracy (Overlap)", fontsize=12, weight='bold')
    axes[0,1].set_xlabel("Epochs")
    axes[0,1].set_ylabel("Score (0-1)")
    axes[0,1].legend()
    
    # 3. GENERALIZATION GAP (Train vs Val)
    # CRITICAL: If Train (Red) is way higher than Val (Blue), you are overfitting (memorizing noise).
    # Ideally, they should move up together.
    # Note: We assume you logged 'train_dice' in history during Cell 9
    if 'train_dice' in history:
        axes[1,0].plot(epochs, history['train_dice'], label='Train Dice', color='#C0392B', linestyle=':')
        axes[1,0].plot(epochs, history['val_dice'], label='Val Dice', color='#2E86C1', linewidth=2)
        axes[1,0].set_title("üß† Generalization Gap (Train vs Val)", fontsize=12, weight='bold')
        axes[1,0].set_xlabel("Epochs")
        axes[1,0].set_ylabel("Dice Score")
        axes[1,0].legend()
    else:
        axes[1,0].text(0.5, 0.5, "Train Dice not found in history", ha='center')
    
    # 4. BOUNDARY QUALITY (Hausdorff)
    # Shows if edges are getting sharper. Lower is better.
    axes[1,1].plot(epochs, history['val_hd'], label='Val HD95', color='#884EA0', linewidth=2)
    axes[1,1].set_title("üìè Edge Precision (Hausdorff Dist)", fontsize=12, weight='bold')
    axes[1,1].set_xlabel("Epochs")
    axes[1,1].set_ylabel("Distance (Pixels)")
    axes[1,1].legend()
    axes[1,1].invert_yaxis() # Invert because lower is better
    
    plt.tight_layout()
    plt.savefig(f"analytics_fold_{fold}.png", dpi=300)
    plt.show()

# Usage Example (Put this inside Cell 9 after the loop, or run manually):
plot_training_analytics(history, fold)

In [None]:
# Cell 12: Final Aggregate Report
# ==================================================================================
import glob
import json

def generate_final_paper_report():
    print("\n" + "="*80)
    print("üìÑ GENERATING FINAL CROSS-VALIDATION REPORT")
    print("="*80)
    
    # 1. Find all metric files
    metric_files = glob.glob("metrics_fold_*.json")
    
    if not metric_files:
        print("‚ö†Ô∏è No metric files found! Did you complete any folds?")
        return
    
    results = []
    for fpath in metric_files:
        with open(fpath, 'r') as f:
            results.append(json.load(f))
            
    # 2. Create DataFrame
    df_res = pd.DataFrame(results).sort_values('fold').set_index('fold')
    
    # 3. Calculate Aggregate Stats
    mean_dice = df_res['best_dice'].mean()
    std_dice = df_res['best_dice'].std()
    
    mean_iou = df_res['best_iou'].mean()
    std_iou = df_res['best_iou'].std()
    
    mean_hd = df_res['best_hd'].mean()
    std_hd = df_res['best_hd'].std()
    
    # 4. Print detailed table
    print("\nüìä PER-FOLD PERFORMANCE:")
    print("-" * 60)
    print(df_res.to_string(float_format="{:.4f}".format))
    print("-" * 60)
    
    # 5. Print Scientific Summary
    report = f"""
    üèÜ FINAL AGGREGATE RESULTS ({len(df_res)} Folds):
    
    üîπ DICE SCORE:       {mean_dice:.4f} ¬± {std_dice:.4f}  (Target: >0.90)
    üîπ IOU SCORE:        {mean_iou:.4f}  ¬± {std_iou:.4f}
    üîπ HAUSDORFF (95%):  {mean_hd:.4f}   ¬± {std_hd:.4f}   (Lower is better)
    
    ------------------------------------------------------------
    CONCLUSION:
    The model achieved a mean Dice score of {mean_dice:.4f} across {len(df_res)} folds.
    The standard deviation of {std_dice:.4f} indicates {'STABLE' if std_dice < 0.015 else 'UNSTABLE'} performance.
    """
    
    print(report)
    
    # 6. Save to Text File (for your paper/report)
    with open("final_sota_report.txt", "w") as f:
        f.write(report)
        f.write("\n\nRaw Data:\n")
        f.write(df_res.to_string())
        
    print("‚úÖ Report saved to 'final_sota_report.txt'")

# Run the report
generate_final_paper_report()