In [None]:
# ============================================================================
# CELL 1: Environment Setup and GPU Check
# ============================================================================
import torch
import os
import sys
from pathlib import Path

print("="*70)
print("TOPO-BRAIN GAN TRAINING - ENVIRONMENT SETUP")
print("="*70)

# Check GPU
if torch.cuda.is_available():
    print(f"‚úì GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    device = 'cuda'
else:
    print("‚ö†Ô∏è No GPU available, using CPU (training will be slow!)")
    device = 'cpu'

# Kaggle/Colab detection
IS_KAGGLE = Path('/kaggle/input').exists()
IS_COLAB = Path('/content').exists()

if IS_KAGGLE:
    print("\n‚úì Running on Kaggle")
    BASE_DIR = Path('/kaggle/working')
elif IS_COLAB:
    print("\n‚úì Running on Google Colab")
    BASE_DIR = Path('/content')
else:
    print("\n‚úì Running locally")
    BASE_DIR = Path.cwd()

print(f"\nBase directory: {BASE_DIR}")
print("="*70)

In [None]:
# ============================================================================
# CELL 2: Install Dependencies
# ============================================================================
print("\n" + "="*70)
print("INSTALLING DEPENDENCIES")
print("="*70)

!pip install -q monai nibabel matplotlib tqdm scikit-image scikit-learn tensorboard

print("‚úì Core dependencies installed")
print("‚úì MONAI, NiBabel, Matplotlib, TensorBoard ready")
print("="*70)

In [None]:
# ============================================================================
# CELL 3: Clone Repository
# ============================================================================
print("\n" + "="*70)
print("CLONING TOPO-BRAIN REPOSITORY")
print("="*70)

os.chdir(BASE_DIR)

# Clone if not already present
if not (BASE_DIR / 'Topo-Brain').exists():
    !git clone https://github.com/prabeshx12/Topo-Brain.git
    print("‚úì Repository cloned")
else:
    print("‚úì Repository already exists")

os.chdir(BASE_DIR / 'Topo-Brain')
sys.path.insert(0, str(BASE_DIR / 'Topo-Brain'))

print(f"\n‚úì Working directory: {Path.cwd()}")
print("‚úì Python path updated")
print("="*70)

In [None]:
# ============================================================================
# CELL 4: Setup Output Directories
# ============================================================================
print("\n" + "="*70)
print("CREATING OUTPUT DIRECTORIES")
print("="*70)

# Create all necessary directories
DIRS = {
    'checkpoints': BASE_DIR / 'gan_checkpoints',
    'logs': BASE_DIR / 'gan_logs',
    'visualizations': BASE_DIR / 'gan_visualizations',
    'cache': BASE_DIR / 'gan_cache',
    'final_output': BASE_DIR / 'gan_final_output',
}

for name, path in DIRS.items():
    path.mkdir(parents=True, exist_ok=True)
    print(f"‚úì Created: {name} -> {path}")

print("\n‚úì All directories ready")
print("="*70)

In [None]:
# ============================================================================
# CELL 5: Link Preprocessed Data
# ============================================================================
print("\n" + "="*70)
print("LINKING PREPROCESSED DATA")
print("="*70)

# UPDATE THIS PATH to your Kaggle dataset!
# After uploading preprocessed data, update this path
if IS_KAGGLE:
    PREPROCESSED_DATA_PATH = Path('/kaggle/input/unc-paired-3t-7t-preprocessed/')
    # Alternative: If you ran preprocessing in same notebook:
    # PREPROCESSED_DATA_PATH = Path('/kaggle/working/preprocessed_no_n4/')
else:
    PREPROCESSED_DATA_PATH = Path('./preprocessed')  # Local path

print(f"Preprocessed data path: {PREPROCESSED_DATA_PATH}")

# Verify data exists
if PREPROCESSED_DATA_PATH.exists():
    print(f"‚úì Found preprocessed data directory")
    
    # Count files
    nifti_files = list(PREPROCESSED_DATA_PATH.rglob("*_preprocessed.nii.gz"))
    print(f"‚úì Total preprocessed volumes: {len(nifti_files)}")
    
    # Show structure
    subjects = sorted([d.name for d in PREPROCESSED_DATA_PATH.iterdir() 
                      if d.is_dir() and d.name.startswith('sub-')])
    print(f"‚úì Found {len(subjects)} subjects: {subjects}")
    
    # Show sample files
    print("\nüìÑ Sample files:")
    for f in sorted(nifti_files)[:5]:
        print(f"   {f.name}")
else:
    print("‚ùå PREPROCESSED DATA NOT FOUND!")
    print("\nüëâ TO FIX:")
    print("   1. Upload preprocessed data as Kaggle Dataset")
    print("   2. Or run preprocessing notebook first")
    print("   3. Update PREPROCESSED_DATA_PATH in this cell")
    raise FileNotFoundError(f"Preprocessed data not found at {PREPROCESSED_DATA_PATH}")

print("\n‚úì Data validation complete")
print("="*70)

In [None]:
# ============================================================================
# CELL 6: Import Models and Utilities
# ============================================================================
print("\n" + "="*70)
print("IMPORTING MODELS AND UTILITIES")
print("="*70)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import nibabel as nib
import numpy as np
import random
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Import your models
try:
    from models.generator_unet3d import UNet3DGenerator
    from models.discriminator_patchgan3d import PatchGANDiscriminator3D
    print("‚úì Imported models from repository")
except ImportError as e:
    print(f"‚ö†Ô∏è Import error: {e}")
    print("‚ö†Ô∏è Will use simplified models")
    
    # Simplified fallback models (basic versions)
    class UNet3DGenerator(nn.Module):
        def __init__(self, in_channels=1, out_channels=1, base_features=32):
            super().__init__()
            # Simplified U-Net - replace with your actual implementation
            self.encoder = nn.Sequential(
                nn.Conv3d(in_channels, base_features, 3, padding=1),
                nn.InstanceNorm3d(base_features),
                nn.ReLU(inplace=True),
            )
            self.decoder = nn.Sequential(
                nn.Conv3d(base_features, out_channels, 3, padding=1),
                nn.Tanh(),
            )
        
        def forward(self, x):
            features = self.encoder(x)
            output = self.decoder(features)
            return output
    
    class PatchGANDiscriminator3D(nn.Module):
        def __init__(self, in_channels=1, base_features=64):
            super().__init__()
            # Simplified PatchGAN - replace with your actual implementation
            self.model = nn.Sequential(
                nn.Conv3d(in_channels, base_features, 4, 2, 1),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv3d(base_features, base_features*2, 4, 2, 1),
                nn.InstanceNorm3d(base_features*2),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv3d(base_features*2, 1, 4, 1, 1),
            )
        
        def forward(self, x):
            return self.model(x)

print("‚úì Models ready")
print("‚úì PyTorch imports complete")
print("="*70)

In [None]:
# ============================================================================
# CELL 7: Dataset Class for Paired 3T-7T Data
# ============================================================================
print("\n" + "="*70)
print("CREATING DATASET CLASS")
print("="*70)

class Paired3T7TDataset(Dataset):
    """
    PyTorch Dataset for paired 3T-7T MRI volumes.
    Extracts random 3D patches for training.
    """
    def __init__(
        self,
        data_pairs,
        patch_size=(64, 64, 64),
        num_patches_per_volume=10,
        transform=None,
        cache_data=False,
    ):
        self.data_pairs = data_pairs
        self.patch_size = patch_size
        self.num_patches_per_volume = num_patches_per_volume
        self.transform = transform
        self.cache_data = cache_data
        self.cache = {}
        
        self.total_patches = len(data_pairs) * num_patches_per_volume
        
    def __len__(self):
        return self.total_patches
    
    def __getitem__(self, idx):
        # Determine which volume pair and patch number
        pair_idx = idx // self.num_patches_per_volume
        pair = self.data_pairs[pair_idx]
        
        # Load or retrieve from cache
        if self.cache_data and pair_idx in self.cache:
            vol_3t, vol_7t = self.cache[pair_idx]
        else:
            vol_3t = nib.load(pair['input_3t']).get_fdata().astype(np.float32)
            vol_7t = nib.load(pair['target_7t']).get_fdata().astype(np.float32)
            
            if self.cache_data:
                self.cache[pair_idx] = (vol_3t, vol_7t)
        
        # Extract random patch
        patch_3t, patch_7t = self._extract_random_patch(vol_3t, vol_7t)
        
        # To tensor and add channel dimension
        patch_3t = torch.from_numpy(patch_3t[None, ...])  # [1, D, H, W]
        patch_7t = torch.from_numpy(patch_7t[None, ...])
        
        # Apply transforms if any
        if self.transform:
            patch_3t = self.transform(patch_3t)
            patch_7t = self.transform(patch_7t)
        
        return {
            'input_3t': patch_3t,
            'target_7t': patch_7t,
            'subject': pair.get('subject', 'unknown'),
        }
    
    def _extract_random_patch(self, vol_3t, vol_7t):
        """Extract matching random patch from both volumes."""
        d, h, w = vol_3t.shape
        pd, ph, pw = self.patch_size
        
        # Random starting coordinates
        start_d = random.randint(0, max(0, d - pd))
        start_h = random.randint(0, max(0, h - ph))
        start_w = random.randint(0, max(0, w - pw))
        
        # Extract patches
        patch_3t = vol_3t[start_d:start_d+pd, start_h:start_h+ph, start_w:start_w+pw]
        patch_7t = vol_7t[start_d:start_d+pd, start_h:start_h+ph, start_w:start_w+pw]
        
        # Pad if necessary
        if patch_3t.shape != self.patch_size:
            patch_3t = self._pad_to_size(patch_3t, self.patch_size)
            patch_7t = self._pad_to_size(patch_7t, self.patch_size)
        
        return patch_3t, patch_7t
    
    def _pad_to_size(self, volume, target_size):
        """Pad volume to target size."""
        pad_width = []
        for i in range(3):
            diff = target_size[i] - volume.shape[i]
            pad_before = diff // 2
            pad_after = diff - pad_before
            pad_width.append((pad_before, pad_after))
        
        return np.pad(volume, pad_width, mode='constant', constant_values=0)

print("‚úì Dataset class defined")
print("‚úì Supports random patch extraction")
print("‚úì Optional data caching")
print("="*70)

In [None]:
# ============================================================================
# CELL 8: Create Data Pairs and Splits (Multi-Modal Support)
# ============================================================================
print("\n" + "="*70)
print("CREATING 3T-7T DATA PAIRS")
print("="*70)

def create_paired_data_list(preprocessed_dir, modalities):
    """
    Create list of paired 3T-7T volumes.
    Supports single modality (str) or multiple (list).
    Assumes: ses-1 = 3T, ses-2 = 7T
    """
    preprocessed_dir = Path(preprocessed_dir)
    
    # Handle both string and list inputs
    if isinstance(modalities, str):
        modalities = [modalities]
    
    all_pairs = []
    
    for modality in modalities:
        print(f"üîç Searching for {modality} pairs...")
        
        # Group files by subject
        files_by_subject = {}
        for file in preprocessed_dir.rglob(f"*{modality}_preprocessed.nii.gz"):
            # Parse: sub-01_ses-1_T1w_preprocessed.nii.gz
            parts = file.stem.replace('_preprocessed', '').split('_')
            subject = parts[0]  # sub-01
            session = parts[1]  # ses-1 or ses-2
            
            if subject not in files_by_subject:
                files_by_subject[subject] = {}
            files_by_subject[subject][session] = file
        
        # Create pairs
        for subject, sessions in files_by_subject.items():
            if 'ses-1' in sessions and 'ses-2' in sessions:
                all_pairs.append({
                    'subject': subject,
                    'input_3t': str(sessions['ses-1']),
                    'target_7t': str(sessions['ses-2']),
                    'modality': modality,
                })
        
        print(f"   ‚úì Found {sum(1 for p in all_pairs if p['modality'] == modality)} {modality} pairs")
    
    return all_pairs

# Configuration
# ============================================
# CHANGE THIS LINE to train on both modalities:
# ============================================
MODALITIES = ['T1w', 'T2w']  # ‚Üê Train on BOTH!
# OR use single: MODALITIES = 'T1w'  # ‚Üê Train on T1w only

TRAIN_RATIO = 0.6
VAL_RATIO = 0.2
TEST_RATIO = 0.2
RANDOM_SEED = 42

# Set seed for reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# Create pairs (works with both single and multiple modalities)
if isinstance(MODALITIES, str):
    print(f"üîç Searching for {MODALITIES} pairs...")
else:
    print(f"üîç Searching for {MODALITIES} pairs...")

all_pairs = create_paired_data_list(PREPROCESSED_DATA_PATH, MODALITIES)
print(f"‚úì Total: {len(all_pairs)} paired volumes")

# Show breakdown
if isinstance(MODALITIES, list):
    for mod in MODALITIES:
        n = sum(1 for p in all_pairs if p['modality'] == mod)
        print(f"   ‚Ä¢ {mod}: {n} pairs")

# Patient-level split (NO DATA LEAKAGE!)
subjects = list(set([p['subject'] for p in all_pairs]))
print(f"‚úì Total subjects: {len(subjects)}")

random.shuffle(subjects)

n_train = int(len(subjects) * TRAIN_RATIO)
n_val = int(len(subjects) * VAL_RATIO)

train_subjects = subjects[:n_train]
val_subjects = subjects[n_train:n_train + n_val]
test_subjects = subjects[n_train + n_val:]

# Create splits (includes all modalities for each subject)
train_pairs = [p for p in all_pairs if p['subject'] in train_subjects]
val_pairs = [p for p in all_pairs if p['subject'] in val_subjects]
test_pairs = [p for p in all_pairs if p['subject'] in test_subjects]

print(f"\nüìä Data Splits:")
print(f"   Train: {len(train_pairs)} pairs from {len(train_subjects)} subjects")
print(f"   Val:   {len(val_pairs)} pairs from {len(val_subjects)} subjects")
print(f"   Test:  {len(test_pairs)} pairs from {len(test_subjects)} subjects")

print(f"\nüìã Train subjects: {train_subjects}")
print(f"üìã Val subjects:   {val_subjects}")
print(f"üìã Test subjects:  {test_subjects}")

# Save split info
split_info = {
    'train_subjects': train_subjects,
    'val_subjects': val_subjects,
    'test_subjects': test_subjects,
    'train_pairs': len(train_pairs),
    'val_pairs': len(val_pairs),
    'test_pairs': len(test_pairs),
    'modalities': MODALITIES if isinstance(MODALITIES, list) else [MODALITIES],
    'random_seed': RANDOM_SEED,
    'created': datetime.now().isoformat(),
}

with open(DIRS['cache'] / 'data_split.json', 'w') as f:
    json.dump(split_info, f, indent=2)

print(f"\n‚úì Split info saved to: {DIRS['cache'] / 'data_split.json'}")
print("="*70)

In [None]:
# ============================================================================
# CELL 9: Create DataLoaders
# ============================================================================
print("\n" + "="*70)
print("CREATING PYTORCH DATALOADERS")
print("="*70)

# Hyperparameters
BATCH_SIZE = 2              # Reduce to 1 if Out-Of-Memory
PATCH_SIZE = (64, 64, 64)   # Reduce to (32, 32, 32) if OOM
NUM_PATCHES_PER_VOLUME = 10 # Patches per volume per epoch
NUM_WORKERS = 2             # Parallel data loading

print(f"‚öôÔ∏è Configuration:")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Patch size: {PATCH_SIZE}")
print(f"   Patches per volume: {NUM_PATCHES_PER_VOLUME}")
print(f"   Num workers: {NUM_WORKERS}")

# Create datasets
print(f"\nüì¶ Creating datasets...")

train_dataset = Paired3T7TDataset(
    data_pairs=train_pairs,
    patch_size=PATCH_SIZE,
    num_patches_per_volume=NUM_PATCHES_PER_VOLUME,
    cache_data=False,  # Set True if enough RAM
)

val_dataset = Paired3T7TDataset(
    data_pairs=val_pairs,
    patch_size=PATCH_SIZE,
    num_patches_per_volume=5,  # Fewer for validation
    cache_data=False,
)

# Create loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True if device == 'cuda' else False,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True if device == 'cuda' else False,
)

