In [1]:
"""
Benchmark 3: Tabular-Only Deep Neural Network
Tests whether deep learning on clinical features alone helps

FIXED VERSION:
- Uses VALID_PATIENTS.pkl for consistent patient cohort
- Survival time measured from MCI diagnosis (consistent)
- Same feature engineering as thesis
- Only tabular features (NO images)

Architecture:
- Deep feedforward network on tabular data
- LSTM for temporal modeling
- Shows value of adding imaging modality
"""

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import pickle
from lifelines.utils import concordance_index
import warnings
warnings.filterwarnings('ignore')

CONFIG = {
    'tab_hidden': 256,
    'lstm_hidden': 128,
    'lstm_layers': 2,
    'dropout': 0.3,
    'lr': 5e-4,
    'weight_decay': 1e-4,
    'epochs': 100,
    'batch_size': 16,
    'max_seq_len': 10,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# ============================================================================
# FEATURE ENGINEERING (Same as thesis)
# ============================================================================

STATIC_FEATURES = [
    'age_bl', 'PTGENDER_encoded', 'PTEDUCAT', 'PTETHCAT_encoded', 
    'PTRACCAT_encoded', 'PTMARRY_encoded'
]

TEMPORAL_FEATURES = [
    'time_from_baseline', 'AGE', 'age_since_bl', 'mmse_slope', 
    'adas13_slope', 'dx_progression', 'cog_decline_index', 
    'visit_number', 'MMSE', 'ADAS13'
]

def engineer_features(df):
    df = df.copy()
    
    df["time_from_baseline"] = df["Years_bl"] - df["Years_bl"].iloc[0]
    df["age_bl"] = df["AGE"].iloc[0]
    df["age_since_bl"] = df["AGE"] - df["age_bl"]
    
    df["mmse_slope"] = df["MMSE"].diff() / df["Years_bl"].diff()
    df["adas13_slope"] = df["ADAS13"].diff() / df["Years_bl"].diff()
    
    dx_map = {"CN": 0, "MCI": 1, "AD": 2, "Dementia": 2}
    df["dx_progression"] = df["DX"].map(dx_map).diff()
    
    df["cog_decline_index"] = df["ADAS13"] - df["MMSE"]
    df["visit_number"] = range(len(df))
    
    df['age_mmse_interaction'] = df['AGE'] * (30 - df['MMSE']) / 30
    df['education_cognitive_reserve'] = df['PTEDUCAT'] * df['MMSE'] / 30
    df['rapid_decline_flag'] = (df['mmse_slope'] < -2).astype(float)
    
    mmse_bins = [0, 20, 24, 30]
    df['mmse_severity'] = pd.cut(df['MMSE'], bins=mmse_bins, labels=[2, 1, 0]).astype(float)
    
    df['weighted_mmse_decline'] = df['mmse_slope'] * np.exp(-0.1 * df['time_from_baseline'])
    df['mmse_variability'] = df['MMSE'].rolling(window=3, min_periods=1).std()
    
    df['adas_mmse_discordance'] = np.abs(
        (df['ADAS13'] - df['ADAS13'].mean()) / (df['ADAS13'].std() + 1e-7) - 
        (df['MMSE'] - df['MMSE'].mean()) / (df['MMSE'].std() + 1e-7)
    )
    
    df = df.fillna(0)
    return df

TEMPORAL_FEATURES.extend([
    'age_mmse_interaction', 'education_cognitive_reserve', 'rapid_decline_flag',
    'mmse_severity', 'weighted_mmse_decline', 'mmse_variability', 'adas_mmse_discordance'
])

# ============================================================================
# DATASET
# ============================================================================

class TabularOnlyDataset(Dataset):
    def __init__(self, manifest, valid_patients, max_seq_len=10):
        self.sequences = []
        self.max_seq_len = max_seq_len
        
        manifest["path"] = manifest["path"].str.replace("\\", "/", regex=False)
        manifest["path"] = "./AD_Multimodal/TFN_AD/" + manifest["path"]
        
        processed = 0
        skipped_not_valid = 0
        
        for ptid in manifest['PTID'].unique():
            # CRITICAL: Only process valid patients
            if ptid not in valid_patients:
                skipped_not_valid += 1
                continue
            
            try:
                patient_rows = manifest[manifest['PTID'] == ptid]
                if len(patient_rows) == 0:
                    continue
                
                df = pd.read_pickle(patient_rows.iloc[0]["path"])
                df = engineer_features(df)
                
                dx_seq = df["DX"].tolist()
                if "MCI" not in dx_seq:
                    continue
                
                # FIXED: Time from MCI diagnosis
                mci_idx = dx_seq.index("MCI")
                ad_idx = next((i for i, x in enumerate(dx_seq[mci_idx+1:], start=mci_idx+1) 
                              if x in ["AD", "Dementia"]), -1)
                
                if ad_idx != -1:
                    time_to_event = df["Years_bl"].iloc[ad_idx] - df["Years_bl"].iloc[mci_idx]
                    event = 1
                else:
                    time_to_event = df["Years_bl"].iloc[-1] - df["Years_bl"].iloc[mci_idx]
                    event = 0
                
                tabs, times = [], []
                valid_visits = 0
                
                for _, visit in df.iterrows():
                    tabs.append(visit[TEMPORAL_FEATURES + STATIC_FEATURES].values.astype(np.float32))
                    times.append(visit["Years_bl"])
                    valid_visits += 1
                    
                    if valid_visits >= max_seq_len:
                        break
                
                if valid_visits < 2:
                    continue
                
                pad_len = max_seq_len - len(tabs)
                if pad_len > 0:
                    for _ in range(pad_len):
                        tabs.append(np.zeros_like(tabs[-1]))
                        times.append(times[-1])
                
                self.sequences.append({
                    'ptid': ptid,
                    'tabs': np.array(tabs, dtype=np.float32),
                    'times': np.array(times, dtype=np.float32),
                    'seq_len': valid_visits,
                    'time_to_event': time_to_event,
                    'event': event
                })
                
                processed += 1
                
            except Exception as e:
                continue
        
        print(f"  Processed: {processed} valid patients")
        print(f"  Skipped (not in valid set): {skipped_not_valid}")
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        return (
            seq['tabs'], seq['times'], seq['seq_len'],
            seq['time_to_event'], seq['event'], seq['ptid']
        )

# ============================================================================
# MODEL
# ============================================================================

class TabularOnlyModel(nn.Module):
    def __init__(self, tab_dim, config):
        super().__init__()
        self.config = config
        
        # Deep tabular encoder
        self.tab_encoder = nn.Sequential(
            nn.Linear(tab_dim, config['tab_hidden']),
            nn.BatchNorm1d(config['tab_hidden']),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(config['tab_hidden'], config['tab_hidden'] // 2),
            nn.BatchNorm1d(config['tab_hidden'] // 2),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(config['tab_hidden'] // 2, 128),
            nn.ReLU()
        )
        
        # LSTM for temporal modeling
        self.lstm = nn.LSTM(
            input_size=128,
            hidden_size=config['lstm_hidden'],
            num_layers=config['lstm_layers'],
            batch_first=True,
            dropout=config['dropout'] if config['lstm_layers'] > 1 else 0,
            bidirectional=True
        )
        
        # Risk prediction
        self.risk_head = nn.Sequential(
            nn.Linear(config['lstm_hidden'] * 2, 64),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(64, 1)
        )
    
    def encode_tabular(self, tab):
        return self.tab_encoder(tab)
    
    def forward(self, tab_seq, seq_lengths):
        batch_size, seq_len, _ = tab_seq.shape
        
        # Encode each timestep
        tab_feats = []
        for t in range(seq_len):
            feat = self.encode_tabular(tab_seq[:, t])
            tab_feats.append(feat)
        
        tab_seq_feat = torch.stack(tab_feats, dim=1)
        
        # LSTM
        packed = nn.utils.rnn.pack_padded_sequence(
            tab_seq_feat, seq_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        lstm_out, (h_n, c_n) = self.lstm(packed)
        
        h_forward = h_n[-2]
        h_backward = h_n[-1]
        h_final = torch.cat([h_forward, h_backward], dim=1)
        
        risk_score = self.risk_head(h_final)
        
        return risk_score

# ============================================================================
# LOSS & TRAINING
# ============================================================================

def cox_loss(risk_scores, times, events):
    order = torch.argsort(times, descending=True)
    risk_scores = risk_scores[order]
    events = events[order]
    
    log_risk = risk_scores.view(-1)
    log_cumsum_hazard = torch.logcumsumexp(log_risk, dim=0)
    loss = -(log_risk - log_cumsum_hazard) * events
    
    return loss.sum() / (events.sum() + 1e-7)

def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in loader:
        tabs, times, seq_lens, t_event, event, _ = batch
        
        tabs = torch.FloatTensor(tabs).to(device)
        seq_lens = torch.LongTensor(seq_lens)
        t_event = t_event.float().to(device)
        event = event.float().to(device)
        
        risk_scores = model(tabs, seq_lens)
        loss = cox_loss(risk_scores.squeeze(), t_event, event)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

def validate(model, loader, device):
    model.eval()
    all_risks, all_times, all_events = [], [], []
    
    with torch.no_grad():
        for batch in loader:
            tabs, times, seq_lens, t_event, event, _ = batch
            
            tabs = torch.FloatTensor(tabs).to(device)
            seq_lens = torch.LongTensor(seq_lens)
            
            risk_scores = model(tabs, seq_lens)
            
            all_risks.extend(risk_scores.cpu().numpy().flatten())
            all_times.extend(t_event.numpy())
            all_events.extend(event.numpy())
    
    all_events = np.array(all_events).astype(bool)
    all_times = np.array(all_times)
    all_risks = np.array(all_risks)
    
    c_index = concordance_index(all_times, -all_risks, all_events)
    return c_index

# ============================================================================
# EXPORT
# ============================================================================

def export_features(model, loader, device, output_path):
    """Export tabular features - baseline visit only for survival modeling"""
    model.eval()
    rows = []
    
    with torch.no_grad():
        for batch in loader:
            tabs, times, seq_lens, t_event, event, ptids = batch
            
            tabs = torch.FloatTensor(tabs).to(device)
            seq_lens = torch.LongTensor(seq_lens)
            
            # Only use FIRST visit for each patient (baseline)
            for i in range(len(ptids)):
                tab_feat = model.encode_tabular(tabs[i:i+1, 0])
                tab_vals = tabs[i, 0].cpu().numpy()
                
                row = {
                    "PTID": ptids[i],
                    "Years_bl": float(times[i][0]),
                    "time_to_event": float(t_event[i]),
                    "event": int(event[i])
                }
                
                feat_vals = tab_feat[0].cpu().numpy()
                for k in range(len(feat_vals)):
                    row[f"tab_feat_{k}"] = float(feat_vals[k])
                
                rows.append(row)
    
    df = pd.DataFrame(rows).sort_values(['PTID', 'Years_bl'])
    df.to_csv(output_path, index=False)
    
    print(f"\n✓ Exported to {output_path}")
    print(f"  Patients: {df['PTID'].nunique()}, Features: {len([c for c in df.columns if c.startswith('tab_feat')])}")
    
    return df

# ============================================================================
# MAIN
# ============================================================================

def main():
    print("=" * 80)
    print("BENCHMARK 3: TABULAR-ONLY DEEP NETWORK")
    print("=" * 80)
    
    device = CONFIG['device']
    print(f"\nDevice: {device}")
    
    # Load valid patients
    print("\nLoading valid patient list...")
    with open('VALID_PATIENTS.pkl', 'rb') as f:
        VALID_PATIENTS = pickle.load(f)
    print(f"Valid patients: {len(VALID_PATIENTS)}")
    
    manifest = pd.read_csv("./AD_Multimodal/TFN_AD/AD_Patient_Manifest.csv")
    
    dataset = TabularOnlyDataset(manifest, VALID_PATIENTS, max_seq_len=CONFIG['max_seq_len'])
    print(f"Total sequences: {len(dataset)}")
    
    n_train = int(0.8 * len(dataset))
    n_val = len(dataset) - n_train
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
                             shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], 
                           shuffle=False, num_workers=0)
    
    sample_batch = next(iter(train_loader))
    tab_dim = sample_batch[0].shape[2]
    print(f"Tabular features: {tab_dim}")
    
    model = TabularOnlyModel(tab_dim, CONFIG).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], 
                                  weight_decay=CONFIG['weight_decay'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5
    )
    
    print("\nTraining...")
    best_c_index = 0
    patience_counter = 0
    
    for epoch in range(CONFIG['epochs']):
        train_loss = train_epoch(model, train_loader, optimizer, device)
        val_c_index = validate(model, val_loader, device)
        
        scheduler.step(val_c_index)
        
        print(f"Epoch {epoch+1}/{CONFIG['epochs']} - Loss: {train_loss:.4f}, C-index: {val_c_index:.4f}")
        
        if val_c_index > best_c_index:
            best_c_index = val_c_index
            torch.save(model.state_dict(), 'tabular_only_model.pth')
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= 15:
            print("Early stopping")
            break
    
    # Export
    model.load_state_dict(torch.load('tabular_only_model.pth'))
    full_loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], 
                            shuffle=False, num_workers=0)
    
    export_features(model, full_loader, device, "tabular_only_features.csv")
    
    print("\n" + "=" * 80)
    print(f"✓ BEST C-INDEX: {best_c_index:.4f}")
    print("⚠️  This uses the SAME patient cohort as all other benchmarks")
    print("=" * 80)

