In [5]:
"""
Improved MCI-to-AD Conversion Prediction
Generates latent features for JMBayes2 joint modeling in R

Key Enhancements:
1. Efron's Cox loss for tied event times
2. Future MMSE prediction (not current)
3. Enhanced clinical features
4. Gradient accumulation
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import numpy as np
import os
import pickle
from lifelines.utils import concordance_index
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION
# ============================================================================
CONFIG = {
    'latent_dim': 128,
    'img_out_dim': 256,
    'tab_out_dim': 64,
    'lstm_hidden': 128,
    'lstm_layers': 2,
    'dropout': 0.3,  # Back to 0.3 (was 0.4)
    'lr': 5e-4,      # Back to original
    'weight_decay': 1e-4,
    'epochs': 100,   # Back to 100
    'batch_size': 16,
    'accumulation_steps': 1,  # Disable gradient accumulation for now
    'max_seq_len': 10,
    'alpha_survival': 0.85,  # Increased focus on survival
    'alpha_mmse': 0.15,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}

# ============================================================================
# FEATURE DEFINITIONS
# ============================================================================
DEMOGRAPHIC_COLUMNS = [
    'AGE', 'PTGENDER_encoded', 'PTEDUCAT', 'PTETHCAT_encoded', 
    'PTRACCAT_encoded', 'PTMARRY_encoded'
]

STATIC_FEATURES = [
    'age_bl', 'PTGENDER_encoded', 'PTEDUCAT', 'PTETHCAT_encoded', 
    'PTRACCAT_encoded', 'PTMARRY_encoded'
]

TEMPORAL_FEATURES = [
    'time_from_baseline', 'AGE', 'age_since_bl', 'mmse_slope', 
    'adas13_slope', 'dx_progression', 'cog_decline_index', 
    'visit_number', 'MMSE', 'ADAS13'
]

# ============================================================================
# FEATURE ENGINEERING
# ============================================================================

def add_future_mmse_label(df, horizon=1.0, window=0.25):
    """Add MMSE score ~12 months later"""
    future_mmse = []
    for i, row in df.iterrows():
        current_time = row["Years_bl"]
        target_time = current_time + horizon
        mask = (df["Years_bl"] >= target_time - window) & \
               (df["Years_bl"] <= target_time + window)
        candidates = df[mask]
        if len(candidates) > 0:
            future_mmse.append(candidates["MMSE"].iloc[0])
        else:
            future_mmse.append(np.nan)
    df["MMSE_future12"] = future_mmse
    return df

def engineer_features(df):
    """Enhanced feature engineering"""
    df = df.copy()
    
    # Basic temporal features
    df["time_from_baseline"] = df["Years_bl"] - df["Years_bl"].iloc[0]
    df["age_bl"] = df["AGE"].iloc[0]
    df["age_since_bl"] = df["AGE"] - df["age_bl"]
    
    # Cognitive decline rates
    df["mmse_slope"] = df["MMSE"].diff() / df["Years_bl"].diff()
    df["adas13_slope"] = df["ADAS13"].diff() / df["Years_bl"].diff()
    
    # Diagnosis progression
    dx_map = {"CN": 0, "MCI": 1, "AD": 2, "Dementia": 2}
    df["dx_progression"] = df["DX"].map(dx_map).diff()
    
    # Cognitive decline index
    df["cog_decline_index"] = df["ADAS13"] - df["MMSE"]
    
    # Visit order
    df["visit_number"] = range(len(df))
    
    # Clinical interaction features
    df['age_mmse_interaction'] = df['AGE'] * (30 - df['MMSE']) / 30
    df['education_cognitive_reserve'] = df['PTEDUCAT'] * df['MMSE'] / 30
    df['rapid_decline_flag'] = (df['mmse_slope'] < -2).astype(float)
    
    # MMSE severity
    mmse_bins = [0, 10, 20, 24, 30]
    df['mmse_severity'] = pd.cut(df['MMSE'], bins=mmse_bins, 
                                  labels=[3, 2, 1, 0]).astype(float)
    
    # Time-weighted decline
    df['weighted_mmse_decline'] = df['mmse_slope'] * np.exp(-0.1 * df['time_from_baseline'])
    
    # Cognitive variability
    df['mmse_variability'] = df['MMSE'].rolling(window=3, min_periods=1).std()
    
    # ADAS13-MMSE discordance
    df['adas_mmse_discordance'] = np.abs(
        (df['ADAS13'] - df['ADAS13'].mean()) / (df['ADAS13'].std() + 1e-7) - 
        (df['MMSE'] - df['MMSE'].mean()) / (df['MMSE'].std() + 1e-7)
    )
    
    # MMSE acceleration
    df['mmse_acceleration'] = df['mmse_slope'].diff() / df['Years_bl'].diff()
    
    # Cognitive efficiency
    df['cognitive_efficiency'] = df['MMSE'] / (df['PTEDUCAT'] + 1)
    
    # Age-adjusted MMSE
    expected_decline = np.maximum(0, (df['AGE'] - 65) * 0.2)
    df['age_adjusted_mmse'] = df['MMSE'] + expected_decline
    
    # Fill NaNs
    df = df.fillna(method='ffill').fillna(method='bfill').fillna(0)
    
    # Add future MMSE label
    df = add_future_mmse_label(df)
    
    return df

# Update feature lists
TEMPORAL_FEATURES.extend([
    'age_mmse_interaction', 'education_cognitive_reserve', 'rapid_decline_flag',
    'mmse_severity', 'weighted_mmse_decline', 'mmse_variability', 
    'adas_mmse_discordance', 'mmse_acceleration', 'cognitive_efficiency',
    'age_adjusted_mmse'
])

# ============================================================================
# DATASET
# ============================================================================
class SequenceDataset(Dataset):
    """Dataset providing patient-level sequences"""

    def __init__(self, manifest, valid_patients, transform=None, max_seq_len=10):
        self.transform = transform
        self.max_seq_len = max_seq_len
        self.sequences = []

        manifest = manifest.copy()
        manifest["path"] = manifest["path"].str.replace("\\", "/", regex=False)
        manifest["path"] = "./AD_Multimodal/TFN_AD/" + manifest["path"]

        skipped_not_valid = 0
        processed = 0

        for ptid in manifest["PTID"].unique():
            if ptid not in valid_patients:
                skipped_not_valid += 1
                continue
            try:
                patient_rows = manifest[manifest["PTID"] == ptid]
                if len(patient_rows) == 0:
                    continue

                df = pd.read_pickle(patient_rows.iloc[0]["path"])
                df = engineer_features(df)

                dx_seq = df["DX"].tolist()
                if "MCI" not in dx_seq:
                    continue

                mci_idx = dx_seq.index("MCI")
                ad_idx = next(
                    (i for i, x in enumerate(dx_seq[mci_idx + 1:], start=mci_idx + 1)
                     if x in ["AD", "Dementia"]),
                    -1
                )

                if ad_idx != -1:
                    time_to_event = (
                        df["Years_bl"].iloc[ad_idx] -
                        df["Years_bl"].iloc[mci_idx]
                    )
                    event = 1
                else:
                    time_to_event = (
                        df["Years_bl"].iloc[-1] -
                        df["Years_bl"].iloc[mci_idx]
                    )
                    event = 0

                imgs, tabs, times, mmse, mmse_future = [], [], [], [], []
                valid_visits = 0

                for _, visit in df.iterrows():
                    image_path = visit["image_path"].replace(
                        "/home/mason/ADNI_Dataset/",
                        "./AD_Multimodal/ADNI_Dataset/"
                    )

                    if not os.path.exists(image_path):
                        continue

                    img = Image.open(image_path).convert("RGB")
                    if self.transform:
                        img = self.transform(img)

                    imgs.append(img)
                    tabs.append(
                        visit[TEMPORAL_FEATURES + STATIC_FEATURES]
                        .values.astype(np.float32)
                    )
                    times.append(visit["Years_bl"])
                    mmse.append(visit["MMSE"])
                    mmse_future.append(
                        visit["MMSE_future12"]
                        if "MMSE_future12" in visit else np.nan
                    )

                    valid_visits += 1
                    if valid_visits >= self.max_seq_len:
                        break

                if valid_visits < 2:
                    continue

                pad_len = self.max_seq_len - valid_visits
                if pad_len > 0:
                    for _ in range(pad_len):
                        imgs.append(torch.zeros_like(imgs[-1]))
                        tabs.append(np.zeros_like(tabs[-1]))
                        times.append(times[-1])
                        mmse.append(np.nan)
                        mmse_future.append(np.nan)

                self.sequences.append({
                    "ptid": ptid,
                    "imgs": torch.stack(imgs),
                    "tabs": np.array(tabs, dtype=np.float32),
                    "times": np.array(times, dtype=np.float32),
                    "mmse": np.array(mmse, dtype=np.float32),
                    "mmse_future": np.array(mmse_future, dtype=np.float32),
                    "seq_len": valid_visits,
                    "time_to_event": float(time_to_event),
                    "event": int(event)
                })

                processed += 1

            except Exception as e:
                print(f"⚠️ Skipped patient {ptid}: {e}")
                continue

        print(f"  Processed: {processed} valid patients")
        print(f"  Skipped (not in valid set): {skipped_not_valid}")

        # all_tabs = np.vstack([seq["tabs"] for seq in self.sequences])
        # self.scaler = StandardScaler()
        # self.scaler.fit(all_tabs)

        # CRITICAL R FIX: Standardize tabular features for numerical stability
        if len(self.sequences) > 0:
            all_tabs = np.vstack([seq['tabs'] for seq in self.sequences])
            self.scaler = StandardScaler()
            self.scaler.fit(all_tabs)
            
            for seq in self.sequences:
                seq['tabs'] = self.scaler.transform(seq['tabs']).astype(np.float32)
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        return (
            seq['imgs'],
            seq['tabs'],
            seq['times'],
            seq['mmse'],
            seq['mmse_future'],
            seq['seq_len'],
            seq['time_to_event'],
            seq['event'],
            seq['ptid']
        )

# ============================================================================
# MODEL
# ============================================================================

class TensorFusion(nn.Module):
    """Tensor fusion layer"""
    def __init__(self, v_dim, d_dim, t_dim, proj_dim=None, dropout=0.1):
        super().__init__()
        self.v_dim = v_dim
        self.d_dim = d_dim
        self.t_dim = t_dim
        self.output_dim = (v_dim + 1) * (d_dim + 1) * (t_dim + 1)
        self.dropout = nn.Dropout(dropout)
        
        if proj_dim:
            self.proj = nn.Linear(self.output_dim, proj_dim)
        else:
            self.proj = None
    
    def forward(self, v, d, t):
        batch_size = v.shape[0]
        
        v_1 = torch.cat([v, torch.ones(batch_size, 1, device=v.device)], dim=1)
        d_1 = torch.cat([d, torch.ones(batch_size, 1, device=d.device)], dim=1)
        t_1 = torch.cat([t, torch.ones(batch_size, 1, device=t.device)], dim=1)
        
        fusion = torch.einsum('bi,bj,bk->bijk', v_1, d_1, t_1)
        fusion = fusion.view(batch_size, -1)
        
        fusion = self.dropout(fusion)
        
        if self.proj:
            fusion = self.proj(fusion)
        
        return fusion

class AttentionImageEncoder(nn.Module):
    """Image encoder with spatial attention"""
    def __init__(self, out_dim=256):
        super().__init__()
        base = models.resnet18(pretrained=True)
        
        for param in list(base.parameters())[:-20]:
            param.requires_grad = False
        
        self.features = nn.Sequential(*list(base.children())[:-2])
        
        self.attention = nn.Sequential(
            nn.Conv2d(512, 256, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 1, 1),
            nn.Sigmoid()
        )
        
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.proj = nn.Linear(512, out_dim)
    
    def forward(self, x):
        feats = self.features(x)
        attn = self.attention(feats)
        feats = feats * attn
        pooled = self.global_pool(feats).view(x.size(0), -1)
        return self.proj(pooled)

class SupervisedTemporalFusionAutoencoder(nn.Module):
    """Multi-modal temporal model"""
    def __init__(self, tab_dim, config):
        super().__init__()
        self.config = config
        
        self.img_encoder = AttentionImageEncoder(out_dim=config['img_out_dim'])
        
        self.tab_encoder = nn.Sequential(
            nn.Linear(tab_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(128, config['tab_out_dim']),
            nn.ReLU()
        )
        
        self.time_encoder = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )
        
        self.fusion = TensorFusion(
            v_dim=config['img_out_dim'],
            d_dim=config['tab_out_dim'],
            t_dim=16,
            dropout=config['dropout']
        )
        
        fusion_dim = (config['img_out_dim'] + 1) * (config['tab_out_dim'] + 1) * 17
        self.fusion_proj = nn.Sequential(
            nn.Linear(fusion_dim, config['latent_dim']),
            nn.BatchNorm1d(config['latent_dim']),
            nn.ReLU()
        )
        
        self.lstm = nn.LSTM(
            input_size=config['latent_dim'],
            hidden_size=config['lstm_hidden'],
            num_layers=config['lstm_layers'],
            batch_first=True,
            dropout=config['dropout'] if config['lstm_layers'] > 1 else 0,
            bidirectional=True
        )
        
        self.temporal_proj = nn.Sequential(
            nn.Linear(config['lstm_hidden'] * 2, config['latent_dim']),
            nn.ReLU()
        )
        
        self.survival_head = nn.Sequential(
            nn.Linear(config['latent_dim'], 64),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
        
        self.mmse_head = nn.Sequential(
            nn.Linear(config['latent_dim'], 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
    
    def encode_visit(self, img, tab, time):
        """Encode a single visit"""
        v = self.img_encoder(img)
        d = self.tab_encoder(tab)
        t = self.time_encoder(time.unsqueeze(1))
        
        z = self.fusion(v, d, t)
        z = z.view(z.size(0), -1)
        z = self.fusion_proj(z)
        
        return z
    
    def forward(self, img_seq, tab_seq, time_seq, seq_lengths):
        """Forward pass"""
        batch_size, seq_len = img_seq.shape[:2]
        
        z_list = []
        for t in range(seq_len):
            img_t = img_seq[:, t]
            tab_t = tab_seq[:, t]
            time_t = time_seq[:, t]
            
            z_t = self.encode_visit(img_t, tab_t, time_t)
            z_list.append(z_t)
        
        z_seq = torch.stack(z_list, dim=1)
        
        packed = nn.utils.rnn.pack_padded_sequence(
            z_seq, seq_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        lstm_out, (h_n, c_n) = self.lstm(packed)
        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
        
        h_forward = h_n[-2]
        h_backward = h_n[-1]
        h_final = torch.cat([h_forward, h_backward], dim=1)
        
        z_final = self.temporal_proj(h_final)
        
        risk_score = self.survival_head(z_final)
        mmse_pred = self.mmse_head(z_final)
        
        return z_final, risk_score, mmse_pred

# ============================================================================
# LOSS FUNCTIONS
# ============================================================================

def cox_partial_likelihood_loss_efron(risk_scores, times, events):
    """Cox loss with Efron's approximation"""
    device = risk_scores.device
    
    mask = torch.isfinite(risk_scores) & torch.isfinite(times)
    risk_scores = risk_scores[mask]
    times = times[mask]
    events = events[mask]
    
    if events.sum() == 0:
        return torch.tensor(0.0, device=device)
    
    order = torch.argsort(times, descending=False)
    risk_scores = risk_scores[order]
    times = times[order]
    events = events[order]
    
    log_risk = risk_scores.view(-1)
    hazard = torch.exp(log_risk)
    
    unique_times = torch.unique(times[events == 1])
    
    loss = torch.tensor(0.0, device=device)
    for t in unique_times:
        at_risk = times >= t
        died = (times == t) & (events == 1)
        
        if died.sum() == 0:
            continue
        
        risk_set = hazard[at_risk].sum()
        died_risk = log_risk[died].sum()
        n_died = died.sum().float()
        
        tied_risk = hazard[died].sum()
        
        for j in range(int(n_died)):
            risk_set_j = risk_set - (j / n_died) * tied_risk
            loss += torch.log(risk_set_j + 1e-7)
        
        loss -= died_risk
    
    return loss / (events.sum() + 1e-7)