print(f"\n‚úì Train loader: {len(train_loader)} batches")
print(f"‚úì Val loader:   {len(val_loader)} batches")

# Test loading
print(f"\nüß™ Testing data loading...")
test_batch = next(iter(train_loader))
print(f"‚úì Batch loaded successfully!")
print(f"   Input 3T shape:  {test_batch['input_3t'].shape}")
print(f"   Target 7T shape: {test_batch['target_7t'].shape}")
print(f"   Subjects: {test_batch['subject']}")

print("\n‚úì DataLoaders ready for training")
print("="*70)

In [None]:
# ============================================================================
# CELL 10: Initialize Models and Optimizers
# ============================================================================
print("\n" + "="*70)
print("INITIALIZING MODELS AND OPTIMIZERS")
print("="*70)

# Model parameters
BASE_FEATURES_G = 32  # Generator base features
BASE_FEATURES_D = 64  # Discriminator base features

print(f"üèóÔ∏è Creating models...")

# Generator: 3D U-Net
generator = UNet3DGenerator(
    in_channels=1,
    out_channels=1,
    base_features=BASE_FEATURES_G,
).to(device)

# Discriminator: 3D PatchGAN
discriminator = PatchGANDiscriminator3D(
    in_channels=1,
    base_features=BASE_FEATURES_D,
).to(device)

