In [None]:
import os
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_single_fold(fold_idx, dataloaders):
    """Train a single fold"""
    print(f"\n{'='*60}")
    print(f"Fold {fold_idx + 1}/4")
    print(f"{'='*60}")

    # Get current fold data
    train_loader, val_loader = dataloaders[fold_idx]

    # Initialize model
    model = MMFormer_LiFS_Improved_v2(num_classes=2, dropout_rate=0.1).to(device)

    # Training hyperparameters
    num_epochs = 100
    learning_rate = 1e-4
    weight_decay = 1e-3
    patience = 30

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Learning rate scheduler
    scheduler = CosineAnnealingLR(
        optimizer,
        T_max=num_epochs,
        eta_min=1e-6,
        last_epoch=-1
    )

    # Gradient clipping
    max_grad_norm = 1.0

    # Early stopping state
    best_val_acc = 0.0
    epochs_no_improve = 0

    # Training history
    train_history = {'loss': [], 'acc': []}
    val_history = {'loss': [], 'acc': []}

    # Create checkpoint directory
    ckpt_dir = f'./checkpoints/fold_{fold_idx}'
    os.makedirs(ckpt_dir, exist_ok=True)

    print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    for epoch in range(1, num_epochs+1):
        # ---- Training phase ----
        model.train()
        train_bar = tqdm(train_loader, desc=f"[Fold {fold_idx+1}] Epoch {epoch:03d}", leave=False)
        train_losses, train_preds, train_labels = [], [], []

        for batch in train_bar:
            imgs = batch['image'].to(device)
            masks = batch['modality_mask'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            logits = model(imgs, masks)
            loss = criterion(logits, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            optimizer.step()

            # Record metrics
            train_losses.append(loss.item())
            preds = logits.argmax(dim=1).cpu().tolist()
            train_preds += preds
            train_labels += labels.cpu().tolist()

            # Update progress bar
            current_acc = accuracy_score(train_labels, train_preds)
            train_bar.set_postfix(
                loss=f"{loss.item():.4f}",
                acc=f"{current_acc:.4f}"
            )

        # Training statistics
        avg_train_loss = np.mean(train_losses)
        train_acc = accuracy_score(train_labels, train_preds)
        train_history['loss'].append(avg_train_loss)
        train_history['acc'].append(train_acc)

        # ---- Validation phase ----
        model.eval()
        val_bar = tqdm(val_loader, desc=f"[Fold {fold_idx+1}] Validating", leave=False)
        val_losses, val_preds, val_labels = [], [], []

        with torch.no_grad():
            for batch in val_bar:
                imgs = batch['image'].to(device)
                masks = batch['modality_mask'].to(device)
                labels = batch['label'].to(device)

                logits = model(imgs, masks)
                loss = criterion(logits, labels)

                val_losses.append(loss.item())
                preds = logits.argmax(dim=1).cpu().tolist()
                val_preds += preds
                val_labels += labels.cpu().tolist()

        # Validation statistics
        avg_val_loss = np.mean(val_losses)
        val_acc = accuracy_score(val_labels, val_preds)
        val_history['loss'].append(avg_val_loss)
        val_history['acc'].append(val_acc)

        # Print epoch results
        print(f"[Fold {fold_idx+1}] Epoch {epoch:03d} | "
              f"Train Loss: {avg_train_loss:.4f} Acc: {train_acc:.4f} | "
              f"Val Loss: {avg_val_loss:.4f} Acc: {val_acc:.4f} | "
              f"LR: {optimizer.param_groups[0]['lr']:.2e}")

        # Learning rate scheduling
        scheduler.step()

        # Early stopping & save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            epochs_no_improve = 0

            # Save best model checkpoint
            ckpt_path = os.path.join(ckpt_dir, f"best_epoch_{epoch:03d}.pth")
            torch.save({
                'epoch': epoch,
                'fold': fold_idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_val_acc,
                'train_history': train_history,
                'val_history': val_history
            }, ckpt_path)

            print(f"[Fold {fold_idx+1}] Saved checkpoint (Val Acc: {best_val_acc:.4f}) -> {ckpt_path}")
        else:
            epochs_no_improve += 1

            if epochs_no_improve >= patience:
                print(f"[Fold {fold_idx+1}] Early stopping at epoch {epoch} (no improvement for {patience} epochs)")
                break

    print(f"[Fold {fold_idx+1}] Training completed | Best Val Acc: {best_val_acc:.4f}\n")

    return best_val_acc

def run_4fold_training(dataloaders):
    """Run 4-fold cross-validation training"""

    # 4-fold cross-validation loop
    fold_results = []

    print("\nStarting 4-Fold Cross-Validation Training...")
    print("="*60)

    for fold_idx in range(4):
        best_acc = train_single_fold(fold_idx, dataloaders)
        fold_results.append(best_acc)

    # Compute cross-validation statistics
    mean_acc = np.mean(fold_results)
    std_acc = np.std(fold_results)

    print("4-Fold Cross-Validation Results")
    for i, acc in enumerate(fold_results):
        print(f"Fold {i+1}: {acc:.4f}")
    print("-"*60)
    print(f"Mean Accuracy: {mean_acc:.4f} +/- {std_acc:.4f}")
    print(f"Best Fold: {max(fold_results):.4f}")
    print(f"Worst Fold: {min(fold_results):.4f}")

    return fold_results

# Execute training
fold_results = run_4fold_training(dataloaders)