# Balance PINN Training - Two-Stage with GPU Optimization

**BEFORE RUNNING: Upload these files to Colab:**
1. `processed_data/` folder (containing batch_0.h5, batch_1.h5)
2. `user_ages.csv`
3. `improved_models.py`
4. `enhanced_datasets.py`
5. `training_utils.py`
6. `config.py`

In [None]:
# Install dependencies
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install matplotlib seaborn pyyaml h5py pandas numpy tqdm

In [None]:
# GPU Validation and Performance Setup
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast
import numpy as np
import time

print("=== GPU VALIDATION ===")

# Check CUDA availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"PyTorch version: {torch.__version__}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        print(f"GPU {i}: {props.name}")
        print(f"  Memory: {props.total_memory / 1024**3:.1f}GB")
        print(f"  Compute Capability: {props.major}.{props.minor}")
    
    # Set device
    device = torch.device("cuda:0")
    torch.cuda.set_device(0)
    
    # GPU memory test
    print("\n=== GPU MEMORY TEST ===")
    test_tensor = torch.randn(10000, 10000, device=device)
    memory_allocated = torch.cuda.memory_allocated() / 1024**3
    memory_reserved = torch.cuda.memory_reserved() / 1024**3
    print(f"Test allocation: {memory_allocated:.2f}GB allocated, {memory_reserved:.2f}GB reserved")
    del test_tensor
    torch.cuda.empty_cache()
    
    # GPU compute test
    print("\n=== GPU COMPUTE TEST ===")
    start_time = time.time()
    a = torch.randn(5000, 5000, device=device)
    b = torch.randn(5000, 5000, device=device)
    c = torch.matmul(a, b)
    torch.cuda.synchronize()  # Wait for GPU to finish
    compute_time = time.time() - start_time
    print(f"Matrix multiplication (5000x5000): {compute_time:.3f} seconds")
    
    if compute_time < 0.1:
        print("✅ GPU is working correctly and fast!")
    else:
        print("⚠️ GPU compute seems slow - check GPU utilization")
    
    del a, b, c
    torch.cuda.empty_cache()
    
else:
    print("❌ CUDA not available - using CPU (will be very slow!)")
    device = torch.device("cpu")

In [None]:
# Check uploaded files
import os
from pathlib import Path

print("=== FILE VERIFICATION ===")
required_files = [
    'improved_models.py',
    'enhanced_datasets.py',
    'training_utils.py', 
    'config.py',
    'user_ages.csv'
]

missing_files = []
for file in required_files:
    if os.path.exists(file):
        print(f"✅ {file}")
    else:
        print(f"❌ {file} - MISSING!")
        missing_files.append(file)

if os.path.exists('processed_data'):
    batch_files = list(Path('processed_data').glob('batch_*.h5'))
    print(f"✅ processed_data/ - {len(batch_files)} batch files found")
    for batch_file in batch_files:
        print(f"   {batch_file.name}")
else:
    print("❌ processed_data/ - MISSING!")
    missing_files.append('processed_data/')

if missing_files:
    print(f"\n❌ Please upload missing files: {missing_files}")
    raise SystemExit("Missing required files")
else:
    print("\n✅ All required files present!")

In [None]:
# Data inspection with GPU validation
import pandas as pd
import h5py

print("=== DATA INSPECTION ===")

# Check age data
try:
    age_df = pd.read_csv('user_ages.csv')
    print(f"Age data: {len(age_df)} subjects")
    print(f"Age range: {age_df['age'].min():.1f} - {age_df['age'].max():.1f} years")
    print(f"Age mean ± std: {age_df['age'].mean():.1f} ± {age_df['age'].std():.1f}")
except Exception as e:
    print(f"Error loading age data: {e}")