# Count parameters
n_params_g = sum(p.numel() for p in generator.parameters())
n_params_d = sum(p.numel() for p in discriminator.parameters())

print(f"‚úì Generator parameters:     {n_params_g:,}")
print(f"‚úì Discriminator parameters: {n_params_d:,}")

# Optimizers
LR_G = 2e-4
LR_D = 2e-4
BETA1 = 0.5
BETA2 = 0.999

optimizer_g = optim.Adam(generator.parameters(), lr=LR_G, betas=(BETA1, BETA2))
optimizer_d = optim.Adam(discriminator.parameters(), lr=LR_D, betas=(BETA1, BETA2))

print(f"\n‚öôÔ∏è Optimizers:")
print(f"   Generator LR:     {LR_G}")
print(f"   Discriminator LR: {LR_D}")

# Loss functions
criterion_l1 = nn.L1Loss()
criterion_adv = nn.MSELoss()  # LSGAN (more stable)

# Loss weights
LAMBDA_L1 = 100.0   # Weight for L1 reconstruction
LAMBDA_ADV = 1.0    # Weight for adversarial loss

print(f"\nüìä Loss configuration:")
print(f"   L1 weight:  {LAMBDA_L1}")
print(f"   Adv weight: {LAMBDA_ADV}")

# Mixed precision training
USE_AMP = True if device == 'cuda' else False
scaler_g = GradScaler() if USE_AMP else None
scaler_d = GradScaler() if USE_AMP else None