if __name__ == "__main__":
    main()
# """
# Benchmark 3: Tabular-Only Deep Neural Network
# Tests whether deep learning on clinical features alone helps

# Architecture:
# - Deep feedforward network on tabular data
# - LSTM for temporal modeling
# - NO images
# - Shows value of adding imaging modality
# """

# import torch
# import torch.nn as nn
# from torch.utils.data import Dataset, DataLoader
# import pandas as pd
# import numpy as np
# from lifelines.utils import concordance_index
# import warnings
# warnings.filterwarnings('ignore')

# CONFIG = {
#     'tab_hidden': 256,
#     'lstm_hidden': 128,
#     'lstm_layers': 2,
#     'dropout': 0.3,
#     'lr': 5e-4,
#     'weight_decay': 1e-4,
#     'epochs': 100,
#     'batch_size': 16,
#     'max_seq_len': 10,
#     'device': 'cuda' if torch.cuda.is_available() else 'cpu'
# }

# # ============================================================================
# # FEATURE ENGINEERING (Same as thesis)
# # ============================================================================

# STATIC_FEATURES = [
#     'age_bl', 'PTGENDER_encoded', 'PTEDUCAT', 'PTETHCAT_encoded', 
#     'PTRACCAT_encoded', 'PTMARRY_encoded'
# ]