# Check batch files
try:
    total_subjects = set()
    total_points = 0
    
    for batch_file in Path('processed_data').glob('batch_*.h5'):
        with h5py.File(batch_file, 'r') as f:
            subjects = f['subject_id'][:]
            unique_subjects = set(subjects.astype(str))
            total_subjects.update(unique_subjects)
            points = len(f['t'][:])
            total_points += points
            print(f"{batch_file.name}: {len(unique_subjects)} subjects, {points:,} points")
    
    print(f"\nTotal: {len(total_subjects)} subjects, {total_points:,} data points")
    
    # GPU memory estimate
    if torch.cuda.is_available():
        estimated_memory = (total_points * 4 * 4) / 1024**3  # 4 floats per point, 4 bytes per float
        print(f"Estimated GPU memory needed: {estimated_memory:.1f}GB")
        
        available_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        if estimated_memory < available_memory * 0.8:
            print(f"✅ Data will fit in GPU memory ({available_memory:.1f}GB available)")
        else:
            print(f"⚠️ Data might not fit in GPU memory ({available_memory:.1f}GB available)")
            
except Exception as e:
    print(f"Error inspecting batch files: {e}")

In [None]:
# GPU-Optimized Configuration for A100
print("=== PERFORMANCE OPTIMIZATION ===")

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"GPU: {gpu_name}")
    print(f"GPU Memory: {gpu_memory_gb:.1f}GB")
    
    # A100 optimized configuration
    if 'A100' in gpu_name and gpu_memory_gb > 35:
        config = {
            # Data loading - optimized for A100
            'batch_size': 32768,        # Even larger batches
            'num_workers': 8,
            'pin_memory': True,
            'prefetch_factor': 4,
            'persistent_workers': True,
            
            # Training - A100 optimized
            'mixed_precision': True,    # Essential for A100
            'compile_model': True,      # PyTorch 2.0 compilation
            'gradient_checkpointing': False,  # A100 has enough memory
            
            # Physics computation - reduced for speed
            'use_simplified_physics': True,
            'physics_computation_frequency': 8,  # Every 8th batch only
            
            # Validation - less frequent for speed
            'val_frequency': 10,        # Validate every 10 epochs
            
            # Learning rates - optimized for large batches
            'stage1_lr': 3e-3,          # Higher LR for large batches
            'stage2_lr': 2e-3,
            
            # Epochs - reduced with better optimization
            'stage1_epochs': 30,        # Should converge faster
            'stage2_epochs': 20,
            
            # Memory management
            'empty_cache_frequency': 50,
        }
        print("🚀 A100 ULTRA-FAST CONFIG LOADED")
        print(f"Expected training time: 30-45 minutes total")
    else:
        # Fallback for other GPUs
        config = {
            'batch_size': 8192,
            'num_workers': 4,
            'mixed_precision': True,
            'physics_computation_frequency': 4,
            'stage1_epochs': 50,
            'stage2_epochs': 30,
            'stage1_lr': 2e-3,
            'stage2_lr': 1e-3,
        }
        print("⚡ Standard GPU optimization")
else:
    config = {
        'batch_size': 2048,
        'num_workers': 2,
        'mixed_precision': False,
        'stage1_epochs': 20,
        'stage2_epochs': 15,
    }
    print("⚠️ CPU fallback - will be slow")

# Add base configuration
config.update({
    'data_folder': 'processed_data',
    'age_csv_path': 'user_ages.csv',
    'train_ratio': 0.7,
    'val_ratio': 0.15,
    'test_ratio': 0.15,
    'random_seed': 42,
    'min_points_per_subject': 100,
    'stage1_physics_weight': 0.01,  # Low physics weight for speed
    'stage2_reg_weight': 0.1,
    'stage1_patience': 15,
    'stage2_patience': 10,
    'param_bounds': {
        'K': (500.0, 3000.0),
        'B': (20.0, 150.0),
        'tau': (0.05, 0.4)
    }
})

print(f"\nFinal config:")
print(f"  Batch size: {config['batch_size']:,}")
print(f"  Mixed precision: {config.get('mixed_precision', False)}")
print(f"  Stage 1 epochs: {config['stage1_epochs']}")
print(f"  Stage 2 epochs: {config['stage2_epochs']}")

In [None]:
# Setup imports and logging
import logging
from torch.utils.data import DataLoader, Subset
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
import json
import matplotlib.pyplot as plt

# Import our modules
from improved_models import SubjectPINN, AgeParameterModel
from enhanced_datasets import SubjectAwareDataset, create_subject_splits, create_filtered_dataset
from training_utils import SimplePhysicsLoss, ParameterRegularizationLoss, EarlyStopping

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print("✅ All imports successful")

