In [1]:
"""
Benchmark 1: Clinical Cox Model (Baseline)
Traditional survival analysis with clinical features only

FIXED VERSION:
- Uses VALID_PATIENTS.pkl for consistent patient cohort
- Survival time measured from MCI diagnosis (consistent with thesis)
- Proper feature engineering without data leakage
- Clinical features only (no deep learning)
"""

import pandas as pd
import numpy as np
import os
import pickle
from lifelines.utils import concordance_index
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# FEATURE ENGINEERING
# ============================================================================

DEMOGRAPHIC_COLUMNS = [
    'AGE', 'PTGENDER_encoded', 'PTEDUCAT', 'PTETHCAT_encoded', 
    'PTRACCAT_encoded', 'PTMARRY_encoded'
]

STATIC_FEATURES = [
    'age_bl', 'PTGENDER_encoded', 'PTEDUCAT', 'PTETHCAT_encoded', 
    'PTRACCAT_encoded', 'PTMARRY_encoded'
]

TEMPORAL_FEATURES = [
    'time_from_baseline', 'AGE', 'age_since_bl', 'mmse_slope', 
    'adas13_slope', 'dx_progression', 'cog_decline_index', 
    'visit_number', 'MMSE', 'ADAS13'
]

def engineer_features(df):
    """Enhanced feature engineering (same as thesis for fair comparison)"""
    df = df.copy()
    
    df["time_from_baseline"] = df["Years_bl"] - df["Years_bl"].iloc[0]
    df["age_bl"] = df["AGE"].iloc[0]
    df["age_since_bl"] = df["AGE"] - df["age_bl"]
    
    df["mmse_slope"] = df["MMSE"].diff() / df["Years_bl"].diff()
    df["adas13_slope"] = df["ADAS13"].diff() / df["Years_bl"].diff()
    
    dx_map = {"CN": 0, "MCI": 1, "AD": 2, "Dementia": 2}
    df["dx_progression"] = df["DX"].map(dx_map).diff()
    
    df["cog_decline_index"] = df["ADAS13"] - df["MMSE"]
    df["visit_number"] = range(len(df))
    
    df['age_mmse_interaction'] = df['AGE'] * (30 - df['MMSE']) / 30
    df['education_cognitive_reserve'] = df['PTEDUCAT'] * df['MMSE'] / 30
    df['rapid_decline_flag'] = (df['mmse_slope'] < -2).astype(float)
    
    mmse_bins = [0, 20, 24, 30]
    df['mmse_severity'] = pd.cut(df['MMSE'], bins=mmse_bins, labels=[2, 1, 0]).astype(float)
    
    df['weighted_mmse_decline'] = df['mmse_slope'] * np.exp(-0.1 * df['time_from_baseline'])
    df['mmse_variability'] = df['MMSE'].rolling(window=3, min_periods=1).std()
    
    df['adas_mmse_discordance'] = np.abs(
        (df['ADAS13'] - df['ADAS13'].mean()) / (df['ADAS13'].std() + 1e-7) - 
        (df['MMSE'] - df['MMSE'].mean()) / (df['MMSE'].std() + 1e-7)
    )
    
    df = df.fillna(0)
    
    return df

TEMPORAL_FEATURES.extend([
    'age_mmse_interaction', 'education_cognitive_reserve', 'rapid_decline_flag',
    'mmse_severity', 'weighted_mmse_decline', 'mmse_variability', 'adas_mmse_discordance'
])

# ============================================================================
# MAIN
# ============================================================================