# ============================================================================
# TRAINING
# ============================================================================

def train_epoch(model, loader, optimizer, config, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    total_cox = 0
    total_mmse = 0
    
    for batch in loader:
        imgs, tabs, times, mmse, mmse_future, seq_lens, t_event, event, _ = batch
        
        imgs = imgs.to(device)
        tabs = torch.FloatTensor(tabs).to(device)
        times = torch.FloatTensor(times).to(device)
        mmse_future_vals = torch.FloatTensor(mmse_future).to(device)
        seq_lens = torch.LongTensor(seq_lens)
        t_event = t_event.float().to(device)
        event = event.float().to(device)
        
        # Get FUTURE MMSE targets (not current!)
        mmse_targets = []
        for i, slen in enumerate(seq_lens):
            future_val = mmse_future_vals[i, slen-1]
            mmse_targets.append(future_val if not torch.isnan(future_val) else torch.tensor(float('nan')))
        mmse_targets = torch.stack(mmse_targets).to(device)
        
        # Forward
        z_final, risk_scores, mmse_pred = model(imgs, tabs, times, seq_lens)
        
        # Cox loss (Efron)
        loss_cox = cox_partial_likelihood_loss_efron(
            risk_scores.squeeze(), t_event, event
        )
        
        # MMSE loss (predict FUTURE, not current)
        valid_mmse_mask = ~torch.isnan(mmse_targets)
        if valid_mmse_mask.any():
            loss_mmse = F.mse_loss(
                mmse_pred.squeeze()[valid_mmse_mask],
                mmse_targets[valid_mmse_mask]
            )
        else:
            loss_mmse = torch.tensor(0.0, device=device)
        
        # Combined
        loss = (config['alpha_survival'] * loss_cox + 
                config['alpha_mmse'] * loss_mmse)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_cox += loss_cox.item()
        total_mmse += loss_mmse.item() if valid_mmse_mask.any() else 0
    
    n_batches = len(loader)
    return {
        'total': total_loss / n_batches,
        'cox': total_cox / n_batches,
        'mmse': total_mmse / n_batches,
    }

def validate(model, loader, device):
    """Validate and compute C-index"""
    model.eval()
    all_risks = []
    all_times = []
    all_events = []
    
    with torch.no_grad():
        for batch in loader:
            imgs, tabs, times, mmse, mmse_future, seq_lens, t_event, event, _ = batch
            
            imgs = imgs.to(device)
            tabs = torch.FloatTensor(tabs).to(device)
            times = torch.FloatTensor(times).to(device)
            seq_lens = torch.LongTensor(seq_lens)
            
            _, risk_scores, _ = model(imgs, tabs, times, seq_lens)
            
            all_risks.extend(risk_scores.cpu().numpy().flatten())
            all_times.extend(t_event.numpy())
            all_events.extend(event.numpy())
    
    all_events = np.array(all_events).astype(bool)
    all_times = np.array(all_times)
    all_risks = np.array(all_risks)
    
    c_index = concordance_index(all_times, -all_risks, all_events)
    
    return c_index, all_risks, all_times, all_events

# ============================================================================
# EXPORT
# ============================================================================

def export_latent_features(model, loader, device, output_path):
    """Export for JMBayes2"""
    model.eval()
    rows = []

    BASELINE_FEATURES = ['AGE', 'PTGENDER', 'PTEDUCAT', 'ADAS13']
    tab_feature_names = TEMPORAL_FEATURES + STATIC_FEATURES

    with torch.no_grad():
        for batch in loader:
            imgs, tabs, times, mmse, mmse_future, seq_lens, t_event, event, ptids = batch

            imgs = imgs.to(device)
            tabs = tabs.to(device)
            times = times.to(device)
            seq_lens = seq_lens.cpu().long()

            for i in range(len(ptids)):
                slen = seq_lens[i].item()

                for t in range(slen):
                    img_t = imgs[i:i+1, t]
                    tab_t = tabs[i:i+1, t]
                    time_t = times[i:i+1, t]

                    z_visit = model.encode_visit(img_t, tab_t, time_t)

                    tab_vals_unscaled = loader.dataset.scaler.inverse_transform(
                        tabs[i, t].cpu().numpy().reshape(1, -1)
                    )[0]

                    row = {
                        "PTID": ptids[i],
                        "Years_bl": float(times[i, t].cpu()),
                        "MMSE": float(mmse[i, t]),
                        "time_to_event": float(t_event[i]),
                        "event": int(event[i]),
                    }

                    for feat in BASELINE_FEATURES:
                        feat_encoded = feat + '_encoded' if feat == 'PTGENDER' else feat
                        if feat_encoded in tab_feature_names:
                            idx = tab_feature_names.index(feat_encoded)
                            row[feat] = float(tab_vals_unscaled[idx])

                    z_vals = z_visit[0].cpu().numpy()
                    for k, val in enumerate(z_vals):
                        row[f"z_{k}"] = float(val)

                    rows.append(row)

    df = pd.DataFrame(rows).sort_values(["PTID", "Years_bl"])
    df.to_csv(output_path, index=False)

    print(f"\n✓ Exported to {output_path}")
    print(f"  Patients: {df['PTID'].nunique()}")
    print(f"  Visits: {len(df)}")
    print(f"  Latent features: {len([c for c in df.columns if c.startswith('z_')])}")

    return df

# ============================================================================
# MAIN
# ============================================================================

def main():
    print("=" * 80)
    print("ENHANCED MCI-TO-AD PREDICTION")
    print("=" * 80)
    
    device = CONFIG['device']
    print(f"\nDevice: {device}")
    
    # Load data
    print("\nLoading data...")
    manifest = pd.read_csv("./AD_Multimodal/TFN_AD/AD_Patient_Manifest.csv")

    # Load valid patients
    print("\nLoading valid patient list...")
    with open('VALID_PATIENTS.pkl', 'rb') as f:
        VALID_PATIENTS = pickle.load(f)
    print(f"Valid patients: {len(VALID_PATIENTS)}")
    
    # Transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Dataset
    print("Creating dataset...")
    dataset = SequenceDataset(manifest, VALID_PATIENTS, transform, max_seq_len=CONFIG['max_seq_len'])
    print(f"Total sequences: {len(dataset)}")
    
    # Split
    n_train = int(0.8 * len(dataset))
    n_val = len(dataset) - n_train
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [n_train, n_val],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
                             shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], 
                           shuffle=False, num_workers=0)
    
    print(f"Train: {n_train}, Val: {n_val}")
    
    # Model
    sample_batch = next(iter(train_loader))
    tab_dim = sample_batch[1].shape[2]
    
    print("\nInitializing model...")
    model = SupervisedTemporalFusionAutoencoder(tab_dim, CONFIG).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Parameters: {total_params:,}")
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=CONFIG['lr'], 
        weight_decay=CONFIG['weight_decay']
    )
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5
    )
    
    # Training
    print("\n" + "=" * 80)
    print("TRAINING")
    print("=" * 80)
    
    best_c_index = 0
    patience = 0
    
    for epoch in range(CONFIG['epochs']):
        train_losses = train_epoch(model, train_loader, optimizer, CONFIG, device)
        val_c_index, _, _, _ = validate(model, val_loader, device)
        
        scheduler.step(val_c_index)
        
        print(f"\nEpoch {epoch+1}/{CONFIG['epochs']}")
        print(f"  Loss: {train_losses['total']:.4f} "
              f"(Cox: {train_losses['cox']:.4f}, MMSE: {train_losses['mmse']:.4f})")
        print(f"  Val C-index: {val_c_index:.4f}")
        
        if val_c_index > best_c_index:
            best_c_index = val_c_index
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'c_index': best_c_index,
            }, 'best_model.pth')
            print(f"  ✓ New best: {best_c_index:.4f}")
            patience = 0
        else:
            patience += 1
        
        if patience >= 15:
            print(f"\n⚠ Early stopping")
            break
    
    # Export
    print("\n" + "=" * 80)
    print("EXPORTING")
    print("=" * 80)
    
    # FIX: Use weights_only=False for backward compatibility
    checkpoint = torch.load('best_model.pth', weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\nLoaded best model (epoch {checkpoint['epoch']+1})")
    
    full_loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], 
                            shuffle=False, num_workers=0)
    
    latent_df = export_latent_features(
        model, full_loader, device, 
        "latent_improved_autoencoder.csv"
    )
    
    print("\n" + "=" * 80)
    print("✓ COMPLETE")
    print("=" * 80)
    print(f"\nBest C-index: {best_c_index:.4f}")
    print("\nReady for JMBayes2 in R!")