In [None]:
# Setup optimized datasets
print("=== DATASET SETUP ===")

# Load dataset
dataset = SubjectAwareDataset(
    processed_data_folder=config['data_folder'],
    age_csv_path=config['age_csv_path'],
    min_points_per_subject=config['min_points_per_subject']
)

print(f"Dataset loaded: {len(dataset)} total points")
print(f"Valid subjects: {len(dataset.valid_subjects)}")

# Create subject splits
subject_splits = create_subject_splits(
    dataset,
    train_ratio=config['train_ratio'],
    val_ratio=config['val_ratio'],
    test_ratio=config['test_ratio'],
    random_seed=config['random_seed']
)

# Create filtered datasets
train_indices = create_filtered_dataset(dataset, subject_splits['train'])
val_indices = create_filtered_dataset(dataset, subject_splits['val'])
test_indices = create_filtered_dataset(dataset, subject_splits['test'])

print(f"Data splits:")
print(f"  Train: {len(subject_splits['train'])} subjects, {len(train_indices):,} points")
print(f"  Val:   {len(subject_splits['val'])} subjects, {len(val_indices):,} points")
print(f"  Test:  {len(subject_splits['test'])} subjects, {len(test_indices):,} points")

# Create optimized data loaders
train_loader = DataLoader(
    Subset(dataset, train_indices),
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=config.get('num_workers', 4),
    pin_memory=config.get('pin_memory', True),
    prefetch_factor=config.get('prefetch_factor', 2),
    persistent_workers=config.get('persistent_workers', False),
    drop_last=True  # Consistent batch sizes
)

val_loader = DataLoader(
    Subset(dataset, val_indices),
    batch_size=config['batch_size'],
    shuffle=False,
    num_workers=config.get('num_workers', 4),
    pin_memory=config.get('pin_memory', True)
)

print(f"Data loaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Batch size: {config['batch_size']:,}")

In [None]:
# STAGE 1: Subject Parameter Learning
print("=" * 60)
print("STAGE 1: TRAINING SUBJECT PINN")
print("=" * 60)

# Create Stage 1 model
subject_pinn = SubjectPINN(
    subject_ids=dataset.valid_subjects,
    hidden_dims=[256, 256, 256, 256],
    param_bounds=config['param_bounds']
).to(device)

# Model compilation for faster execution (PyTorch 2.0+)
if config.get('compile_model', False) and hasattr(torch, 'compile'):
    try:
        subject_pinn = torch.compile(subject_pinn)
        print("✅ Model compiled with torch.compile")
    except:
        print("⚠️ Model compilation failed, using standard model")

# Optimized optimizer
stage1_optimizer = torch.optim.AdamW(
    subject_pinn.parameters(),
    lr=config['stage1_lr'],
    weight_decay=1e-5,
    eps=1e-6,  # More stable for mixed precision
    betas=(0.9, 0.95)  # Optimized for large batches
)

# Learning rate scheduler
stage1_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    stage1_optimizer,
    T_max=config['stage1_epochs'],
    eta_min=config['stage1_lr'] * 0.01
)

# Loss function
stage1_loss_fn = SimplePhysicsLoss(weight=config['stage1_physics_weight']).to(device)
stage1_early_stopping = EarlyStopping(patience=config['stage1_patience'])

# Mixed precision setup
scaler = GradScaler() if config.get('mixed_precision', False) else None
use_amp = config.get('mixed_precision', False)
physics_freq = config.get('physics_computation_frequency', 1)

print(f"Stage 1 model: {sum(p.numel() for p in subject_pinn.parameters()):,} parameters")
print(f"Mixed precision: {use_amp}")
print(f"Physics computation: every {physics_freq} batches")
print(f"Expected time per epoch: 2-3 minutes (A100)")

In [None]:
# Stage 1 Training Loop with GPU Performance Monitoring
import time

best_val_loss = float('inf')
stage1_metrics = []
val_frequency = config.get('val_frequency', 5)

# Performance monitoring
batch_times = []
epoch_start_time = time.time()