print(f"\n‚ö° Mixed precision: {USE_AMP}")

print("\n‚úì All models and optimizers ready")
print("="*70)

In [None]:
# ============================================================================
# CELL 11: Training Loop
# ============================================================================
print("\n" + "="*70)
print("STARTING GAN TRAINING")
print("="*70)

import time

# Training configuration
NUM_EPOCHS = 100
SAVE_INTERVAL = 5
LOG_INTERVAL = 10

# Tracking
train_losses_g = []
train_losses_d = []
val_losses = []
best_val_loss = float('inf')

print(f"üöÄ Training Configuration:")
print(f"   Epochs: {NUM_EPOCHS}")
print(f"   Device: {device}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Patch size: {PATCH_SIZE}")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")
print(f"\n{'='*70}\n")

start_time = time.time()

for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start = time.time()
    
    # ============ TRAINING PHASE ============
    generator.train()
    discriminator.train()
    
    epoch_loss_g = 0.0
    epoch_loss_d = 0.0
    epoch_loss_l1 = 0.0
    epoch_loss_adv = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS}", leave=False)
    
    for i, batch in enumerate(pbar):
        input_3t = batch['input_3t'].to(device)
        target_7t = batch['target_7t'].to(device)
        batch_size_current = input_3t.size(0)
        
        # ===== TRAIN DISCRIMINATOR =====
        optimizer_d.zero_grad()
        
        with autocast(enabled=USE_AMP):
            # Generate fake 7T
            fake_7t = generator(input_3t)
            
            # Discriminator predictions
            pred_real = discriminator(target_7t)
            pred_fake = discriminator(fake_7t.detach())
            
            # Labels (LSGAN: real=1, fake=0)
            real_label = torch.ones_like(pred_real)
            fake_label = torch.zeros_like(pred_fake)
            
            # Discriminator loss
            loss_d_real = criterion_adv(pred_real, real_label)
            loss_d_fake = criterion_adv(pred_fake, fake_label)
            loss_d = 0.5 * (loss_d_real + loss_d_fake)
        
        if USE_AMP:
            scaler_d.scale(loss_d).backward()
            scaler_d.step(optimizer_d)
            scaler_d.update()
        else:
            loss_d.backward()
            optimizer_d.step()
        
        # ===== TRAIN GENERATOR =====
        optimizer_g.zero_grad()
        
        with autocast(enabled=USE_AMP):
            # Generate fake 7T (fresh forward pass)
            fake_7t = generator(input_3t)
            
            # Adversarial loss (fool discriminator)
            pred_fake = discriminator(fake_7t)
            loss_adv = criterion_adv(pred_fake, real_label)
            
            # L1 reconstruction loss
            loss_l1 = criterion_l1(fake_7t, target_7t)
            
            # Total generator loss
            loss_g = LAMBDA_ADV * loss_adv + LAMBDA_L1 * loss_l1
        
        if USE_AMP:
            scaler_g.scale(loss_g).backward()
            scaler_g.step(optimizer_g)
            scaler_g.update()
        else:
            loss_g.backward()
            optimizer_g.step()
        
        # Track losses
        epoch_loss_g += loss_g.item()
        epoch_loss_d += loss_d.item()
        epoch_loss_l1 += loss_l1.item()
        epoch_loss_adv += loss_adv.item()
        
        # Update progress bar
        pbar.set_postfix({
            'G': f'{loss_g.item():.3f}',
            'D': f'{loss_d.item():.3f}',
            'L1': f'{loss_l1.item():.4f}',
        })
    
    # Average losses
    avg_loss_g = epoch_loss_g / len(train_loader)
    avg_loss_d = epoch_loss_d / len(train_loader)
    avg_loss_l1 = epoch_loss_l1 / len(train_loader)
    avg_loss_adv = epoch_loss_adv / len(train_loader)
    
    train_losses_g.append(avg_loss_g)
    train_losses_d.append(avg_loss_d)
    
    # ============ VALIDATION PHASE ============
    generator.eval()
    val_loss_total = 0.0
    
    with torch.no_grad():
        for batch in val_loader:
            input_3t = batch['input_3t'].to(device)
            target_7t = batch['target_7t'].to(device)
            
            fake_7t = generator(input_3t)
            loss = criterion_l1(fake_7t, target_7t)
            val_loss_total += loss.item()
    
    avg_val_loss = val_loss_total / len(val_loader)
    val_losses.append(avg_val_loss)
    
    # ============ LOGGING ============
    epoch_time = time.time() - epoch_start
    elapsed_total = time.time() - start_time
    
    print(f"\nEpoch {epoch}/{NUM_EPOCHS} ({epoch_time:.1f}s, total: {elapsed_total/60:.1f}m)")
    print(f"  Train - G: {avg_loss_g:.4f} | D: {avg_loss_d:.4f} | L1: {avg_loss_l1:.4f} | Adv: {avg_loss_adv:.4f}")
    print(f"  Val   - L1: {avg_val_loss:.4f} | Best: {min(val_losses):.4f}")
    
    # ============ SAVE CHECKPOINT ============
    if epoch % SAVE_INTERVAL == 0 or avg_val_loss < best_val_loss:
        checkpoint = {
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_g_state_dict': optimizer_g.state_dict(),
            'optimizer_d_state_dict': optimizer_d.state_dict(),
            'train_loss_g': train_losses_g,
            'train_loss_d': train_losses_d,
            'val_losses': val_losses,
            'config': {
                'batch_size': BATCH_SIZE,
                'patch_size': PATCH_SIZE,
                'lambda_l1': LAMBDA_L1,
                'lambda_adv': LAMBDA_ADV,
                'lr_g': LR_G,
                'lr_d': LR_D,
            }
        }
        
        # Regular checkpoint
        if epoch % SAVE_INTERVAL == 0:
            checkpoint_path = DIRS['checkpoints'] / f'checkpoint_epoch_{epoch}.pth'
            torch.save(checkpoint, checkpoint_path)
            print(f"  üíæ Saved: checkpoint_epoch_{epoch}.pth")
        
        # Best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(checkpoint, DIRS['checkpoints'] / 'best_model.pth')
            print(f"  ‚≠ê NEW BEST MODEL! Val loss: {avg_val_loss:.4f}")