def main():
    print("=" * 80)
    print("BENCHMARK 1: CLINICAL COX MODEL (BASELINE)")
    print("=" * 80)
    
    # Load valid patients
    print("\nLoading valid patient list...")
    with open('VALID_PATIENTS.pkl', 'rb') as f:
        VALID_PATIENTS = pickle.load(f)
    print(f"Valid patients to process: {len(VALID_PATIENTS)}")
    
    # Load manifest
    manifest = pd.read_csv("./AD_Multimodal/TFN_AD/AD_Patient_Manifest.csv")
    manifest["path"] = manifest["path"].str.replace("\\", "/", regex=False)
    manifest["path"] = "./AD_Multimodal/TFN_AD/" + manifest["path"]
    
    all_rows = []
    processed = 0
    skipped = 0
    
    print("\nProcessing patients...")
    for ptid in manifest['PTID'].unique():
        # CRITICAL: Only process valid patients
        if ptid not in VALID_PATIENTS:
            skipped += 1
            continue
        
        try:
            patient_rows = manifest[manifest['PTID'] == ptid]
            df = pd.read_pickle(patient_rows.iloc[0]["path"])
            df = engineer_features(df)
            
            dx_seq = df["DX"].tolist()
            if "MCI" not in dx_seq:
                continue
            
            # FIXED: Survival time from MCI diagnosis (consistent with thesis)
            mci_idx = dx_seq.index("MCI")
            ad_idx = next((i for i, x in enumerate(dx_seq[mci_idx+1:], start=mci_idx+1) 
                          if x in ["AD", "Dementia"]), -1)
            
            if ad_idx != -1:
                time_to_event = df["Years_bl"].iloc[ad_idx] - df["Years_bl"].iloc[mci_idx]
                event = 1
            else:
                time_to_event = df["Years_bl"].iloc[-1] - df["Years_bl"].iloc[mci_idx]
                event = 0
            
            # Extract all visits (for longitudinal modeling in R)
            for _, visit in df.iterrows():
                row = {
                    "PTID": ptid,
                    "Years_bl": visit["Years_bl"],
                    "MMSE": visit["MMSE"],
                    "ADAS13": visit["ADAS13"],
                    "time_to_event": time_to_event,
                    "event": event,
                }
                
                # Add all clinical features
                for feat in TEMPORAL_FEATURES + STATIC_FEATURES:
                    if feat in df.columns:
                        row[feat] = visit[feat]
                
                # Add demographics (renamed for R compatibility)
                row["AGE"] = visit["AGE"]
                row["PTGENDER"] = visit["PTGENDER_encoded"]
                row["PTEDUCAT"] = visit["PTEDUCAT"]
                
                all_rows.append(row)
            
            processed += 1
            
        except Exception as e:
            print(f"⚠️ Skipped patient {ptid}: {e}")
            continue
    
    # Create DataFrame
    df_out = pd.DataFrame(all_rows)
    df_out = df_out.sort_values(['PTID', 'Years_bl'])
    
    # Save
    output_path = "baseline_clinical_features.csv"
    df_out.to_csv(output_path, index=False)
    
    # Summary
    n_patients = df_out['PTID'].nunique()
    n_visits = len(df_out)
    n_events = df_out.groupby('PTID')['event'].first().sum()
    
    print("\n" + "=" * 80)
    print("✓ CLINICAL FEATURES EXTRACTED")
    print("=" * 80)
    print(f"\nOutput: {output_path}")
    print(f"  - Total valid patients: {len(VALID_PATIENTS)}")
    print(f"  - Processed: {processed}")
    print(f"  - Skipped (not in valid set): {skipped}")
    print(f"  - Final patients in output: {n_patients}")
    print(f"  - Total visits: {n_visits}")
    print(f"  - Events: {n_events} ({100*n_events/n_patients:.1f}%)")
    print(f"  - Features: {len([c for c in df_out.columns if c not in ['PTID', 'Years_bl', 'time_to_event', 'event']])}")
    print("\n⚠️  This uses the SAME patient cohort as all other benchmarks")
    print("=" * 80)

if __name__ == "__main__":
    main()
# """
# Benchmark 1: Clinical Cox Model (Baseline)
# Traditional survival analysis with clinical features only

# This is the standard approach in Alzheimer's research:
# - Uses only tabular clinical data (no images)
# - Simple Cox proportional hazards model
# - No deep learning, no latent features
# - Provides baseline performance to beat
# """

# import pandas as pd
# import numpy as np
# import os
# from lifelines.utils import concordance_index
# import warnings
# warnings.filterwarnings('ignore')

# # ============================================================================
# # FEATURE ENGINEERING (Same as thesis for fair comparison)
# # ============================================================================

# DEMOGRAPHIC_COLUMNS = [
#     'AGE', 'PTGENDER_encoded', 'PTEDUCAT', 'PTETHCAT_encoded', 
#     'PTRACCAT_encoded', 'PTMARRY_encoded'
# ]

# STATIC_FEATURES = [
#     'age_bl', 'PTGENDER_encoded', 'PTEDUCAT', 'PTETHCAT_encoded', 
#     'PTRACCAT_encoded', 'PTMARRY_encoded'
# ]

# TEMPORAL_FEATURES = [
#     'time_from_baseline', 'AGE', 'age_since_bl', 'mmse_slope', 
#     'adas13_slope', 'dx_progression', 'cog_decline_index', 
#     'visit_number', 'MMSE', 'ADAS13'
# ]

# def engineer_features(df):
#     """Enhanced feature engineering (same as thesis)"""
#     df = df.copy()
    
#     df["time_from_baseline"] = df["Years_bl"] - df["Years_bl"].iloc[0]
#     df["age_bl"] = df["AGE"].iloc[0]
#     df["age_since_bl"] = df["AGE"] - df["age_bl"]
    
#     df["mmse_slope"] = df["MMSE"].diff() / df["Years_bl"].diff()
#     df["adas13_slope"] = df["ADAS13"].diff() / df["Years_bl"].diff()
    
#     dx_map = {"CN": 0, "MCI": 1, "AD": 2, "Dementia": 2}
#     df["dx_progression"] = df["DX"].map(dx_map).diff()
    
#     df["cog_decline_index"] = df["ADAS13"] - df["MMSE"]
#     df["visit_number"] = range(len(df))
    
#     df['age_mmse_interaction'] = df['AGE'] * (30 - df['MMSE']) / 30
#     df['education_cognitive_reserve'] = df['PTEDUCAT'] * df['MMSE'] / 30
#     df['rapid_decline_flag'] = (df['mmse_slope'] < -2).astype(float)
    
#     mmse_bins = [0, 20, 24, 30]
#     df['mmse_severity'] = pd.cut(df['MMSE'], bins=mmse_bins, labels=[2, 1, 0]).astype(float)
    