for epoch in range(config['stage1_epochs']):
    epoch_start = time.time()
    
    # Training
    subject_pinn.train()
    train_losses = defaultdict(float)
    train_samples = 0
    
    pbar = tqdm(train_loader, desc=f"Stage 1 Epoch {epoch+1}")
    for batch_idx, (t, age, xy_true, subject_idx) in enumerate(pbar):
        batch_start = time.time()
        
        # Move to GPU with non_blocking for faster transfer
        t = t.to(device, non_blocking=True).requires_grad_(True)
        age = age.to(device, non_blocking=True)
        xy_true = xy_true.to(device, non_blocking=True)
        subject_idx = subject_idx.to(device, non_blocking=True)
        
        # Forward pass with mixed precision
        with autocast(enabled=use_amp):
            xy_pred, params = subject_pinn(t, subject_idx)
            
            # Data loss
            data_loss = nn.functional.mse_loss(xy_pred, xy_true)
            
            # Physics loss (computed less frequently)
            if batch_idx % physics_freq == 0:
                physics_loss = stage1_loss_fn(t, xy_pred, params)
            else:
                physics_loss = torch.tensor(0.0, device=device)
            
            total_loss = data_loss + physics_loss
        
        # Backward pass with mixed precision
        if use_amp:
            scaler.scale(total_loss).backward()
            scaler.unscale_(stage1_optimizer)
            torch.nn.utils.clip_grad_norm_(subject_pinn.parameters(), max_norm=1.0)
            scaler.step(stage1_optimizer)
            scaler.update()
            stage1_optimizer.zero_grad()
        else:
            stage1_optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(subject_pinn.parameters(), max_norm=1.0)
            stage1_optimizer.step()
        
        # Track metrics
        batch_size = t.shape[0]
        train_losses['data'] += data_loss.item() * batch_size
        train_losses['physics'] += physics_loss.item() * batch_size
        train_losses['total'] += (data_loss + physics_loss).item() * batch_size
        train_samples += batch_size
        
        # Performance monitoring
        batch_time = time.time() - batch_start
        batch_times.append(batch_time)
        samples_per_sec = batch_size / batch_time
        
        # Update progress
        pbar.set_postfix({
            'Data': f"{data_loss.item():.1f}",
            'Physics': f"{physics_loss.item():.4f}",
            'Speed': f"{samples_per_sec:.0f} smp/s"
        })
        
        # Memory management
        if batch_idx % config.get('empty_cache_frequency', 50) == 0:
            torch.cuda.empty_cache()
    
    # Epoch timing
    epoch_time = time.time() - epoch_start
    avg_batch_time = np.mean(batch_times[-len(train_loader):])
    avg_samples_per_sec = config['batch_size'] / avg_batch_time
    
    # Learning rate step
    stage1_scheduler.step()
    
    # Calculate average losses
    avg_train_losses = {k: v / train_samples for k, v in train_losses.items()}
    
    # Validation (less frequent for speed)
    if epoch % val_frequency == 0 or epoch == config['stage1_epochs'] - 1:
        # Validation pass
        subject_pinn.eval()
        val_losses = defaultdict(float)
        val_samples = 0
        
        with torch.no_grad():
            for t, age, xy_true, subject_idx in val_loader:
                t = t.to(device, non_blocking=True).requires_grad_(True)
                xy_true = xy_true.to(device, non_blocking=True)
                subject_idx = subject_idx.to(device, non_blocking=True)
                
                with autocast(enabled=use_amp):
                    xy_pred, params = subject_pinn(t, subject_idx)
                    data_loss = nn.functional.mse_loss(xy_pred, xy_true)
                    physics_loss = stage1_loss_fn(t, xy_pred, params)
                    total_loss = data_loss + physics_loss
                
                batch_size = t.shape[0]
                val_losses['data'] += data_loss.item() * batch_size
                val_losses['physics'] += physics_loss.item() * batch_size
                val_losses['total'] += total_loss.item() * batch_size
                val_samples += batch_size
        
        avg_val_losses = {k: v / val_samples for k, v in val_losses.items()}
        val_loss = avg_val_losses['total']
        
        # GPU memory status
        if torch.cuda.is_available():
            memory_used = torch.cuda.memory_allocated() / 1024**3
            memory_cached = torch.cuda.memory_reserved() / 1024**3
            gpu_util_info = f"GPU: {memory_used:.1f}GB used, {memory_cached:.1f}GB cached"
        else:
            gpu_util_info = "CPU mode"
        
        print(f"\nEpoch {epoch+1}/{config['stage1_epochs']} - {epoch_time:.1f}s")
        print(f"  Train: Data={avg_train_losses['data']:.1f}, Physics={avg_train_losses['physics']:.6f}")
        print(f"  Val:   Data={avg_val_losses['data']:.1f}, Physics={avg_val_losses['physics']:.6f}")
        print(f"  Speed: {avg_samples_per_sec:.0f} samples/sec")
        print(f"  {gpu_util_info}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'model_state_dict': subject_pinn.state_dict(),
                'epoch': epoch,
                'val_loss': val_loss,
                'config': config
            }, 'best_stage1_model.pth')
            print(f"  ✅ Best model saved (val_loss={val_loss:.6f})")
        
        # Early stopping check
        if stage1_early_stopping(val_loss, subject_pinn):
            print(f"  🛑 Early stopping at epoch {epoch+1}")
            break
    else:
        print(f"Epoch {epoch+1}/{config['stage1_epochs']} - {epoch_time:.1f}s - Train Loss: {avg_train_losses['total']:.1f}")