# TEMPORAL_FEATURES = [
#     'time_from_baseline', 'AGE', 'age_since_bl', 'mmse_slope', 
#     'adas13_slope', 'dx_progression', 'cog_decline_index', 
#     'visit_number', 'MMSE', 'ADAS13'
# ]

# def engineer_features(df):
#     df = df.copy()
    
#     df["time_from_baseline"] = df["Years_bl"] - df["Years_bl"].iloc[0]
#     df["age_bl"] = df["AGE"].iloc[0]
#     df["age_since_bl"] = df["AGE"] - df["age_bl"]
    
#     df["mmse_slope"] = df["MMSE"].diff() / df["Years_bl"].diff()
#     df["adas13_slope"] = df["ADAS13"].diff() / df["Years_bl"].diff()
    
#     dx_map = {"CN": 0, "MCI": 1, "AD": 2, "Dementia": 2}
#     df["dx_progression"] = df["DX"].map(dx_map).diff()
    
#     df["cog_decline_index"] = df["ADAS13"] - df["MMSE"]
#     df["visit_number"] = range(len(df))
    
#     df['age_mmse_interaction'] = df['AGE'] * (30 - df['MMSE']) / 30
#     df['education_cognitive_reserve'] = df['PTEDUCAT'] * df['MMSE'] / 30
#     df['rapid_decline_flag'] = (df['mmse_slope'] < -2).astype(float)
    