total_time = time.time() - start_time
print(f"\n{'='*70}")
print(f"üéâ TRAINING COMPLETE!")
print(f"   Total time: {total_time/3600:.2f} hours")
print(f"   Best val loss: {best_val_loss:.4f}")
print(f"   Final epoch: {NUM_EPOCHS}")
print(f"{'='*70}")

In [None]:
# ============================================================================
# CELL 12: Plot Training Curves
# ============================================================================
print("\n" + "="*70)
print("VISUALIZING TRAINING PROGRESS")
print("="*70)

import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot 1: Generator and Discriminator losses
axes[0].plot(train_losses_g, label='Generator', linewidth=2, color='blue')
axes[0].plot(train_losses_d, label='Discriminator', linewidth=2, color='red')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Losses (G vs D)', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Plot 2: Validation loss
axes[1].plot(val_losses, label='Validation L1', color='green', linewidth=2)
axes[1].axhline(y=min(val_losses), color='r', linestyle='--', 
                label=f'Best: {min(val_losses):.4f}', linewidth=1.5)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('L1 Loss', fontsize=12)
axes[1].set_title('Validation Loss', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

# Plot 3: Loss ratio (G/D balance)
loss_ratio = np.array(train_losses_g) / (np.array(train_losses_d) + 1e-8)
axes[2].plot(loss_ratio, label='G/D Ratio', color='purple', linewidth=2)
axes[2].axhline(y=1.0, color='r', linestyle='--', label='Balanced (G=D)', linewidth=1.5)
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('Loss Ratio', fontsize=12)
axes[2].set_title('Generator/Discriminator Balance', fontsize=14, fontweight='bold')
axes[2].legend(fontsize=11)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()

# Save figure
curves_path = DIRS['logs'] / 'training_curves.png'
plt.savefig(curves_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"\n‚úì Training curves saved to: {curves_path.name}")
print(f"‚úì Best validation loss: {min(val_losses):.4f} at epoch {np.argmin(val_losses) + 1}")
print("="*70)

In [None]:
# ============================================================================
# CELL 13: Generate and Visualize Test Samples
# ============================================================================
print("\n" + "="*70)
print("GENERATING TEST SAMPLES")
print("="*70)

# Load best model
checkpoint = torch.load(DIRS['checkpoints'] / 'best_model.pth', map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])
generator.eval()

