In [2]:
"""
Benchmark 4: Concatenation Fusion
Simple concatenation instead of tensor fusion

FIXED VERSION:
- Uses VALID_PATIENTS.pkl for consistent patient cohort
- Survival time measured from MCI diagnosis (consistent)
- Proper feature standardization and NaN handling
- Shows that tensor fusion is superior to naive concatenation

Architecture:
- Same as thesis BUT: simple concatenation [img; tab] instead of tensor product
- Everything else identical (LSTM, autoencoder, etc.)
- Expected to be worse than thesis but better than single modality
"""

import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import numpy as np
import os
import pickle
from lifelines.utils import concordance_index
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

CONFIG = {
    'latent_dim': 128,
    'img_out_dim': 256,
    'tab_out_dim': 64,
    'lstm_hidden': 128,
    'lstm_layers': 2,
    'dropout': 0.3,
    'lr': 5e-4,
    'weight_decay': 1e-4,
    'epochs': 100,
    'batch_size': 16,
    'max_seq_len': 10,
    'alpha_recon': 0.2,
    'alpha_survival': 0.6,
    'alpha_mmse': 0.2,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# ============================================================================
# FEATURE ENGINEERING (Same as thesis)
# ============================================================================

STATIC_FEATURES = [
    'age_bl', 'PTGENDER_encoded', 'PTEDUCAT', 'PTETHCAT_encoded', 
    'PTRACCAT_encoded', 'PTMARRY_encoded'
]

TEMPORAL_FEATURES = [
    'time_from_baseline', 'AGE', 'age_since_bl', 'mmse_slope', 
    'adas13_slope', 'dx_progression', 'cog_decline_index', 
    'visit_number', 'MMSE', 'ADAS13'
]

def engineer_features(df):
    df = df.copy()
    df["time_from_baseline"] = df["Years_bl"] - df["Years_bl"].iloc[0]
    df["age_bl"] = df["AGE"].iloc[0]
    df["age_since_bl"] = df["AGE"] - df["age_bl"]
    df["mmse_slope"] = df["MMSE"].diff() / df["Years_bl"].diff()
    df["adas13_slope"] = df["ADAS13"].diff() / df["Years_bl"].diff()
    dx_map = {"CN": 0, "MCI": 1, "AD": 2, "Dementia": 2}
    df["dx_progression"] = df["DX"].map(dx_map).diff()
    df["cog_decline_index"] = df["ADAS13"] - df["MMSE"]
    df["visit_number"] = range(len(df))
    df['age_mmse_interaction'] = df['AGE'] * (30 - df['MMSE']) / 30
    df['education_cognitive_reserve'] = df['PTEDUCAT'] * df['MMSE'] / 30
    df['rapid_decline_flag'] = (df['mmse_slope'] < -2).astype(float)
    mmse_bins = [0, 20, 24, 30]
    df['mmse_severity'] = pd.cut(df['MMSE'], bins=mmse_bins, labels=[2, 1, 0]).astype(float)
    df['weighted_mmse_decline'] = df['mmse_slope'] * np.exp(-0.1 * df['time_from_baseline'])
    df['mmse_variability'] = df['MMSE'].rolling(window=3, min_periods=1).std()
    df['adas_mmse_discordance'] = np.abs(
        (df['ADAS13'] - df['ADAS13'].mean()) / (df['ADAS13'].std() + 1e-7) - 
        (df['MMSE'] - df['MMSE'].mean()) / (df['MMSE'].std() + 1e-7)
    )
    df = df.fillna(0)
    return df

TEMPORAL_FEATURES.extend([
    'age_mmse_interaction', 'education_cognitive_reserve', 'rapid_decline_flag',
    'mmse_severity', 'weighted_mmse_decline', 'mmse_variability', 'adas_mmse_discordance'
])

# ============================================================================
# DATASET
# ============================================================================

class SequenceDataset(Dataset):
    def __init__(self, manifest, valid_patients, transform=None, max_seq_len=10):
        self.sequences = []
        self.transform = transform
        self.max_seq_len = max_seq_len
        
        manifest["path"] = manifest["path"].str.replace("\\", "/", regex=False)
        manifest["path"] = "./AD_Multimodal/TFN_AD/" + manifest["path"]
        
        processed = 0
        skipped_not_valid = 0
        
        for ptid in manifest['PTID'].unique():
            # CRITICAL: Only process valid patients
            if ptid not in valid_patients:
                skipped_not_valid += 1
                continue
            
            try:
                patient_rows = manifest[manifest['PTID'] == ptid]
                if len(patient_rows) == 0:
                    continue
                
                df = pd.read_pickle(patient_rows.iloc[0]["path"])
                df = engineer_features(df)
                
                dx_seq = df["DX"].tolist()
                if "MCI" not in dx_seq:
                    continue
                
                # FIXED: Time from MCI diagnosis
                mci_idx = dx_seq.index("MCI")
                ad_idx = next((i for i, x in enumerate(dx_seq[mci_idx+1:], start=mci_idx+1) 
                              if x in ["AD", "Dementia"]), -1)
                
                if ad_idx != -1:
                    time_to_event = df["Years_bl"].iloc[ad_idx] - df["Years_bl"].iloc[mci_idx]
                    event = 1
                else:
                    time_to_event = df["Years_bl"].iloc[-1] - df["Years_bl"].iloc[mci_idx]
                    event = 0
                
                imgs, tabs, times, mmse_vals = [], [], [], []
                valid_visits = 0
                
                for _, visit in df.iterrows():
                    image_path = visit["image_path"].replace(
                        "/home/mason/ADNI_Dataset/", 
                        "./AD_Multimodal/ADNI_Dataset/"
                    )
                    
                    if not os.path.exists(image_path):
                        continue
                    
                    img = Image.open(image_path).convert("RGB")
                    if self.transform:
                        img = self.transform(img)
                    
                    imgs.append(img)
                    tabs.append(visit[TEMPORAL_FEATURES + STATIC_FEATURES].values.astype(np.float32))
                    times.append(visit["Years_bl"])
                    mmse_vals.append(visit["MMSE"])
                    valid_visits += 1
                    
                    if valid_visits >= max_seq_len:
                        break
                
                if valid_visits < 2:
                    continue
                
                pad_len = max_seq_len - len(imgs)
                if pad_len > 0:
                    for _ in range(pad_len):
                        imgs.append(torch.zeros_like(imgs[-1]))
                        tabs.append(np.zeros_like(tabs[-1]))
                        times.append(times[-1])
                        mmse_vals.append(0.0)
                
                self.sequences.append({
                    'ptid': ptid,
                    'imgs': torch.stack(imgs),
                    'tabs': np.array(tabs, dtype=np.float32),
                    'times': np.array(times, dtype=np.float32),
                    'mmse': np.array(mmse_vals, dtype=np.float32),
                    'seq_len': valid_visits,
                    'time_to_event': time_to_event,
                    'event': event
                })
                
                processed += 1
                
            except Exception as e:
                continue
        
        # ADDED: Standardize tabular features for stability
        if len(self.sequences) > 0:
            all_tabs = np.vstack([seq['tabs'] for seq in self.sequences])
            self.scaler = StandardScaler()
            self.scaler.fit(all_tabs)
            
            for seq in self.sequences:
                seq['tabs'] = self.scaler.transform(seq['tabs']).astype(np.float32)
        
        print(f"  Processed: {processed} valid patients")
        print(f"  Skipped (not in valid set): {skipped_not_valid}")
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        seq = self.sequences[idx]
        return (
            seq['imgs'], seq['tabs'], seq['times'], seq['mmse'],
            seq['seq_len'], seq['time_to_event'], seq['event'], seq['ptid']
        )

# ============================================================================
# MODEL: CONCATENATION INSTEAD OF TENSOR FUSION
# ============================================================================

class AttentionImageEncoder(nn.Module):
    def __init__(self, out_dim=256):
        super().__init__()
        base = models.resnet18(pretrained=True)
        for param in list(base.parameters())[:-20]:
            param.requires_grad = False
        
        self.features = nn.Sequential(*list(base.children())[:-2])
        self.attention = nn.Sequential(
            nn.Conv2d(512, 256, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 1, 1),
            nn.Sigmoid()
        )
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.proj = nn.Linear(512, out_dim)
    
    def forward(self, x):
        feats = self.features(x)
        attn = self.attention(feats)
        feats = feats * attn
        pooled = self.global_pool(feats).view(x.size(0), -1)
        return self.proj(pooled)

class ConcatenationFusionAutoencoder(nn.Module):
    """
    KEY DIFFERENCE: Uses simple concatenation [img; tab; time] 
    instead of tensor product fusion
    """
    def __init__(self, tab_dim, config):
        super().__init__()
        self.config = config
        
        self.img_encoder = AttentionImageEncoder(out_dim=config['img_out_dim'])
        
        self.tab_encoder = nn.Sequential(
            nn.Linear(tab_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(128, config['tab_out_dim']),
            nn.ReLU()
        )
        
        self.time_encoder = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU()
        )
        
        # CONCATENATION instead of tensor fusion
        concat_dim = config['img_out_dim'] + config['tab_out_dim'] + 16
        
        self.concat_proj = nn.Sequential(
            nn.Linear(concat_dim, config['latent_dim']),
            nn.BatchNorm1d(config['latent_dim']),
            nn.ReLU()
        )
        
        self.lstm = nn.LSTM(
            input_size=config['latent_dim'],
            hidden_size=config['lstm_hidden'],
            num_layers=config['lstm_layers'],
            batch_first=True,
            dropout=config['dropout'] if config['lstm_layers'] > 1 else 0,
            bidirectional=True
        )
        
        self.temporal_proj = nn.Sequential(
            nn.Linear(config['lstm_hidden'] * 2, config['latent_dim']),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(config['latent_dim'], 128),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(128, tab_dim)
        )
        
        self.survival_head = nn.Sequential(
            nn.Linear(config['latent_dim'], 64),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
        
        self.mmse_head = nn.Sequential(
            nn.Linear(config['latent_dim'], 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
    
    def encode_visit(self, img, tab, time):
        v = self.img_encoder(img)
        d = self.tab_encoder(tab)
        t = self.time_encoder(time.unsqueeze(1))
        
        # SIMPLE CONCATENATION (not tensor fusion!)
        concat = torch.cat([v, d, t], dim=1)
        z = self.concat_proj(concat)
        
        return z
    
    def forward(self, img_seq, tab_seq, time_seq, seq_lengths):
        batch_size, seq_len = img_seq.shape[:2]
        
        z_list = []
        for t in range(seq_len):
            z_t = self.encode_visit(img_seq[:, t], tab_seq[:, t], time_seq[:, t])
            z_list.append(z_t)
        
        z_seq = torch.stack(z_list, dim=1)
        
        packed = nn.utils.rnn.pack_padded_sequence(
            z_seq, seq_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        lstm_out, (h_n, c_n) = self.lstm(packed)
        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
        
        h_forward = h_n[-2]
        h_backward = h_n[-1]
        h_final = torch.cat([h_forward, h_backward], dim=1)
        
        z_final = self.temporal_proj(h_final)
        
        tab_recon = self.decoder(z_final)
        risk_score = self.survival_head(z_final)
        mmse_pred = self.mmse_head(z_final)
        
        return z_final, tab_recon, risk_score, mmse_pred

# ============================================================================
# LOSS & TRAINING
# ============================================================================

def cox_loss(risk_scores, times, events):
    order = torch.argsort(times, descending=True)
    risk_scores = risk_scores[order]
    events = events[order]
    log_risk = risk_scores.view(-1)
    log_cumsum_hazard = torch.logcumsumexp(log_risk, dim=0)
    loss = -(log_risk - log_cumsum_hazard) * events
    return loss.sum() / (events.sum() + 1e-7)

def train_epoch(model, loader, optimizer, config, device):
    model.train()
    total_loss = 0
    
    for batch in loader:
        imgs, tabs, times, mmse, seq_lens, t_event, event, _ = batch
        
        imgs = imgs.to(device)
        tabs = torch.FloatTensor(tabs).to(device)
        times = torch.FloatTensor(times).to(device)
        mmse = torch.FloatTensor(mmse).to(device)
        seq_lens = torch.LongTensor(seq_lens)
        t_event = t_event.float().to(device)
        event = event.float().to(device)
        
        last_tabs = torch.stack([tabs[i, seq_lens[i]-1] for i in range(len(seq_lens))])
        last_mmse = torch.FloatTensor([mmse[i, seq_lens[i]-1] for i in range(len(seq_lens))]).to(device)
        
        z_final, tab_recon, risk_scores, mmse_pred = model(imgs, tabs, times, seq_lens)
        
        loss_recon = nn.MSELoss()(tab_recon, last_tabs)
        loss_cox = cox_loss(risk_scores.squeeze(), t_event, event)
        loss_mmse = nn.MSELoss()(mmse_pred.squeeze(), last_mmse)
        
        loss = (config['alpha_recon'] * loss_recon + 
                config['alpha_survival'] * loss_cox + 
                config['alpha_mmse'] * loss_mmse)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

def validate(model, loader, device):
    model.eval()
    all_risks, all_times, all_events = [], [], []
    
    with torch.no_grad():
        for batch in loader:
            imgs, tabs, times, mmse, seq_lens, t_event, event, _ = batch
            
            imgs = imgs.to(device)
            tabs = torch.FloatTensor(tabs).to(device)
            times = torch.FloatTensor(times).to(device)
            seq_lens = torch.LongTensor(seq_lens)
            
            _, _, risk_scores, _ = model(imgs, tabs, times, seq_lens)
            
            all_risks.extend(risk_scores.cpu().numpy().flatten())
            all_times.extend(t_event.numpy())
            all_events.extend(event.numpy())
    
    c_index = concordance_index(np.array(all_times), -np.array(all_risks), np.array(all_events).astype(bool))
    return c_index

# ============================================================================
# EXPORT
# ============================================================================

def export_features(model, loader, device, output_path):
    """Export features with proper column structure and NaN handling"""
    model.eval()
    rows = []
    
    BASELINE_FEATURES = ['AGE', 'PTGENDER_encoded', 'PTEDUCAT', 'ADAS13']
    
    with torch.no_grad():
        for batch in loader:
            imgs, tabs, times, mmse, seq_lens, t_event, event, ptids = batch
            
            imgs = imgs.to(device)
            tabs = torch.FloatTensor(tabs).to(device)
            times = torch.FloatTensor(times).to(device)
            seq_lens = torch.LongTensor(seq_lens)
            
            for i in range(len(ptids)):
                slen = seq_lens[i].item()
                
                for t in range(slen):
                    feat = model.encode_visit(imgs[i:i+1, t], tabs[i:i+1, t], times[i:i+1, t])
                    feat_vals = feat[0].cpu().numpy()
                    
                    # CRITICAL: Check for NaN/Inf
                    if np.any(np.isnan(feat_vals)) or np.any(np.isinf(feat_vals)):
                        feat_vals = np.nan_to_num(feat_vals, nan=0.0, posinf=0.0, neginf=0.0)
                    
                    # Unscale tabular features
                    tab_vals = loader.dataset.scaler.inverse_transform(
                        tabs[i, t].cpu().numpy().reshape(1, -1)
                    )[0]
                    
                    row = {
                        "PTID": ptids[i],
                        "Years_bl": float(times[i, t].cpu()),
                        "MMSE": float(mmse[i, t]),
                        "time_to_event": float(t_event[i]),
                        "event": int(event[i]),
                    }
                    
                    # Add clinical features
                    tab_feature_names = TEMPORAL_FEATURES + STATIC_FEATURES
                    for f in BASELINE_FEATURES:
                        if f in tab_feature_names:
                            idx = tab_feature_names.index(f)
                            if idx < len(tab_vals):
                                val = float(tab_vals[idx])
                                if np.isnan(val) or np.isinf(val):
                                    val = 0.0
                                r_name = f.replace('_encoded', '')
                                row[r_name] = val
                    
                    # Add ADAS13
                    if 'ADAS13' in tab_feature_names:
                        idx = tab_feature_names.index('ADAS13')
                        if idx < len(tab_vals):
                            val = float(tab_vals[idx])
                            if np.isnan(val) or np.isinf(val):
                                val = 0.0
                            row['ADAS13'] = val
                    
                    # Add latent features
                    for k in range(len(feat_vals)):
                        row[f"z_{k}"] = float(feat_vals[k])
                    
                    rows.append(row)
    
    df = pd.DataFrame(rows).sort_values(['PTID', 'Years_bl'])
    
    # Final NaN/Inf check
    numeric_cols = df.select_dtypes(include=[np.number]).columns
    for col in numeric_cols:
        if df[col].isna().any():
            df[col].fillna(0, inplace=True)
        if np.isinf(df[col]).any():
            df[col] = df[col].replace([np.inf, -np.inf], 0)
    
    df.to_csv(output_path, index=False)
    
    print(f"\n✓ Exported to {output_path}")
    print(f"  Patients: {df['PTID'].nunique()}")
    print(f"  Visits: {len(df)}")
    print(f"  Features: {len([c for c in df.columns if c.startswith('z_')])}")
    
    return df

# ============================================================================
# MAIN
# ============================================================================

def main():
    print("=" * 80)
    print("BENCHMARK 4: CONCATENATION FUSION (vs Tensor Fusion)")
    print("=" * 80)
    
    device = CONFIG['device']
    print(f"\nDevice: {device}")
    
    # Load valid patients
    print("\nLoading valid patient list...")
    with open('VALID_PATIENTS.pkl', 'rb') as f:
        VALID_PATIENTS = pickle.load(f)
    print(f"Valid patients: {len(VALID_PATIENTS)}")
    
    manifest = pd.read_csv("./AD_Multimodal/TFN_AD/AD_Patient_Manifest.csv")
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    dataset = SequenceDataset(manifest, VALID_PATIENTS, transform, max_seq_len=CONFIG['max_seq_len'])
    print(f"Total sequences: {len(dataset)}")
    
    n_train = int(0.8 * len(dataset))
    n_val = len(dataset) - n_train
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42)
    )
    
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0)
    
    tab_dim = next(iter(train_loader))[1].shape[2]
    
    model = ConcatenationFusionAutoencoder(tab_dim, CONFIG).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
    
    print("\nTraining...")
    best_c_index = 0
    patience_counter = 0
    
    for epoch in range(CONFIG['epochs']):
        train_loss = train_epoch(model, train_loader, optimizer, CONFIG, device)
        val_c_index = validate(model, val_loader, device)
        
        scheduler.step(val_c_index)
        print(f"Epoch {epoch+1}/{CONFIG['epochs']} - Loss: {train_loss:.4f}, C-index: {val_c_index:.4f}")
        
        if val_c_index > best_c_index:
            best_c_index = val_c_index
            torch.save(model.state_dict(), 'concat_fusion_model.pth')
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= 15:
            break
    
    model.load_state_dict(torch.load('concat_fusion_model.pth'))
    full_loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0)
    export_features(model, full_loader, device, "concat_fusion_features.csv")
    
    print("\n" + "=" * 80)
    print(f"✓ BEST C-INDEX: {best_c_index:.4f}")
    print("⚠️  This uses the SAME patient cohort as all other benchmarks")
    print("=" * 80)

if __name__ == "__main__":
    main()
# """
# Benchmark 4: Concatenation Fusion
# Simple concatenation instead of tensor fusion

# Shows that tensor fusion is superior to naive concatenation.
# This is a critical ablation - proves your fusion method adds value!

# Architecture:
# - Same as thesis BUT: simple concatenation [img; tab] instead of tensor product
# - Everything else identical (LSTM, autoencoder, etc.)
# - Expected to be worse than thesis but better than single modality
# """

# import torch
# import torch.nn as nn
# from torchvision import models, transforms
# from torch.utils.data import Dataset, DataLoader
# from PIL import Image
# import pandas as pd
# import numpy as np
# import os
# import pickle
# from lifelines.utils import concordance_index
# import warnings
# warnings.filterwarnings('ignore')

# CONFIG = {
#     'latent_dim': 128,
#     'img_out_dim': 256,
#     'tab_out_dim': 64,
#     'lstm_hidden': 128,
#     'lstm_layers': 2,
#     'dropout': 0.3,
#     'lr': 5e-4,
#     'weight_decay': 1e-4,
#     'epochs': 100,
#     'batch_size': 16,
#     'max_seq_len': 10,
#     'alpha_recon': 0.2,
#     'alpha_survival': 0.6,
#     'alpha_mmse': 0.2,
#     'device': 'cuda' if torch.cuda.is_available() else 'cpu'
# }

# # ============================================================================
# # FEATURE ENGINEERING (Same as thesis)
# # ============================================================================

# STATIC_FEATURES = [
#     'age_bl', 'PTGENDER_encoded', 'PTEDUCAT', 'PTETHCAT_encoded', 
#     'PTRACCAT_encoded', 'PTMARRY_encoded'
# ]

# TEMPORAL_FEATURES = [
#     'time_from_baseline', 'AGE', 'age_since_bl', 'mmse_slope', 
#     'adas13_slope', 'dx_progression', 'cog_decline_index', 
#     'visit_number', 'MMSE', 'ADAS13'
# ]

# def engineer_features(df):
#     df = df.copy()
#     df["time_from_baseline"] = df["Years_bl"] - df["Years_bl"].iloc[0]
#     df["age_bl"] = df["AGE"].iloc[0]
#     df["age_since_bl"] = df["AGE"] - df["age_bl"]
#     df["mmse_slope"] = df["MMSE"].diff() / df["Years_bl"].diff()
#     df["adas13_slope"] = df["ADAS13"].diff() / df["Years_bl"].diff()
#     dx_map = {"CN": 0, "MCI": 1, "AD": 2, "Dementia": 2}
#     df["dx_progression"] = df["DX"].map(dx_map).diff()
#     df["cog_decline_index"] = df["ADAS13"] - df["MMSE"]
#     df["visit_number"] = range(len(df))
#     df['age_mmse_interaction'] = df['AGE'] * (30 - df['MMSE']) / 30
#     df['education_cognitive_reserve'] = df['PTEDUCAT'] * df['MMSE'] / 30
#     df['rapid_decline_flag'] = (df['mmse_slope'] < -2).astype(float)
#     mmse_bins = [0, 20, 24, 30]
#     df['mmse_severity'] = pd.cut(df['MMSE'], bins=mmse_bins, labels=[2, 1, 0]).astype(float)
#     df['weighted_mmse_decline'] = df['mmse_slope'] * np.exp(-0.1 * df['time_from_baseline'])
#     df['mmse_variability'] = df['MMSE'].rolling(window=3, min_periods=1).std()
#     df['adas_mmse_discordance'] = np.abs(
#         (df['ADAS13'] - df['ADAS13'].mean()) / (df['ADAS13'].std() + 1e-7) - 
#         (df['MMSE'] - df['MMSE'].mean()) / (df['MMSE'].std() + 1e-7)
#     )
#     df = df.fillna(0)
#     return df

# TEMPORAL_FEATURES.extend([
#     'age_mmse_interaction', 'education_cognitive_reserve', 'rapid_decline_flag',
#     'mmse_severity', 'weighted_mmse_decline', 'mmse_variability', 'adas_mmse_discordance'
# ])

# # ============================================================================
# # DATASET (Same as thesis)
# # ============================================================================

# class SequenceDataset(Dataset):
#     def __init__(self, manifest, valid_patients, transform=None, max_seq_len=10):
#         self.sequences = []
#         self.transform = transform
#         self.max_seq_len = max_seq_len
        
#         manifest["path"] = manifest["path"].str.replace("\\", "/", regex=False)
#         manifest["path"] = "./AD_Multimodal/TFN_AD/" + manifest["path"]

#         skipped_not_valid = 0
#         processed = 0
        
#         for ptid in manifest['PTID'].unique():
#             if ptid not in valid_patients:
#                 skipped_not_valid += 1
#                 continue
                
#             try:
#                 patient_rows = manifest[manifest['PTID'] == ptid]
#                 if len(patient_rows) == 0:
#                     continue
                
#                 df = pd.read_pickle(patient_rows.iloc[0]["path"])
#                 df = engineer_features(df)
                
#                 dx_seq = df["DX"].tolist()
#                 if "MCI" not in dx_seq:
#                     continue
                
#                 mci_idx = dx_seq.index("MCI")
#                 ad_idx = next((i for i, x in enumerate(dx_seq[mci_idx+1:], start=mci_idx+1) 
#                               if x in ["AD", "Dementia"]), -1)
                
#                 if ad_idx != -1:
#                     time_to_event = df["Years_bl"].iloc[ad_idx]
#                     event = 1
#                 else:
#                     time_to_event = df["Years_bl"].iloc[-1]
#                     event = 0
                
#                 imgs, tabs, times, mmse_vals = [], [], [], []
#                 valid_visits = 0
                
#                 for _, visit in df.iterrows():
#                     image_path = visit["image_path"].replace(
#                         "/home/mason/ADNI_Dataset/", 
#                         "./AD_Multimodal/ADNI_Dataset/"
#                     )
                    
#                     if not os.path.exists(image_path):
#                         continue
                    
#                     img = Image.open(image_path).convert("RGB")
#                     if self.transform:
#                         img = self.transform(img)
                    
#                     imgs.append(img)
#                     tabs.append(visit[TEMPORAL_FEATURES + STATIC_FEATURES].values.astype(np.float32))
#                     times.append(visit["Years_bl"])
#                     mmse_vals.append(visit["MMSE"])
#                     valid_visits += 1
                    
#                     if valid_visits >= max_seq_len:
#                         break
                
#                 if valid_visits < 2:
#                     continue
                
#                 pad_len = max_seq_len - len(imgs)
#                 if pad_len > 0:
#                     for _ in range(pad_len):
#                         imgs.append(torch.zeros_like(imgs[-1]))
#                         tabs.append(np.zeros_like(tabs[-1]))
#                         times.append(times[-1])
#                         mmse_vals.append(0.0)
                
#                 self.sequences.append({
#                     'ptid': ptid,
#                     'imgs': torch.stack(imgs),
#                     'tabs': np.array(tabs, dtype=np.float32),
#                     'times': np.array(times, dtype=np.float32),
#                     'mmse': np.array(mmse_vals, dtype=np.float32),
#                     'seq_len': valid_visits,
#                     'time_to_event': time_to_event,
#                     'event': event
#                 })

#                 processed += 1
                
#             except Exception as e:
#                 continue

#         print(f"  Processed: {processed} valid patients")
#         print(f"  Skipped (not in valid set): {skipped_not_valid}")
    
#     def __len__(self):
#         return len(self.sequences)
    
#     def __getitem__(self, idx):
#         seq = self.sequences[idx]
#         return (
#             seq['imgs'], seq['tabs'], seq['times'], seq['mmse'],
#             seq['seq_len'], seq['time_to_event'], seq['event'], seq['ptid']
#         )

# # ============================================================================
# # MODEL: CONCATENATION INSTEAD OF TENSOR FUSION
# # ============================================================================

# class AttentionImageEncoder(nn.Module):
#     def __init__(self, out_dim=256):
#         super().__init__()
#         base = models.resnet18(pretrained=True)
#         for param in list(base.parameters())[:-20]:
#             param.requires_grad = False
        
#         self.features = nn.Sequential(*list(base.children())[:-2])
#         self.attention = nn.Sequential(
#             nn.Conv2d(512, 256, 1),
#             nn.BatchNorm2d(256),
#             nn.ReLU(),
#             nn.Conv2d(256, 1, 1),
#             nn.Sigmoid()
#         )
#         self.global_pool = nn.AdaptiveAvgPool2d(1)
#         self.proj = nn.Linear(512, out_dim)
    
#     def forward(self, x):
#         feats = self.features(x)
#         attn = self.attention(feats)
#         feats = feats * attn
#         pooled = self.global_pool(feats).view(x.size(0), -1)
#         return self.proj(pooled)

# class ConcatenationFusionAutoencoder(nn.Module):
#     """
#     KEY DIFFERENCE: Uses simple concatenation [img; tab; time] 
#     instead of tensor product fusion
#     """
#     def __init__(self, tab_dim, config):
#         super().__init__()
#         self.config = config
        
#         self.img_encoder = AttentionImageEncoder(out_dim=config['img_out_dim'])
        
#         self.tab_encoder = nn.Sequential(
#             nn.Linear(tab_dim, 128),
#             nn.BatchNorm1d(128),
#             nn.ReLU(),
#             nn.Dropout(config['dropout']),
#             nn.Linear(128, config['tab_out_dim']),
#             nn.ReLU()
#         )
#         # self.tab_encoder = nn.Sequential(
#         #     nn.Linear(tab_dim, 32),
#         #     nn.ReLU(),
#         #     nn.Dropout(config['dropout']),
#         #     nn.Linear(32, 16),
#         #     nn.ReLU()
#         # )

        
#         self.time_encoder = nn.Sequential(
#             nn.Linear(1, 16),
#             nn.ReLU()
#         )
        
#         # CONCATENATION instead of tensor fusion
#         concat_dim = config['img_out_dim'] + config['tab_out_dim'] + 16
#         #concat_dim = config['img_out_dim'] + 16 + 16

        
#         self.concat_proj = nn.Sequential(
#             nn.Linear(concat_dim, config['latent_dim']),
#             nn.BatchNorm1d(config['latent_dim']),
#             nn.ReLU()
#         )
        
#         self.lstm = nn.LSTM(
#             input_size=config['latent_dim'],
#             hidden_size=config['lstm_hidden'],
#             num_layers=config['lstm_layers'],
#             batch_first=True,
#             dropout=config['dropout'] if config['lstm_layers'] > 1 else 0,
#             bidirectional=True
#         )
        
#         self.temporal_proj = nn.Sequential(
#             nn.Linear(config['lstm_hidden'] * 2, config['latent_dim']),
#             nn.ReLU()
#         )
        
#         self.decoder = nn.Sequential(
#             nn.Linear(config['latent_dim'], 128),
#             nn.ReLU(),
#             nn.Dropout(config['dropout']),
#             nn.Linear(128, tab_dim)
#         )
        
#         self.survival_head = nn.Sequential(
#             nn.Linear(config['latent_dim'], 64),
#             nn.ReLU(),
#             nn.Dropout(config['dropout']),
#             nn.Linear(64, 32),
#             nn.ReLU(),
#             nn.Linear(32, 1)
#         )
        
#         self.mmse_head = nn.Sequential(
#             nn.Linear(config['latent_dim'], 32),
#             nn.ReLU(),
#             nn.Linear(32, 1)
#         )
    
#     def encode_visit(self, img, tab, time):
#         v = self.img_encoder(img)
#         d = self.tab_encoder(tab)
#         t = self.time_encoder(time.unsqueeze(1))
        
#         # SIMPLE CONCATENATION (not tensor fusion!)
#         concat = torch.cat([v, d, t], dim=1)
#         z = self.concat_proj(concat)
        
#         return z
    
#     def forward(self, img_seq, tab_seq, time_seq, seq_lengths):
#         batch_size, seq_len = img_seq.shape[:2]
        
#         z_list = []
#         for t in range(seq_len):
#             z_t = self.encode_visit(img_seq[:, t], tab_seq[:, t], time_seq[:, t])
#             z_list.append(z_t)
        
#         z_seq = torch.stack(z_list, dim=1)
        
#         packed = nn.utils.rnn.pack_padded_sequence(
#             z_seq, seq_lengths.cpu(), batch_first=True, enforce_sorted=False
#         )
#         lstm_out, (h_n, c_n) = self.lstm(packed)
#         lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
        
#         h_forward = h_n[-2]
#         h_backward = h_n[-1]
#         h_final = torch.cat([h_forward, h_backward], dim=1)
        
#         z_final = self.temporal_proj(h_final)
        
#         tab_recon = self.decoder(z_final)
#         risk_score = self.survival_head(z_final)
#         mmse_pred = self.mmse_head(z_final)
        
#         return z_final, tab_recon, risk_score, mmse_pred

# # ============================================================================
# # LOSS & TRAINING (Same as thesis)
# # ============================================================================

# def cox_loss(risk_scores, times, events):
#     order = torch.argsort(times, descending=True)
#     risk_scores = risk_scores[order]
#     events = events[order]
#     log_risk = risk_scores.view(-1)
#     log_cumsum_hazard = torch.logcumsumexp(log_risk, dim=0)
#     loss = -(log_risk - log_cumsum_hazard) * events
#     return loss.sum() / (events.sum() + 1e-7)

# def train_epoch(model, loader, optimizer, config, device):
#     model.train()
#     total_loss = 0
    
#     for batch in loader:
#         imgs, tabs, times, mmse, seq_lens, t_event, event, _ = batch
        
#         imgs = imgs.to(device)
#         tabs = torch.FloatTensor(tabs).to(device)
#         times = torch.FloatTensor(times).to(device)
#         mmse = torch.FloatTensor(mmse).to(device)
#         seq_lens = torch.LongTensor(seq_lens)
#         t_event = t_event.float().to(device)
#         event = event.float().to(device)
        
#         last_tabs = torch.stack([tabs[i, seq_lens[i]-1] for i in range(len(seq_lens))])
#         last_mmse = torch.FloatTensor([mmse[i, seq_lens[i]-1] for i in range(len(seq_lens))]).to(device)
        
#         z_final, tab_recon, risk_scores, mmse_pred = model(imgs, tabs, times, seq_lens)
        
#         loss_recon = nn.MSELoss()(tab_recon, last_tabs)
#         loss_cox = cox_loss(risk_scores.squeeze(), t_event, event)
#         loss_mmse = nn.MSELoss()(mmse_pred.squeeze(), last_mmse)
        
#         loss = (config['alpha_recon'] * loss_recon + 
#                 config['alpha_survival'] * loss_cox + 
#                 config['alpha_mmse'] * loss_mmse)
        
#         optimizer.zero_grad()
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#         optimizer.step()
        
#         total_loss += loss.item()
    
#     return total_loss / len(loader)

# def validate(model, loader, device):
#     model.eval()
#     all_risks, all_times, all_events = [], [], []
    
#     with torch.no_grad():
#         for batch in loader:
#             imgs, tabs, times, mmse, seq_lens, t_event, event, _ = batch
            
#             imgs = imgs.to(device)
#             tabs = torch.FloatTensor(tabs).to(device)
#             times = torch.FloatTensor(times).to(device)
#             seq_lens = torch.LongTensor(seq_lens)
            
#             _, _, risk_scores, _ = model(imgs, tabs, times, seq_lens)
            
#             all_risks.extend(risk_scores.cpu().numpy().flatten())
#             all_times.extend(t_event.numpy())
#             all_events.extend(event.numpy())
    
#     c_index = concordance_index(np.array(all_times), -np.array(all_risks), np.array(all_events).astype(bool))
#     return c_index

# # ============================================================================
# # EXPORT
# # ============================================================================

# def export_features(model, loader, device, output_path):
#     """Export features with proper column structure"""
#     model.eval()
#     rows = []
    
#     BASELINE_FEATURES = ['AGE', 'PTGENDER_encoded', 'PTEDUCAT', 'ADAS13']
    
#     with torch.no_grad():
#         for batch in loader:
#             imgs, tabs, times, mmse, seq_lens, t_event, event, ptids = batch
            
#             imgs = imgs.to(device)
#             tabs = torch.FloatTensor(tabs).to(device)
#             times = torch.FloatTensor(times).to(device)
#             seq_lens = torch.LongTensor(seq_lens)
            
#             batch_size, seq_len = imgs.shape[:2]
            
#             for i in range(len(ptids)):
#                 slen = seq_lens[i].item()
                
#                 for t in range(slen):
#                     img_t = imgs[i:i+1, t]
#                     tab_t = tabs[i:i+1, t]
#                     time_t = times[i:i+1, t]
                    
#                     z_visit = model.encode_visit(img_t, tab_t, time_t)
#                     tab_vals = tabs[i, t].cpu().numpy()
                    
#                     row = {
#                         "PTID": ptids[i],
#                         "Years_bl": float(times[i, t].cpu()),
#                         "MMSE": float(mmse[i, t]),
#                         "time_to_event": float(t_event[i]),
#                         "event": int(event[i]),
#                     }
                    
#                     # Add clinical features with proper names
#                     tab_feature_names = TEMPORAL_FEATURES + STATIC_FEATURES
#                     for feat in BASELINE_FEATURES:
#                         if feat in tab_feature_names:
#                             idx = tab_feature_names.index(feat)
#                             if idx < len(tab_vals):
#                                 r_name = feat.replace('_encoded', '')
#                                 row[r_name] = float(tab_vals[idx])
                    
#                     # Add ADAS13 from tab_vals if available
#                     if 'ADAS13' in tab_feature_names:
#                         idx = tab_feature_names.index('ADAS13')
#                         if idx < len(tab_vals):
#                             row['ADAS13'] = float(tab_vals[idx])
                    
#                     # Add latent features
#                     z_vals = z_visit[0].cpu().numpy()
#                     for k in range(len(z_vals)):
#                         row[f"z_{k}"] = float(z_vals[k])
                    
#                     # # ZERO OUT TABULAR-LEAKAGE HALF
#                     # z_vals[int(len(z_vals)*0.5):] = 0
                    
#                     rows.append(row)
    
#     df = pd.DataFrame(rows).sort_values(['PTID', 'Years_bl'])
    
#     # Ensure PTID is unique per row for grouping
#     common_ptids = set(df['PTID'])
#     df = df[df['PTID'].isin(common_ptids)]
    
#     df.to_csv(output_path, index=False)
    
#     print(f"\n✓ Exported to {output_path}")
#     print(f"  Patients: {df['PTID'].nunique()}")
#     print(f"  Visits: {len(df)}")
#     print(f"  Latent features: {len([c for c in df.columns if c.startswith('z_')])}")
    
#     return df

# # ============================================================================
# # MAIN
# # ============================================================================

# def main():
#     print("=" * 80)
#     print("BENCHMARK 4: CONCATENATION FUSION (vs Tensor Fusion)")
#     print("=" * 80)
    
#     device = CONFIG['device']
#     manifest = pd.read_csv("./AD_Multimodal/TFN_AD/AD_Patient_Manifest.csv")

#     # Load valid patients
#     print("\nLoading valid patient list...")
#     with open('VALID_PATIENTS.pkl', 'rb') as f:
#         VALID_PATIENTS = pickle.load(f)
#     print(f"Valid patients: {len(VALID_PATIENTS)}")
    
#     transform = transforms.Compose([
#         transforms.Resize((224, 224)),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#     ])
    
#     dataset = SequenceDataset(manifest, VALID_PATIENTS, transform, max_seq_len=CONFIG['max_seq_len'])
#     print(f"Total sequences: {len(dataset)}")
    
#     n_train = int(0.8 * len(dataset))
#     n_val = len(dataset) - n_train
#     train_dataset, val_dataset = torch.utils.data.random_split(
#         dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42)
#     )
    
#     train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=0)
#     val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0)
    
#     tab_dim = next(iter(train_loader))[1].shape[2]
    
#     model = ConcatenationFusionAutoencoder(tab_dim, CONFIG).to(device)
#     optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
    
#     print("\nTraining...")
#     best_c_index = 0
#     patience_counter = 0
    
#     for epoch in range(CONFIG['epochs']):
#         train_loss = train_epoch(model, train_loader, optimizer, CONFIG, device)
#         val_c_index = validate(model, val_loader, device)
        
#         scheduler.step(val_c_index)
#         print(f"Epoch {epoch+1}/{CONFIG['epochs']} - Loss: {train_loss:.4f}, C-index: {val_c_index:.4f}")
        
#         if val_c_index > best_c_index:
#             best_c_index = val_c_index
#             torch.save(model.state_dict(), 'concat_fusion_model.pth')
#             patience_counter = 0
#         else:
#             patience_counter += 1
        
#         if patience_counter >= 15:
#             break
    
#     model.load_state_dict(torch.load('concat_fusion_model.pth'))
#     full_loader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0)
#     export_features(model, full_loader, device, "concat_fusion_features.csv")
    
#     print("\n" + "=" * 80)
#     print(f"✓ BEST C-INDEX: {best_c_index:.4f}")
#     print("=" * 80)

# if __name__ == "__main__":
#     main()

BENCHMARK 4: CONCATENATION FUSION (vs Tensor Fusion)

Device: cuda

Loading valid patient list...
Valid patients: 161
  Processed: 161 valid patients
  Skipped (not in valid set): 221
Total sequences: 161

Training...
Epoch 1/100 - Loss: 130.8061, C-index: 0.4272
Epoch 2/100 - Loss: 126.9472, C-index: 0.6052
Epoch 3/100 - Loss: 112.3578, C-index: 0.7443
Epoch 4/100 - Loss: 88.8343, C-index: 0.5890
Epoch 5/100 - Loss: 61.7459, C-index: 0.2104
Epoch 6/100 - Loss: 35.4176, C-index: 0.2006
Epoch 7/100 - Loss: 15.2519, C-index: 0.2201
Epoch 8/100 - Loss: 5.8667, C-index: 0.2395
Epoch 9/100 - Loss: 7.3109, C-index: 0.2427
Epoch 10/100 - Loss: 6.0735, C-index: 0.2751
Epoch 11/100 - Loss: 5.5766, C-index: 0.3010
Epoch 12/100 - Loss: 5.4346, C-index: 0.2783
Epoch 13/100 - Loss: 5.0478, C-index: 0.2071
Epoch 14/100 - Loss: 3.9996, C-index: 0.1748
Epoch 15/100 - Loss: 3.6109, C-index: 0.1748
Epoch 16/100 - Loss: 3.1099, C-index: 0.1812
Epoch 17/100 - Loss: 3.4874, C-index: 0.1812
Epoch 18/100 - L