#     mmse_bins = [0, 20, 24, 30]
#     df['mmse_severity'] = pd.cut(df['MMSE'], bins=mmse_bins, labels=[2, 1, 0]).astype(float)
    
#     df['weighted_mmse_decline'] = df['mmse_slope'] * np.exp(-0.1 * df['time_from_baseline'])
#     df['mmse_variability'] = df['MMSE'].rolling(window=3, min_periods=1).std()
    
#     df['adas_mmse_discordance'] = np.abs(
#         (df['ADAS13'] - df['ADAS13'].mean()) / (df['ADAS13'].std() + 1e-7) - 
#         (df['MMSE'] - df['MMSE'].mean()) / (df['MMSE'].std() + 1e-7)
#     )
    
#     df = df.fillna(0)
#     return df

# TEMPORAL_FEATURES.extend([
#     'age_mmse_interaction', 'education_cognitive_reserve', 'rapid_decline_flag',
#     'mmse_severity', 'weighted_mmse_decline', 'mmse_variability', 'adas_mmse_discordance'
# ])

# # ============================================================================
# # DATASET
# # ============================================================================

# class TabularOnlyDataset(Dataset):
#     def __init__(self, manifest, max_seq_len=10):
#         self.sequences = []
#         self.max_seq_len = max_seq_len
        
#         manifest["path"] = manifest["path"].str.replace("\\", "/", regex=False)
#         manifest["path"] = "./AD_Multimodal/TFN_AD/" + manifest["path"]
        