#     df['weighted_mmse_decline'] = df['mmse_slope'] * np.exp(-0.1 * df['time_from_baseline'])
#     df['mmse_variability'] = df['MMSE'].rolling(window=3, min_periods=1).std()
    
#     df['adas_mmse_discordance'] = np.abs(
#         (df['ADAS13'] - df['ADAS13'].mean()) / (df['ADAS13'].std() + 1e-7) - 
#         (df['MMSE'] - df['MMSE'].mean()) / (df['MMSE'].std() + 1e-7)
#     )
    
#     df = df.fillna(0)
    
#     return df

# TEMPORAL_FEATURES.extend([
#     'age_mmse_interaction', 'education_cognitive_reserve', 'rapid_decline_flag',
#     'mmse_severity', 'weighted_mmse_decline', 'mmse_variability', 'adas_mmse_discordance'
# ])

# # ============================================================================
# # MAIN: EXTRACT CLINICAL FEATURES FOR R
# # ============================================================================

# def main():
#     print("=" * 80)
#     print("BENCHMARK 1: CLINICAL COX MODEL (BASELINE)")
#     print("=" * 80)
    
#     print("\nLoading data...")
#     manifest = pd.read_csv("./AD_Multimodal/TFN_AD/AD_Patient_Manifest.csv")
    
#     manifest["path"] = manifest["path"].str.replace("\\", "/", regex=False)
#     manifest["path"] = "./AD_Multimodal/TFN_AD/" + manifest["path"]
    
#     all_rows = []
    
#     print("Processing patients...")
#     for ptid in manifest['PTID'].unique():
#         try:
#             patient_rows = manifest[manifest['PTID'] == ptid]
#             if len(patient_rows) == 0:
#                 continue
            
#             df = pd.read_pickle(patient_rows.iloc[0]["path"])
#             df = engineer_features(df)
            
#             dx_seq = df["DX"].tolist()
#             if "MCI" not in dx_seq:
#                 continue
            
#             # Compute survival info
#             mci_idx = dx_seq.index("MCI")
#             ad_idx = next((i for i, x in enumerate(dx_seq[mci_idx+1:], start=mci_idx+1) 
#                           if x in ["AD", "Dementia"]), -1)
            
#             if ad_idx != -1:
#                 time_to_event = df["Years_bl"].iloc[ad_idx]
#                 event = 1
#             else:
#                 time_to_event = df["Years_bl"].iloc[-1]
#                 event = 0
            
#             # Extract all visits (for longitudinal modeling in R)
#             for _, visit in df.iterrows():
#                 row = {
#                     "PTID": ptid,
#                     "Years_bl": visit["Years_bl"],
#                     "MMSE": visit["MMSE"],
#                     "ADAS13": visit["ADAS13"],
#                     "time_to_event": time_to_event,
#                     "event": event,
#                 }
                
#                 # Add all clinical features
#                 for feat in TEMPORAL_FEATURES + STATIC_FEATURES:
#                     if feat in df.columns:
#                         row[feat] = visit[feat]
                
#                 # Add demographics (renamed for R compatibility)
#                 row["AGE"] = visit["AGE"]
#                 row["PTGENDER"] = visit["PTGENDER_encoded"]
#                 row["PTEDUCAT"] = visit["PTEDUCAT"]
                
#                 all_rows.append(row)
            
#         except Exception as e:
#             print(f"⚠️ Skipped patient {ptid}: {e}")
#             continue
    
#     # Create DataFrame
#     df_out = pd.DataFrame(all_rows)
#     df_out = df_out.sort_values(['PTID', 'Years_bl'])
    
#     # Save
#     output_path = "baseline_clinical_features.csv"
#     df_out.to_csv(output_path, index=False)
    
#     # Summary
#     n_patients = df_out['PTID'].nunique()
#     n_visits = len(df_out)
#     n_events = df_out.groupby('PTID')['event'].first().sum()
    
#     print("\n" + "=" * 80)
#     print("✓ CLINICAL FEATURES EXTRACTED")
#     print("=" * 80)
#     print(f"\nOutput: {output_path}")
#     print(f"  - Patients: {n_patients}")
#     print(f"  - Total visits: {n_visits}")
#     print(f"  - Events: {n_events} ({100*n_events/n_patients:.1f}%)")
#     print(f"  - Features: {len([c for c in df_out.columns if c not in ['PTID', 'Years_bl', 'time_to_event', 'event']])}")
#     print("\nNext step: Run R script with this baseline file")
#     print("=" * 80)

# if __name__ == "__main__":
#     main()

BENCHMARK 1: CLINICAL COX MODEL (BASELINE)

Loading valid patient list...
Valid patients to process: 161

Processing patients...

✓ CLINICAL FEATURES EXTRACTED

Output: baseline_clinical_features.csv
  - Total valid patients: 161
  - Processed: 161
  - Skipped (not in valid set): 221
  - Final patients in output: 161
  - Total visits: 948
  - Events: 74 (46.0%)
  - Features: 24

⚠️  This uses the SAME patient cohort as all other benchmarks
