In [1]:
"""
STEP 0: MASTER PATIENT SELECTION
Creates a single list of valid patients that ALL methods will use

This ensures fair comparison by evaluating all methods on the SAME patients.
Patients must have:
1. MCI diagnosis at some point
2. At least 2 visits with valid MRI images
3. Complete clinical data
"""

import pandas as pd
import numpy as np
import os
import pickle

def check_patient_validity(ptid, manifest):
    """Check if patient meets all criteria"""
    try:
        patient_rows = manifest[manifest['PTID'] == ptid]
        if len(patient_rows) == 0:
            return False
        
        # Load patient data
        path = patient_rows.iloc[0]["path"]
        df = pd.read_pickle(path)
        
        # Criterion 1: Has MCI diagnosis
        dx_seq = df["DX"].tolist()
        if "MCI" not in dx_seq:
            return False
        
        # Criterion 2: At least 2 visits with valid images
        valid_image_count = 0
        for _, visit in df.iterrows():
            image_path = visit["image_path"].replace(
                "/home/mason/ADNI_Dataset/", 
                "./AD_Multimodal/ADNI_Dataset/"
            )
            if os.path.exists(image_path):
                valid_image_count += 1
        
        if valid_image_count < 2:
            return False
        
        # Criterion 3: Complete clinical data (no NaN in key fields)
        required_fields = ['AGE', 'PTGENDER_encoded', 'PTEDUCAT', 'MMSE', 'ADAS13']
        for field in required_fields:
            if field not in df.columns or df[field].isna().all():
                return False
        
        return True
        
    except Exception as e:
        return False

def main():
    print("=" * 80)
    print("STEP 0: MASTER PATIENT SELECTION")
    print("=" * 80)
    
    # Load manifest
    manifest = pd.read_csv("./AD_Multimodal/TFN_AD/AD_Patient_Manifest.csv")
    manifest["path"] = manifest["path"].str.replace("\\", "/", regex=False)
    manifest["path"] = "./AD_Multimodal/TFN_AD/" + manifest["path"]
    
    print(f"\nTotal patients in manifest: {manifest['PTID'].nunique()}")
    
    # Check each patient
    valid_patients = []
    invalid_reasons = {
        'no_mci': 0,
        'insufficient_images': 0,
        'missing_clinical': 0,
        'other_error': 0
    }
    
    for ptid in manifest['PTID'].unique():
        try:
            patient_rows = manifest[manifest['PTID'] == ptid]
            path = patient_rows.iloc[0]["path"]
            df = pd.read_pickle(path)
            
            # Check MCI
            dx_seq = df["DX"].tolist()
            if "MCI" not in dx_seq:
                invalid_reasons['no_mci'] += 1
                continue
            
            # Check images
            valid_image_count = 0
            for _, visit in df.iterrows():
                image_path = visit["image_path"].replace(
                    "/home/mason/ADNI_Dataset/", 
                    "./AD_Multimodal/ADNI_Dataset/"
                )
                if os.path.exists(image_path):
                    valid_image_count += 1
            
            if valid_image_count < 2:
                invalid_reasons['insufficient_images'] += 1
                continue
            
            # Check clinical data
            required_fields = ['AGE', 'PTGENDER_encoded', 'PTEDUCAT', 'MMSE', 'ADAS13']
            has_complete_data = True
            for field in required_fields:
                if field not in df.columns or df[field].isna().all():
                    has_complete_data = False
                    break
            
            if not has_complete_data:
                invalid_reasons['missing_clinical'] += 1
                continue
            
            # Compute survival outcome
            mci_idx = dx_seq.index("MCI")
            ad_idx = next((i for i, x in enumerate(dx_seq[mci_idx+1:], start=mci_idx+1) 
                          if x in ["AD", "Dementia"]), -1)
            
            if ad_idx != -1:
                time_to_event = df["Years_bl"].iloc[ad_idx] - df["Years_bl"].iloc[mci_idx]
                event = 1
            else:
                time_to_event = df["Years_bl"].iloc[-1] - df["Years_bl"].iloc[mci_idx]
                event = 0
            
            # Patient is valid!
            valid_patients.append({
                'PTID': ptid,
                'n_visits': len(df),
                'n_images': valid_image_count,
                'time_to_event': time_to_event,
                'event': event
            })
            
        except Exception as e:
            invalid_reasons['other_error'] += 1
            continue
    
    # Create DataFrame
    valid_df = pd.DataFrame(valid_patients)
    
    # Print summary
    print("\n" + "=" * 80)
    print("PATIENT SELECTION SUMMARY")
    print("=" * 80)
    print(f"\nValid patients: {len(valid_df)}")
    print(f"  - Converters (MCI→AD): {valid_df['event'].sum()}")
    print(f"  - Non-converters: {len(valid_df) - valid_df['event'].sum()}")
    print(f"  - Conversion rate: {100*valid_df['event'].mean():.1f}%")
    
    print(f"\nInvalid patients: {manifest['PTID'].nunique() - len(valid_df)}")
    print(f"  - No MCI diagnosis: {invalid_reasons['no_mci']}")
    print(f"  - Insufficient images (<2): {invalid_reasons['insufficient_images']}")
    print(f"  - Missing clinical data: {invalid_reasons['missing_clinical']}")
    print(f"  - Other errors: {invalid_reasons['other_error']}")
    
    print(f"\nMean visits per patient: {valid_df['n_visits'].mean():.1f}")
    print(f"Mean images per patient: {valid_df['n_images'].mean():.1f}")
    print(f"Mean follow-up time: {valid_df['time_to_event'].mean():.2f} years")
    
    # Save valid patient list
    valid_patient_ids = set(valid_df['PTID'].tolist())
    
    with open('VALID_PATIENTS.pkl', 'wb') as f:
        pickle.dump(valid_patient_ids, f)
    
    valid_df.to_csv('VALID_PATIENTS_INFO.csv', index=False)
    
    print("\n" + "=" * 80)
    print("✓ SAVED FILES:")
    print("=" * 80)
    print("  - VALID_PATIENTS.pkl (Python set for benchmarks)")
    print("  - VALID_PATIENTS_INFO.csv (Summary table)")
    print("\n⚠️  ALL BENCHMARK SCRIPTS MUST USE THESE PATIENTS!")
    print("=" * 80)

if __name__ == "__main__":
    main()

STEP 0: MASTER PATIENT SELECTION

Total patients in manifest: 382

PATIENT SELECTION SUMMARY

Valid patients: 161
  - Converters (MCI→AD): 74
  - Non-converters: 87
  - Conversion rate: 46.0%

Invalid patients: 221
  - No MCI diagnosis: 221
  - Insufficient images (<2): 0
  - Missing clinical data: 0
  - Other errors: 0

Mean visits per patient: 5.9
Mean images per patient: 5.9
Mean follow-up time: 2.26 years

✓ SAVED FILES:
  - VALID_PATIENTS.pkl (Python set for benchmarks)
  - VALID_PATIENTS_INFO.csv (Summary table)

⚠️  ALL BENCHMARK SCRIPTS MUST USE THESE PATIENTS!