#         for ptid in manifest['PTID'].unique():
#             try:
#                 patient_rows = manifest[manifest['PTID'] == ptid]
#                 if len(patient_rows) == 0:
#                     continue
                
#                 df = pd.read_pickle(patient_rows.iloc[0]["path"])
#                 df = engineer_features(df)
                
#                 dx_seq = df["DX"].tolist()
#                 if "MCI" not in dx_seq:
#                     continue
                
#                 mci_idx = dx_seq.index("MCI")
#                 ad_idx = next((i for i, x in enumerate(dx_seq[mci_idx+1:], start=mci_idx+1) 
#                               if x in ["AD", "Dementia"]), -1)
                
#                 if ad_idx != -1:
#                     time_to_event = df["Years_bl"].iloc[ad_idx]
#                     event = 1
#                 else:
#                     time_to_event = df["Years_bl"].iloc[-1]
#                     event = 0
                
#                 tabs, times = [], []
#                 valid_visits = 0
                
#                 for _, visit in df.iterrows():
#                     tabs.append(visit[TEMPORAL_FEATURES + STATIC_FEATURES].values.astype(np.float32))
#                     times.append(visit["Years_bl"])
#                     valid_visits += 1
                    
#                     if valid_visits >= max_seq_len:
#                         break
                
#                 if valid_visits < 2:
#                     continue
                
#                 pad_len = max_seq_len - len(tabs)
#                 if pad_len > 0:
#                     for _ in range(pad_len):
#                         tabs.append(np.zeros_like(tabs[-1]))
#                         times.append(times[-1])
                
#                 self.sequences.append({
#                     'ptid': ptid,
#                     'tabs': np.array(tabs, dtype=np.float32),
#                     'times': np.array(times, dtype=np.float32),
#                     'seq_len': valid_visits,
#                     'time_to_event': time_to_event,
#                     'event': event
#                 })
                
#             except Exception as e:
#                 continue
    
#     def __len__(self):
#         return len(self.sequences)
    
#     def __getitem__(self, idx):
#         seq = self.sequences[idx]
#         return (
#             seq['tabs'], seq['times'], seq['seq_len'],
#             seq['time_to_event'], seq['event'], seq['ptid']
#         )

# # ============================================================================
# # MODEL
# # ============================================================================

# class TabularOnlyModel(nn.Module):
#     def __init__(self, tab_dim, config):
#         super().__init__()
#         self.config = config
        
#         # Deep tabular encoder
#         self.tab_encoder = nn.Sequential(
#             nn.Linear(tab_dim, config['tab_hidden']),
#             nn.BatchNorm1d(config['tab_hidden']),
#             nn.ReLU(),
#             nn.Dropout(config['dropout']),
#             nn.Linear(config['tab_hidden'], config['tab_hidden'] // 2),
#             nn.BatchNorm1d(config['tab_hidden'] // 2),
#             nn.ReLU(),
#             nn.Dropout(config['dropout']),
#             nn.Linear(config['tab_hidden'] // 2, 128),
#             nn.ReLU()
#         )
        
#         # LSTM for temporal modeling
#         self.lstm = nn.LSTM(
#             input_size=128,
#             hidden_size=config['lstm_hidden'],
#             num_layers=config['lstm_layers'],
#             batch_first=True,
#             dropout=config['dropout'] if config['lstm_layers'] > 1 else 0,
#             bidirectional=True
#         )
        
#         # Risk prediction
#         self.risk_head = nn.Sequential(
#             nn.Linear(config['lstm_hidden'] * 2, 64),
#             nn.ReLU(),
#             nn.Dropout(config['dropout']),
#             nn.Linear(64, 1)
#         )
    
#     def encode_tabular(self, tab):
#         return self.tab_encoder(tab)
    
#     def forward(self, tab_seq, seq_lengths):
#         batch_size, seq_len, _ = tab_seq.shape
        
#         # Encode each timestep
#         tab_feats = []
#         for t in range(seq_len):
#             feat = self.encode_tabular(tab_seq[:, t])
#             tab_feats.append(feat)
        
#         tab_seq_feat = torch.stack(tab_feats, dim=1)
        
#         # LSTM
#         packed = nn.utils.rnn.pack_padded_sequence(
#             tab_seq_feat, seq_lengths.cpu(), batch_first=True, enforce_sorted=False
#         )
#         lstm_out, (h_n, c_n) = self.lstm(packed)
        