print(f"‚úì Loaded best model from epoch {checkpoint['epoch']}")

# Get validation samples
val_iter = iter(val_loader)
test_batch = next(val_iter)

input_3t = test_batch['input_3t'].to(device)
target_7t = test_batch['target_7t'].to(device)
subjects = test_batch['subject']

# Generate
print(f"‚úì Generating 7T images from 3T...")
with torch.no_grad():
    fake_7t = generator(input_3t)

print(f"‚úì Generated {fake_7t.shape[0]} samples")

# Visualization function
def visualize_comparison(input_3t, fake_7t, target_7t, subject, save_dir):
    """Compare 3T input, Generated 7T, and Real 7T"""
    # Move to CPU and get numpy arrays
    input_np = input_3t[0, 0].cpu().numpy()
    fake_np = fake_7t[0, 0].cpu().numpy()
    target_np = target_7t[0, 0].cpu().numpy()
    
    # Get middle slices
    d, h, w = input_np.shape
    slice_idx = {
        'axial': d // 2,
        'coronal': h // 2,
        'sagittal': w // 2,
    }
    
    # Create figure
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    fig.suptitle(f'Subject: {subject}', fontsize=16, fontweight='bold', y=0.98)
    
    titles = ['3T Input', 'Generated 7T', 'Real 7T']
    
    for col, (data, title) in enumerate(zip([input_np, fake_np, target_np], titles)):
        # Axial slice
        axes[0, col].imshow(data[slice_idx['axial'], :, :], cmap='gray', 
                           vmin=-3, vmax=3)
        axes[0, col].set_title(f'{title} - Axial', fontsize=12, fontweight='bold')
        axes[0, col].axis('off')
        
        # Coronal slice
        axes[1, col].imshow(data[:, slice_idx['coronal'], :], cmap='gray',
                           vmin=-3, vmax=3)
        axes[1, col].set_title(f'{title} - Coronal', fontsize=12, fontweight='bold')
        axes[1, col].axis('off')
        
        # Sagittal slice
        axes[2, col].imshow(data[:, :, slice_idx['sagittal']], cmap='gray',
                           vmin=-3, vmax=3)
        axes[2, col].set_title(f'{title} - Sagittal', fontsize=12, fontweight='bold')
        axes[2, col].axis('off')
    
    plt.tight_layout()
    
    # Save
    save_path = save_dir / f'generation_{subject}.png'
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    return save_path

# Visualize all samples in batch
print(f"\nüìä Creating visualizations...")
for i in range(input_3t.size(0)):
    vis_path = visualize_comparison(
        input_3t[i:i+1], 
        fake_7t[i:i+1], 
        target_7t[i:i+1],
        subjects[i],
        DIRS['visualizations']
    )
    print(f"‚úì Saved: {vis_path.name}")

# Display one example
example_path = DIRS['visualizations'] / f'generation_{subjects[0]}.png'
if example_path.exists():
    from IPython.display import Image, display
    display(Image(filename=str(example_path)))

print("\n‚úì All visualizations complete")
print("="*70)

In [None]:
# ============================================================================
# CELL 14: Compute Quantitative Metrics (PSNR, SSIM)
# ============================================================================
print("\n" + "="*70)
print("COMPUTING QUANTITATIVE METRICS")
print("="*70)

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def compute_metrics(fake, real):
    """Compute PSNR and SSIM between generated and real images."""
    fake_np = fake.cpu().numpy()
    real_np = real.cpu().numpy()
    
    # Normalize to [0, 1] for metrics
    fake_norm = (fake_np - fake_np.min()) / (fake_np.max() - fake_np.min() + 1e-8)
    real_norm = (real_np - real_np.min()) / (real_np.max() - real_np.min() + 1e-8)
    
    # Compute metrics
    psnr_val = psnr(real_norm, fake_norm, data_range=1.0)
    ssim_val = ssim(real_norm, fake_norm, data_range=1.0)
    
    return psnr_val, ssim_val

# Compute on full validation set
print("üìä Computing metrics on validation set...")
all_psnr = []
all_ssim = []

generator.eval()
with torch.no_grad():
    for batch in tqdm(val_loader, desc="Computing metrics"):
        input_3t = batch['input_3t'].to(device)
        target_7t = batch['target_7t'].to(device)
        
        fake_7t = generator(input_3t)
        
        # Compute for each sample in batch
        for i in range(input_3t.size(0)):
            psnr_val, ssim_val = compute_metrics(fake_7t[i, 0], target_7t[i, 0])
            all_psnr.append(psnr_val)
            all_ssim.append(ssim_val)

# Compute statistics
psnr_mean = np.mean(all_psnr)
psnr_std = np.std(all_psnr)
ssim_mean = np.mean(all_ssim)
ssim_std = np.std(all_ssim)