print(f"\n✅ STAGE 1 COMPLETED!")
print(f"Best validation loss: {best_val_loss:.6f}")
print(f"Total Stage 1 time: {(time.time() - epoch_start_time)/60:.1f} minutes")

In [None]:
# Extract Subject Parameters from Stage 1
print("=== EXTRACTING SUBJECT PARAMETERS ===")

subject_pinn.eval()
subject_parameters = {}

with torch.no_grad():
    for i, subject_id in enumerate(dataset.valid_subjects):
        subject_idx = torch.tensor([[i]], dtype=torch.long).to(device)
        K, B, tau = subject_pinn.get_parameters(subject_idx.squeeze())
        
        subject_info = dataset.get_subject_info(subject_id)
        age = subject_info.get('age', 0)
        
        subject_parameters[subject_id] = {
            'age': age,
            'K': K.item(),
            'B': B.item(),
            'tau': tau.item(),
            'n_points': subject_info.get('n_points', 0)
        }

# Save parameters
with open('subject_parameters.json', 'w') as f:
    json.dump(subject_parameters, f, indent=2)

# Parameter statistics
ages = [p['age'] for p in subject_parameters.values()]
Ks = [p['K'] for p in subject_parameters.values()]
Bs = [p['B'] for p in subject_parameters.values()]
taus = [p['tau'] for p in subject_parameters.values()]

print(f"Extracted parameters for {len(subject_parameters)} subjects:")
print(f"  Age range: {min(ages):.1f} - {max(ages):.1f}")
print(f"  K range: {min(Ks):.1f} - {max(Ks):.1f}")
print(f"  B range: {min(Bs):.1f} - {max(Bs):.1f}")
print(f"  τ range: {min(taus):.3f} - {max(taus):.3f}")

# Check parameter variation (should be >0.1 for good learning)
K_cv = np.std(Ks) / np.mean(Ks)
B_cv = np.std(Bs) / np.mean(Bs)
tau_cv = np.std(taus) / np.mean(taus)

print(f"\nParameter variation (coefficient of variation):")
print(f"  K: {K_cv:.3f} {'✅' if K_cv > 0.1 else '❌'} (need >0.1)")
print(f"  B: {B_cv:.3f} {'✅' if B_cv > 0.1 else '❌'} (need >0.1)")
print(f"  τ: {tau_cv:.3f} {'✅' if tau_cv > 0.1 else '❌'} (need >0.1)")

if all(cv > 0.1 for cv in [K_cv, B_cv, tau_cv]):
    print("\n✅ Good parameter variation - Stage 1 learned meaningful differences!")
else:
    print("\n⚠️ Low parameter variation - may need more training or different loss weights")