#         h_forward = h_n[-2]
#         h_backward = h_n[-1]
#         h_final = torch.cat([h_forward, h_backward], dim=1)
        
#         risk_score = self.risk_head(h_final)
        
#         return risk_score

# # ============================================================================
# # LOSS & TRAINING
# # ============================================================================

# def cox_loss(risk_scores, times, events):
#     order = torch.argsort(times, descending=True)
#     risk_scores = risk_scores[order]
#     events = events[order]
    
#     log_risk = risk_scores.view(-1)
#     log_cumsum_hazard = torch.logcumsumexp(log_risk, dim=0)
#     loss = -(log_risk - log_cumsum_hazard) * events
    
#     return loss.sum() / (events.sum() + 1e-7)

# def train_epoch(model, loader, optimizer, device):
#     model.train()
#     total_loss = 0
    
#     for batch in loader:
#         tabs, times, seq_lens, t_event, event, _ = batch
        
#         tabs = torch.FloatTensor(tabs).to(device)
#         seq_lens = torch.LongTensor(seq_lens)
#         t_event = t_event.float().to(device)
#         event = event.float().to(device)
        
#         risk_scores = model(tabs, seq_lens)
#         loss = cox_loss(risk_scores.squeeze(), t_event, event)
        
#         optimizer.zero_grad()
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#         optimizer.step()
        
#         total_loss += loss.item()
    
#     return total_loss / len(loader)

# def validate(model, loader, device):
#     model.eval()
#     all_risks, all_times, all_events = [], [], []
    
#     with torch.no_grad():
#         for batch in loader:
#             tabs, times, seq_lens, t_event, event, _ = batch
            
#             tabs = torch.FloatTensor(tabs).to(device)
#             seq_lens = torch.LongTensor(seq_lens)
            
#             risk_scores = model(tabs, seq_lens)
            
#             all_risks.extend(risk_scores.cpu().numpy().flatten())
#             all_times.extend(t_event.numpy())
#             all_events.extend(event.numpy())
    
#     all_events = np.array(all_events).astype(bool)
#     all_times = np.array(all_times)
#     all_risks = np.array(all_risks)
    
#     c_index = concordance_index(all_times, -all_risks, all_events)
#     return c_index

# # ============================================================================
# # EXPORT
# # ============================================================================

# def export_features(model, loader, device, output_path):
#     """Export tabular features - SURVIVAL DATA ONLY"""
#     model.eval()
#     rows = []
    
#     with torch.no_grad():
#         for batch in loader:
#             tabs, times, seq_lens, t_event, event, ptids = batch
            
#             tabs = torch.FloatTensor(tabs).to(device)
#             seq_lens = torch.LongTensor(seq_lens)
            
#             # Only use FIRST visit for each patient (baseline)
#             for i in range(len(ptids)):
#                 tab_feat = model.encode_tabular(tabs[i:i+1, 0])
#                 tab_vals = tabs[i, 0].cpu().numpy()
                
#                 row = {
#                     "PTID": ptids[i],
#                     "Years_bl": float(times[i][0]),
#                     "time_to_event": float(t_event[i]),
#                     "event": int(event[i])
#                 }
                
#                 feat_vals = tab_feat[0].cpu().numpy()
#                 for k in range(len(feat_vals)):
#                     row[f"tab_feat_{k}"] = float(feat_vals[k])
                
#                 rows.append(row)
    
#     df = pd.DataFrame(rows).sort_values(['PTID', 'Years_bl'])
#     df.to_csv(output_path, index=False)
    
#     print(f"\n✓ Exported to {output_path}")
#     print(f"  Patients: {df['PTID'].nunique()}, Features: {len([c for c in df.columns if c.startswith('tab_feat')])}")
    
#     return df

# # ============================================================================
# # MAIN
# # ============================================================================

# def main():
#     print("=" * 80)
#     print("BENCHMARK 3: TABULAR-ONLY DEEP NETWORK")
#     print("=" * 80)
    
#     device = CONFIG['device']
#     print(f"\nDevice: {device}")
    
#     manifest = pd.read_csv("./AD_Multimodal/TFN_AD/AD_Patient_Manifest.csv")
    
#     dataset = TabularOnlyDataset(manifest, max_seq_len=CONFIG['max_seq_len'])
#     print(f"Total sequences: {len(dataset)}")
    
#     n_train = int(0.8 * len(dataset))
#     n_val = len(dataset) - n_train
#     train_dataset, val_dataset = torch.utils.data.random_split(
#         dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42)
#     )
    