print(f"\n{'='*70}")
print(f"üìä QUANTITATIVE RESULTS:")
print(f"{'='*70}")
print(f"  PSNR: {psnr_mean:.2f} ¬± {psnr_std:.2f} dB")
print(f"       Min: {np.min(all_psnr):.2f} dB")
print(f"       Max: {np.max(all_psnr):.2f} dB")
print(f"\n  SSIM: {ssim_mean:.4f} ¬± {ssim_std:.4f}")
print(f"       Min: {np.min(all_ssim):.4f}")
print(f"       Max: {np.max(all_ssim):.4f}")
print(f"{'='*70}")

# Save metrics
metrics_summary = {
    'validation_metrics': {
        'psnr_mean': float(psnr_mean),
        'psnr_std': float(psnr_std),
        'psnr_min': float(np.min(all_psnr)),
        'psnr_max': float(np.max(all_psnr)),
        'ssim_mean': float(ssim_mean),
        'ssim_std': float(ssim_std),
        'ssim_min': float(np.min(all_ssim)),
        'ssim_max': float(np.max(all_ssim)),
    },
    'training_info': {
        'num_epochs': NUM_EPOCHS,
        'best_val_loss': float(best_val_loss),
        'best_epoch': int(np.argmin(val_losses) + 1),
        'final_g_loss': float(train_losses_g[-1]),
        'final_d_loss': float(train_losses_d[-1]),
    },
    'configuration': {
        'batch_size': BATCH_SIZE,
        'patch_size': list(PATCH_SIZE),
        'train_subjects': train_subjects,
        'val_subjects': val_subjects,
        'test_subjects': test_subjects,
        'modality': MODALITIES,
    },
    'created': datetime.now().isoformat(),
}

metrics_path = DIRS['logs'] / 'metrics_summary.json'
with open(metrics_path, 'w') as f:
    json.dump(metrics_summary, f, indent=2)

print(f"\n‚úì Metrics saved to: {metrics_path.name}")
print("="*70)

In [None]:
# ============================================================================
# CELL 15: Create Final Output Package
# ============================================================================
print("\n" + "="*70)
print("PACKAGING OUTPUTS FOR DOWNLOAD")
print("="*70)

import shutil

# Create final output directory
final_dir = DIRS['final_output']
final_dir.mkdir(exist_ok=True)

print("üì¶ Copying files to final output directory...\n")

# 1. Copy best model
shutil.copy(
    DIRS['checkpoints'] / 'best_model.pth',
    final_dir / 'best_generator.pth'
)
print("‚úì Copied: best_generator.pth")

# 2. Copy training curves
shutil.copy(
    DIRS['logs'] / 'training_curves.png',
    final_dir / 'training_curves.png'
)
print("‚úì Copied: training_curves.png")

# 3. Copy metrics
shutil.copy(
    DIRS['logs'] / 'metrics_summary.json',
    final_dir / 'metrics_summary.json'
)
print("‚úì Copied: metrics_summary.json")

# 4. Copy data split
shutil.copy(
    DIRS['cache'] / 'data_split.json',
    final_dir / 'data_split.json'
)
print("‚úì Copied: data_split.json")

# 5. Copy visualizations
vis_dest = final_dir / 'visualizations'
if vis_dest.exists():
    shutil.rmtree(vis_dest)
shutil.copytree(DIRS['visualizations'], vis_dest)
print(f"‚úì Copied: visualizations/ ({len(list(vis_dest.glob('*.png')))} files)")

# 6. Create README
readme_content = f"""# Topo-Brain GAN Training Results

**Generated on:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## Training Summary

### Configuration
- **Epochs:** {NUM_EPOCHS}
- **Batch Size:** {BATCH_SIZE}
- **Patch Size:** {PATCH_SIZE}
- **Modality:** {MODALITIES}
- **Device:** {device}

### Data Splits
- **Train:** {len(train_subjects)} subjects ({len(train_pairs)} pairs)
- **Validation:** {len(val_subjects)} subjects ({len(val_pairs)} pairs)
- **Test:** {len(test_subjects)} subjects ({len(test_pairs)} pairs)

### Training Subjects
{train_subjects}

### Validation Subjects
{val_subjects}

### Test Subjects
{test_subjects}

## Results

### Training Performance
- **Best Validation Loss:** {best_val_loss:.4f}
- **Best Epoch:** {np.argmin(val_losses) + 1}
- **Final Generator Loss:** {train_losses_g[-1]:.4f}
- **Final Discriminator Loss:** {train_losses_d[-1]:.4f}

### Quantitative Metrics (Validation Set)
- **PSNR:** {psnr_mean:.2f} ¬± {psnr_std:.2f} dB
- **SSIM:** {ssim_mean:.4f} ¬± {ssim_std:.4f}

### Interpretation
- **PSNR > 25 dB:** Good quality reconstruction
- **SSIM > 0.85:** High structural similarity
- Current results: {"‚úÖ Good" if psnr_mean > 25 and ssim_mean > 0.85 else "‚ö†Ô∏è Needs improvement"}

## Files Included

1. **best_generator.pth** - Trained generator weights (load with PyTorch)
2. **training_curves.png** - Loss plots over training
3. **metrics_summary.json** - Detailed quantitative metrics
4. **data_split.json** - Train/val/test split information
5. **visualizations/** - Sample generated 7T images
6. **README.md** - This file

## Usage

### Loading the Model

```python
import torch
from models.generator_unet3d import UNet3DGenerator

# Create model
generator = UNet3DGenerator()

# Load weights
checkpoint = torch.load('best_generator.pth')
generator.load_state_dict(checkpoint['generator_state_dict'])
generator.eval()

# Use for inference
with torch.no_grad():
    generated_7t = generator(input_3t)
```

### Next Steps

1. **Evaluate on test set** - Use test subjects for final validation
2. **Full volume inference** - Generate complete 7T volumes (not just patches)
3. **Clinical validation** - Assess with radiologist review
4. **Topology loss** - Add persistent homology loss for better anatomy preservation
5. **Self-supervised pretraining** - Use ADNI dataset for improved generalization

## Citation

If you use this model, please cite:

```
@misc{{topobrain2025,
  author = {{Your Name}},
  title = {{Topo-Brain: Topology-Preserving 3T-to-7T MRI Super-Resolution}},
  year = {{2025}},
  publisher = {{GitHub}},
  url = {{https://github.com/prabeshx12/Topo-Brain}}
}}
```

## Contact

For questions or issues, please open an issue on GitHub:
https://github.com/prabeshx12/Topo-Brain

---

**Training completed successfully! üéâ**
"""