In [None]:
# STAGE 2: Age Parameter Learning
print("=" * 60)
print("STAGE 2: TRAINING AGE PARAMETER MODEL")
print("=" * 60)

# Create Stage 2 model
age_model = AgeParameterModel(
    param_bounds=config['param_bounds'],
    use_probabilistic=True
).to(device)

# Stage 2 optimizer
stage2_optimizer = torch.optim.AdamW(
    age_model.parameters(),
    lr=config['stage2_lr'],
    weight_decay=1e-5
)

stage2_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    stage2_optimizer,
    T_max=config['stage2_epochs']
)

# Loss functions
stage2_reg_loss = ParameterRegularizationLoss(
    smoothness_weight=0.1,
    variation_weight=0.1, 
    param_bounds=config['param_bounds']
).to(device)

stage2_early_stopping = EarlyStopping(patience=config['stage2_patience'])

print(f"Stage 2 model: {sum(p.numel() for p in age_model.parameters()):,} parameters")

# Prepare Stage 2 data (age -> parameters)
train_ages = []
train_params = []
val_ages = []
val_params = []

for subject_id in subject_splits['train']:
    if subject_id in subject_parameters:
        param_data = subject_parameters[subject_id]
        train_ages.append(param_data['age'])
        train_params.append([param_data['K'], param_data['B'], param_data['tau']])

for subject_id in subject_splits['val']:
    if subject_id in subject_parameters:
        param_data = subject_parameters[subject_id]
        val_ages.append(param_data['age'])
        val_params.append([param_data['K'], param_data['B'], param_data['tau']])

# Convert to tensors
train_ages_tensor = torch.tensor(train_ages, dtype=torch.float32).unsqueeze(-1)
train_params_tensor = torch.tensor(train_params, dtype=torch.float32)
val_ages_tensor = torch.tensor(val_ages, dtype=torch.float32).unsqueeze(-1)
val_params_tensor = torch.tensor(val_params, dtype=torch.float32)

print(f"Stage 2 data: {len(train_ages)} train, {len(val_ages)} val subjects")

In [None]:
# Stage 2 Training Loop
best_stage2_val_loss = float('inf')
stage2_metrics = []

stage2_start_time = time.time()

for epoch in range(config['stage2_epochs']):
    epoch_start = time.time()
    
    # Training
    age_model.train()
    
    # Shuffle training data
    indices = torch.randperm(len(train_ages))
    train_ages_shuffled = train_ages_tensor[indices].to(device)
    train_params_shuffled = train_params_tensor[indices].to(device)
    
    train_losses = defaultdict(float)
    n_batches = 0
    
    # Mini-batch training for Stage 2
    batch_size = min(32, len(train_ages))  # Smaller batches for Stage 2
    
    for i in range(0, len(train_ages), batch_size):
        batch_ages = train_ages_shuffled[i:i+batch_size]
        batch_params = train_params_shuffled[i:i+batch_size]
        
        # Forward pass
        with autocast(enabled=use_amp):
            pred_means, pred_stds = age_model.predict_parameters(batch_ages)
            
            # Negative log-likelihood loss
            param_loss = 0.5 * torch.mean(
                ((batch_params - pred_means) / (pred_stds + 1e-6))**2 + 
                torch.log(pred_stds + 1e-6)
            )
            
            # Regularization
            reg_losses = stage2_reg_loss(batch_ages, batch_params)
            total_reg_loss = sum(reg_losses.values())
            
            total_loss = param_loss + config['stage2_reg_weight'] * total_reg_loss
        
        # Backward pass
        if use_amp:
            scaler.scale(total_loss).backward()
            scaler.step(stage2_optimizer)
            scaler.update()
            stage2_optimizer.zero_grad()
        else:
            stage2_optimizer.zero_grad()
            total_loss.backward()
            stage2_optimizer.step()
        
        # Track losses
        train_losses['param'] += param_loss.item()
        train_losses['reg'] += total_reg_loss.item()
        train_losses['total'] += total_loss.item()
        n_batches += 1
    
    # Learning rate step
    stage2_scheduler.step()
    
    # Average losses
    avg_train_losses = {k: v / n_batches for k, v in train_losses.items()}
    
    # Validation
    age_model.eval()
    with torch.no_grad():
        val_ages_gpu = val_ages_tensor.to(device)
        val_params_gpu = val_params_tensor.to(device)
        
        with autocast(enabled=use_amp):
            pred_means, pred_stds = age_model.predict_parameters(val_ages_gpu)
            val_param_loss = 0.5 * torch.mean(
                ((val_params_gpu - pred_means) / (pred_stds + 1e-6))**2 + 
                torch.log(pred_stds + 1e-6)
            )
            val_reg_losses = stage2_reg_loss(val_ages_gpu, val_params_gpu)
            val_total_reg = sum(val_reg_losses.values())
            val_total_loss = val_param_loss + config['stage2_reg_weight'] * val_total_reg
    
    epoch_time = time.time() - epoch_start
    
    print(f"Stage 2 Epoch {epoch+1}/{config['stage2_epochs']} - {epoch_time:.1f}s")
    print(f"  Train: Param={avg_train_losses['param']:.6f}, Reg={avg_train_losses['reg']:.6f}")
    print(f"  Val:   Param={val_param_loss.item():.6f}, Reg={val_total_reg.item():.6f}")
    
    # Save best model
    if val_total_loss.item() < best_stage2_val_loss:
        best_stage2_val_loss = val_total_loss.item()
        torch.save({
            'subject_pinn_state_dict': subject_pinn.state_dict(),
            'age_model_state_dict': age_model.state_dict(),
            'subject_parameters': subject_parameters,
            'config': config
        }, 'best_two_stage_model.pth')
        print(f"  ✅ Best two-stage model saved")
    
    # Early stopping
    if stage2_early_stopping(val_total_loss.item(), age_model):
        print(f"  🛑 Early stopping at epoch {epoch+1}")
        break