#     train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
#                              shuffle=True, num_workers=0)
#     val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], 
#                            shuffle=False, num_workers=0)
    
#     sample_batch = next(iter(train_loader))
#     tab_dim = sample_batch[0].shape[2]
#     print(f"Tabular features: {tab_dim}")
    
#     model = TabularOnlyModel(tab_dim, CONFIG).to(device)
#     optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], 
#                                   weight_decay=CONFIG['weight_decay'])
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#         optimizer, mode='max', factor=0.5, patience=5
#     )
    
#     print("\nTraining...")
#     best_c_index = 0
#     patience_counter = 0
    
#     for epoch in range(CONFIG['epochs']):
#         train_loss = train_epoch(model, train_loader, optimizer, device)
#         val_c_index = validate(model, val_loader, device)
        
#         scheduler.step(val_c_index)
        
#         print(f"Epoch {epoch+1}/{CONFIG['epochs']} - Loss: {train_loss:.4f}, C-index: {val_c_index:.4f}")
        
#         if val_c_index > best_c_index:
#             best_c_index = val_c_index
#             torch.save(model.state_dict(), 'tabular_only_model.pth')
#             patience_counter = 0
#         else:
#             patience_counter += 1
        
#         if patience_counter >= 15:
#             print("Early stopping")
#             break
    
#     # Export
#     model.load_state_dict(torch.load('tabular_only_model.pth'))
#     full_loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], 
#                             shuffle=False, num_workers=0)
    
#     export_features(model, full_loader, device, "tabular_only_features.csv")
    
#     print("\n" + "=" * 80)
#     print(f"✓ BEST C-INDEX: {best_c_index:.4f}")
#     print("=" * 80)

# if __name__ == "__main__":
#     main()
    

# # """
# # STEP 2: BASELINE METHOD - Export Clinical Features Only
# # This creates the comparison baseline WITHOUT tensor fusion/autoencoder
# # Same 161 patients, same preprocessing, just no deep learning
# # """

# # import pandas as pd
# # import numpy as np
# # import os

# # def engineer_features(df):
# #     """Same feature engineering as thesis method"""
# #     df = df.copy()
# #     df["time_from_baseline"] = df["Years_bl"] - df["Years_bl"].iloc[0]
# #     df["age_bl"] = df["AGE"].iloc[0]
# #     df["age_since_bl"] = df["AGE"] - df["age_bl"]
# #     df["mmse_slope"] = df["MMSE"].diff() / df["Years_bl"].diff()
# #     df["adas13_slope"] = df["ADAS13"].diff() / df["Years_bl"].diff()
# #     dx_map = {"CN": 0, "MCI": 1, "AD": 2, "Dementia": 2}
# #     df["dx_progression"] = df["DX"].map(dx_map).diff()
# #     df["cog_decline_index"] = df["ADAS13"] - df["MMSE"]
# #     df["visit_number"] = range(len(df))
# #     df = df.fillna(0)
# #     return df

# # print("="*80)
# # print("BASELINE METHOD: Exporting Clinical Features")
# # print("="*80)

# # # Load manifest
# # manifest = pd.read_csv("./AD_Multimodal/TFN_AD/AD_Patient_Manifest.csv")
# # manifest["path"] = manifest["path"].str.replace("\\", "/", regex=False)
# # manifest["path"] = "./AD_Multimodal/TFN_AD/" + manifest["path"]

# # rows = []
# # patients_processed = 0

# # for idx, row in manifest.iterrows():
# #     if not os.path.exists(row["path"]):
# #         continue
    
# #     try:
# #         df = pd.read_pickle(row["path"])
# #         df = engineer_features(df)
# #         dx_seq = df["DX"].tolist()
        
# #         # Only MCI patients
# #         if "MCI" not in dx_seq:
# #             continue
        
# #         # Compute survival outcome
# #         mci_idx = dx_seq.index("MCI")
# #         ad_idx = next((i for i, x in enumerate(dx_seq[mci_idx+1:], start=mci_idx+1)
# #                        if x in ["AD", "Dementia"]), -1)
        
# #         if ad_idx != -1:
# #             time_to_event = df["Years_bl"].iloc[ad_idx]
# #             event = 1
# #         else:
# #             time_to_event = df["Years_bl"].iloc[-1]
# #             event = 0
        
# #         # Export each visit with ONLY clinical/tabular features
# #         for _, visit in df.iterrows():
# #             # Check if image exists (for consistency with thesis method)
# #             image_path = visit["image_path"].replace(
# #                 "/home/mason/ADNI_Dataset/", 
# #                 "./AD_Multimodal/ADNI_Dataset/"
# #             )
            
