In [None]:
# ADNI Diagnosis-Based Subject Sampling and Splitting Notebook

# This notebook filters and samples subjects from ADNIMERGE clinical data,
# balances across CN, MCI, AD groups, and splits into train/val/test sets.

import pandas as pd
from sklearn.model_selection import train_test_split

# ---- Step 1: Load ADNIMERGE data from GitHub ----
url = "https://raw.githubusercontent.com/treyschulman/cs598dlh_project/main/data/ADNIMERGE_17Apr2025.csv"
df = pd.read_csv(url)

# ---- Step 2: Filter to baseline visit only ----
df_bl = df[df["VISCODE"] == "bl"]

# ---- Step 3: Keep necessary columns and drop missing values ----
columns_to_keep = ["RID", "PTID", "DX_bl", "AGE", "PTGENDER"]
df_bl_filtered = df_bl[columns_to_keep].dropna(subset=["DX_bl", "AGE", "PTGENDER"])

# ---- Step 4: Normalize diagnosis labels ----
df_bl_filtered["DX_bl"] = df_bl_filtered["DX_bl"].replace({"EMCI": "MCI", "LMCI": "MCI"})

# ---- Step 5: Keep only CN, MCI, AD groups ----
df_bl_filtered = df_bl_filtered[df_bl_filtered["DX_bl"].isin(["CN", "MCI", "AD"])]

# ---- Step 6: Sample 15 subjects per class ----
sampled_df = df_bl_filtered.groupby("DX_bl", group_keys=False).apply(lambda x: x.sample(15, random_state=42))

# ---- Step 7: Stratified split into train (70%), val (15%), test (15%) ----
train_val_df, test_df = train_test_split(sampled_df, test_size=0.15, stratify=sampled_df["DX_bl"], random_state=42)
train_df, val_df = train_test_split(train_val_df, test_size=0.1765, stratify=train_val_df["DX_bl"], random_state=42)

# ---- Step 8: Tag split and combine ----
train_df = train_df.copy(); train_df["Split"] = "Train"
val_df = val_df.copy(); val_df["Split"] = "Validation"
test_df = test_df.copy(); test_df["Split"] = "Test"

final_df = pd.concat([train_df, val_df, test_df]).sort_values(by=["Split", "DX_bl"])

# ---- Step 9: Save split files ----
final_df.to_csv("adni_subject_splits.csv", index=False)
train_df.to_csv("train_subjects.csv", index=False)
val_df.to_csv("val_subjects.csv", index=False)
test_df.to_csv("test_subjects.csv", index=False)

# ---- Step 10: Display summary ----
print("\nFinal Split Counts:")
print(final_df.groupby(["Split", "DX_bl"]).size())

final_df.head()