print(f"\n✅ STAGE 2 COMPLETED!")
print(f"Best Stage 2 validation loss: {best_stage2_val_loss:.6f}")
print(f"Total Stage 2 time: {(time.time() - stage2_start_time)/60:.1f} minutes")

In [None]:
# Age Relationship Analysis and Visualization
print("=== AGE RELATIONSHIP ANALYSIS ===")

age_model.eval()

# Generate predictions across age range
ages_test = np.linspace(20, 90, 100)
age_tensor = torch.tensor(ages_test, dtype=torch.float32).unsqueeze(-1).to(device)

with torch.no_grad():
    pred_means, pred_stds = age_model.predict_parameters(age_tensor)
    pred_means = pred_means.cpu().numpy()
    pred_stds = pred_stds.cpu().numpy()

# Create visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

param_names = ['Stiffness (K)', 'Damping (B)', 'Neural Delay (τ)']
subject_params = [[p['K'] for p in subject_parameters.values()],
                  [p['B'] for p in subject_parameters.values()],
                  [p['tau'] for p in subject_parameters.values()]]
subject_ages = [p['age'] for p in subject_parameters.values()]

for i, (name, subject_param) in enumerate(zip(param_names, subject_params)):
    ax = axes[i]
    
    # Plot subject data
    ax.scatter(subject_ages, subject_param, alpha=0.6, s=20, color='blue', label='Subjects')
    
    # Plot learned curve
    ax.plot(ages_test, pred_means[:, i], 'red', linewidth=2, label='Learned Trend')
    
    # Plot uncertainty
    ax.fill_between(ages_test, 
                   pred_means[:, i] - pred_stds[:, i],
                   pred_means[:, i] + pred_stds[:, i],
                   alpha=0.2, color='red', label='Uncertainty')
    
    ax.set_xlabel('Age (years)')
    ax.set_ylabel(name)
    ax.set_title(f'{name} vs Age')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('parameter_age_relationships.png', dpi=150, bbox_inches='tight')
plt.show()

# Calculate age correlations
from scipy.stats import pearsonr