# #             if not os.path.exists(image_path):
# #                 continue
            
# #             # Extract clinical features
# #             row_data = {
# #                 "PTID": visit["PTID"],
# #                 "Years_bl": visit["Years_bl"],
# #                 "time_to_event": time_to_event,
# #                 "event": event,
                
# #                 # Core clinical features (like benchmark)
# #                 "MMSE": visit["MMSE"],
# #                 "ADAS13": visit["ADAS13"],
# #                 "AGE": visit["AGE"],
# #                 "PTGENDER": visit["PTGENDER_encoded"],
# #                 "PTEDUCAT": visit["PTEDUCAT"],
                
# #                 # Additional demographics
# #                 "PTETHCAT": visit["PTETHCAT_encoded"],
# #                 "PTRACCAT": visit["PTRACCAT_encoded"],
# #                 "PTMARRY": visit["PTMARRY_encoded"],
                
# #                 # Engineered features (your contribution)
# #                 "age_bl": visit["age_bl"],
# #                 "age_since_bl": visit["age_since_bl"],
# #                 "mmse_slope": visit["mmse_slope"],
# #                 "adas13_slope": visit["adas13_slope"],
# #                 "dx_progression": visit["dx_progression"],
# #                 "cog_decline_index": visit["cog_decline_index"],
# #                 "visit_number": visit["visit_number"],
# #             }
            
# #             rows.append(row_data)
        
# #         patients_processed += 1
        
# #     except Exception as e:
# #         print(f"⚠️ Error processing {row['path']}: {e}")
# #         continue

# # # Create DataFrame
# # baseline_df = pd.DataFrame(rows)

# # print(f"\n✓ Processed {patients_processed} patients")
# # print(f"✓ Total visits: {len(baseline_df)}")
# # print(f"✓ Events: {baseline_df['event'].sum()}")
# # print(f"✓ Event rate: {baseline_df.groupby('PTID')['event'].first().mean()*100:.1f}%")

# # # Save
# # baseline_df.to_csv("baseline_clinical_features.csv", index=False)
# # print(f"\n✓ Saved to: baseline_clinical_features.csv")

# # print("\n" + "="*80)
# # print("BASELINE FEATURES EXPORTED")
# # print("="*80)
# # print("This dataset contains ONLY clinical/tabular features.")
# # print("Use this for comparison against your thesis method (fusion + autoencoder).")
# # print("\nFeatures included:")
# # print("  - Core clinical: MMSE, ADAS13, AGE, PTGENDER, PTEDUCAT")
# # print("  - Demographics: PTETHCAT, PTRACCAT, PTMARRY")
# # print("  - Engineered: slopes, progression, decline index")
# # print(f"\nNext: Run your thesis method (fusion + autoencoder) on same {patients_processed} patients")

BENCHMARK 3: TABULAR-ONLY DEEP NETWORK

Device: cuda

Loading valid patient list...
Valid patients: 161
  Processed: 161 valid patients
  Skipped (not in valid set): 221
Total sequences: 161
Tabular features: 23

Training...
Epoch 1/100 - Loss: 2.2746, C-index: 0.8350
Epoch 2/100 - Loss: 2.2125, C-index: 0.8188
Epoch 3/100 - Loss: 1.9936, C-index: 0.8252
Epoch 4/100 - Loss: 1.8367, C-index: 0.8285
Epoch 5/100 - Loss: 1.8272, C-index: 0.8414
Epoch 6/100 - Loss: 1.7868, C-index: 0.8414
Epoch 7/100 - Loss: 1.7790, C-index: 0.8447
Epoch 8/100 - Loss: 1.7369, C-index: 0.8414
Epoch 9/100 - Loss: 1.7838, C-index: 0.8414
Epoch 10/100 - Loss: 1.7139, C-index: 0.8350
Epoch 11/100 - Loss: 1.6903, C-index: 0.8350
Epoch 12/100 - Loss: 1.6914, C-index: 0.8252
Epoch 13/100 - Loss: 1.5982, C-index: 0.8220
Epoch 14/100 - Loss: 1.6423, C-index: 0.8285
Epoch 15/100 - Loss: 1.7197, C-index: 0.8317
Epoch 16/100 - Loss: 1.5417, C-index: 0.8188
Epoch 17/100 - Loss: 1.6740, C-index: 0.8220
Epoch 18/100 - Loss