with open(final_dir / 'README.md', 'w') as f:
    f.write(readme_content)
print("‚úì Created: README.md")

# Create archive for easy download
print(f"\nüì¶ Creating archive...")
archive_name = f"topobrain_gan_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

os.chdir(BASE_DIR)
!zip -r {archive_name}.zip gan_final_output/

archive_path = BASE_DIR / f"{archive_name}.zip"
if archive_path.exists():
    size_mb = archive_path.stat().st_size / (1024**2)
    print(f"\n‚úì Archive created: {archive_name}.zip ({size_mb:.1f} MB)")

print(f"\n{'='*70}")
print("üì• DOWNLOAD FILES:")
print(f"{'='*70}")
print(f"\n1. Single archive (recommended):")
print(f"   üì¶ {archive_name}.zip")
print(f"\n2. Individual files from gan_final_output/:")
print(f"   üìÑ best_generator.pth")
print(f"   üìÑ training_curves.png")
print(f"   üìÑ metrics_summary.json")
print(f"   üìÑ data_split.json")
print(f"   üìÅ visualizations/")
print(f"   üìÑ README.md")

print(f"\n{'='*70}")
print("‚úÖ ALL OUTPUTS PACKAGED AND READY!")
print(f"{'='*70}")

In [None]:
# ============================================================================
# CELL 16: Final Summary and Next Steps
# ============================================================================
print("\n" + "="*70)
print("üéâ GAN TRAINING COMPLETE!")
print("="*70)

print(f"\nüìä FINAL SUMMARY:")
print(f"   ‚Ä¢ Total training time: {(time.time() - start_time)/3600:.2f} hours")
print(f"   ‚Ä¢ Epochs completed: {NUM_EPOCHS}")
print(f"   ‚Ä¢ Best validation loss: {best_val_loss:.4f}")
print(f"   ‚Ä¢ PSNR: {psnr_mean:.2f} ¬± {psnr_std:.2f} dB")
print(f"   ‚Ä¢ SSIM: {ssim_mean:.4f} ¬± {ssim_std:.4f}")

print(f"\nüìÅ OUTPUT LOCATIONS:")
print(f"   ‚Ä¢ Checkpoints: {DIRS['checkpoints']}")
print(f"   ‚Ä¢ Visualizations: {DIRS['visualizations']}")
print(f"   ‚Ä¢ Logs: {DIRS['logs']}")
print(f"   ‚Ä¢ Final package: {DIRS['final_output']}")

print(f"\nüéØ RECOMMENDED NEXT STEPS:")
print(f"\n   1. EVALUATE ON TEST SET")
print(f"      ‚Üí Use held-out test subjects: {test_subjects}")
print(f"      ‚Üí Compute metrics on unseen data")
print(f"      ‚Üí Validate generalization")

print(f"\n   2. FULL VOLUME INFERENCE")
print(f"      ‚Üí Generate complete 7T volumes (not just patches)")
print(f"      ‚Üí Use sliding window with overlap")
print(f"      ‚Üí Save as NIfTI for clinical review")

print(f"\n   3. ADD TOPOLOGY PRESERVATION")
print(f"      ‚Üí Implement persistent homology loss")
print(f"      ‚Üí Preserve anatomical connectivity")
print(f"      ‚Üí Critical for Alzheimer's hippocampal analysis")

print(f"\n   4. SELF-SUPERVISED PRETRAINING")
print(f"      ‚Üí Download ADNI Alzheimer's dataset")
print(f"      ‚Üí Pretrain on large 3T cohort")
print(f"      ‚Üí Fine-tune on paired 3T-7T data")

print(f"\n   5. CLINICAL VALIDATION")
print(f"      ‚Üí Hippocampal volume measurement")
print(f"      ‚Üí Cortical thickness analysis")
print(f"      ‚Üí Radiologist quality assessment")

print(f"\nüìö REFERENCES:")
print(f"   ‚Ä¢ Repository: https://github.com/prabeshx12/Topo-Brain")
print(f"   ‚Ä¢ ADNI dataset: https://adni.loni.usc.edu/")
print(f"   ‚Ä¢ UNC 3T-7T dataset: (your current data)")

print(f"\n{'='*70}")
print("Thank you for using Topo-Brain!")
print("For questions: https://github.com/prabeshx12/Topo-Brain/issues")
print("="*70)