correlations = {}
for i, param_name in enumerate(['K', 'B', 'tau']):
    param_values = [subject_parameters[sid][param_name] for sid in subject_splits['train'] if sid in subject_parameters]
    param_ages = [subject_parameters[sid]['age'] for sid in subject_splits['train'] if sid in subject_parameters]
    
    if len(param_values) > 3:
        corr, p_value = pearsonr(param_ages, param_values)
        correlations[param_name] = {'correlation': corr, 'p_value': p_value}
        print(f"{param_name}-age correlation: {corr:.3f} (p={p_value:.3f})")

print("\n✅ Age relationship analysis complete!")

In [None]:
# Final Model Testing and Performance Summary
print("=== FINAL MODEL TESTING ===")

# Load best model
checkpoint = torch.load('best_two_stage_model.pth')
subject_pinn.load_state_dict(checkpoint['subject_pinn_state_dict'])
age_model.load_state_dict(checkpoint['age_model_state_dict'])

print("✅ Best models loaded")

# Test age comparison functionality
def compare_ages(age1: float, age2: float):
    """Compare balance parameters between two ages."""
    age1_tensor = torch.tensor([[age1]], device=device)
    age2_tensor = torch.tensor([[age2]], device=device)
    
    with torch.no_grad():
        params1_mean, _ = age_model.predict_parameters(age1_tensor)
        params2_mean, _ = age_model.predict_parameters(age2_tensor)
        
        # Calculate parameter differences
        param_diff = (params2_mean - params1_mean).cpu().numpy().squeeze()
        
        return {
            'age1_params': params1_mean.cpu().numpy().squeeze(),
            'age2_params': params2_mean.cpu().numpy().squeeze(),
            'differences': param_diff
        }

# Test age comparisons
print("\n=== AGE COMPARISON TEST ===")
test_comparisons = [
    (30, 60),  # Young vs middle-aged
    (40, 70),  # Middle-aged vs older
    (60, 80),  # Older vs elderly
]

for age1, age2 in test_comparisons:
    comparison = compare_ages(age1, age2)
    print(f"\nAge {age1} vs {age2}:")
    print(f"  K: {comparison['age1_params'][0]:.1f} → {comparison['age2_params'][0]:.1f} (Δ={comparison['differences'][0]:+.1f})")
    print(f"  B: {comparison['age1_params'][1]:.1f} → {comparison['age2_params'][1]:.1f} (Δ={comparison['differences'][1]:+.1f})")
    print(f"  τ: {comparison['age1_params'][2]:.3f} → {comparison['age2_params'][2]:.3f} (Δ={comparison['differences'][2]:+.3f})")

print("\n✅ Age comparison functionality working!")

In [None]:
# Training Summary and Model Download
total_training_time = time.time() - epoch_start_time

print("=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)

print(f"\n📊 PERFORMANCE SUMMARY:")
if torch.cuda.is_available():
    final_memory = torch.cuda.memory_allocated() / 1024**3
    max_memory = torch.cuda.max_memory_allocated() / 1024**3
    print(f"  GPU Memory: {final_memory:.1f}GB current, {max_memory:.1f}GB peak")

print(f"  Total training time: {total_training_time/60:.1f} minutes")
print(f"  Stage 1 epochs: {config['stage1_epochs']}")
print(f"  Stage 2 epochs: {config['stage2_epochs']}")
print(f"  Final batch size: {config['batch_size']:,}")

print(f"\n🎯 MODEL RESULTS:")
print(f"  Subjects processed: {len(subject_parameters)}")
print(f"  Parameter variation: K={K_cv:.3f}, B={B_cv:.3f}, τ={tau_cv:.3f}")
print(f"  Age relationships learned: {'✅' if all(abs(corr['correlation']) > 0.1 for corr in correlations.values()) else '⚠️'}")

print(f"\n📁 SAVED FILES:")
print(f"  best_two_stage_model.pth - Complete trained model")
print(f"  subject_parameters.json - Individual subject parameters")
print(f"  parameter_age_relationships.png - Visualization")

print(f"\n🚀 Ready for cross-age balance comparison!")

# Download instructions
from google.colab import files
print(f"\n📥 DOWNLOAD TRAINED MODELS:")
print(f"Run these commands to download your trained models:")
print(f"files.download('best_two_stage_model.pth')")
print(f"files.download('subject_parameters.json')")
print(f"files.download('parameter_age_relationships.png')")