if __name__ == "__main__":
    main()

ENHANCED MCI-TO-AD PREDICTION

Device: cuda

Loading data...

Loading valid patient list...
Valid patients: 161
Creating dataset...
  Processed: 161 valid patients
  Skipped (not in valid set): 221
Total sequences: 161
Train: 128, Val: 33

Initializing model...
Parameters: 48,509,283

TRAINING

Epoch 1/100
  Loss: 1.9393 (Cox: 2.2815, MMSE: 0.0000)
  Val C-index: 0.8673
  ✓ New best: 0.8673

Epoch 2/100
  Loss: 1.8497 (Cox: 2.1761, MMSE: 0.0000)
  Val C-index: 0.8803
  ✓ New best: 0.8803

Epoch 3/100
  Loss: 1.6664 (Cox: 1.9605, MMSE: 0.0000)
  Val C-index: 0.8511

Epoch 4/100
  Loss: 1.2096 (Cox: 1.4231, MMSE: 0.0000)
  Val C-index: 0.8220

Epoch 5/100
  Loss: 1.3658 (Cox: 1.6068, MMSE: 0.0000)
  Val C-index: 0.7282

Epoch 6/100
  Loss: 1.3349 (Cox: 1.5705, MMSE: 0.0000)
  Val C-index: 0.7152

Epoch 7/100
  Loss: 1.3127 (Cox: 1.5443, MMSE: 0.0000)
  Val C-index: 0.6893

Epoch 8/100
  Loss: 1.1706 (Cox: 1.3771, MMSE: 0.0000)
  Val C-index: 0.8285

Epoch 9/100
  Loss: 1.1585 (Cox: 1.362