In [None]:
# nacc_raw = nacc_raw.merge(
#     df_nacc_resilience[["SUBJ_ID", "TN_group"]],
#     left_on="subject_id",
#     right_on='SUBJ_ID',
#     how="left"
# )
# nacc_raw = nacc_raw.dropna(subset=['TN_group'])


## NACc Check Copathology
path_cols = [
    "NACC_IMCI",
    "NACC_MCI",
    "NACC_AD",
    "NACC_PD",
    "NACC_LBD",
    "NACC_VD",
    "NACC_PCA",
    "NACC_FTD_ANY"
]

ad_cols = ["NACC_IMCI", "NACC_MCI", "NACC_AD"]
non_ad_cols = [c for c in path_cols if c not in ad_cols]


def assign_pathology(row):

    ad_present = any(row[c] == 1 for c in ad_cols)
    non_ad_pos = [c.replace("NACC_", "") for c in non_ad_cols if row[c] == 1]

    # -------------------------
    # Pure AD spectrum
    # -------------------------
    if ad_present and len(non_ad_pos) == 0:
        return "AD_SPECTRUM"

    # -------------------------
    # Mixed AD + others
    # -------------------------
    if ad_present and len(non_ad_pos) > 0:
        return "AD+" + "+".join(sorted(non_ad_pos))

    # -------------------------
    # Non-AD only
    # -------------------------
    if not ad_present and len(non_ad_pos) > 0:
        return "+".join(sorted(non_ad_pos))

    # -------------------------
    # No pathology
    # -------------------------
    return "None"


nacc_raw["PATH_LABEL"] = nacc_raw.apply(assign_pathology, axis=1)

print(nacc_raw["PATH_LABEL"].value_counts())


def collapse_train_group(path_label):
    
    if path_label == "AD_SPECTRUM":
        return "AD_PURE"

    if path_label.startswith("AD+"):
        return "AD_MIXED"

    if path_label == "None":
        return "None"

    return "OTHER"


nacc_raw["TRAIN_GROUP"] = nacc_raw["PATH_LABEL"].apply(collapse_train_group)

print(nacc_raw["TRAIN_GROUP"].value_counts())



In [1]:
import pandas as pd
import numpy as np
from data_processor import *
from lda_model import LDATopicModel
from classifier import TopicClassifier
from visualizer import *
from brain_visualizer import *

data_path = 'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/data'
inp_df = pd.read_csv(data_path + '/260120_wsev_smc_combined_zscores.csv')
# inp_df = inp_df[~inp_df["DX"].isin(['svPPA', 'nfvPPA'])] #######################
print(inp_df['DX'].value_counts())


df_nacc_resilience = pd.read_csv(data_path + '/nacc/NACC_resilience_inference.csv')
df_adni4_resilience = pd.read_csv(data_path + '/adni/ADNI4_resilience_inference.csv')
df_nacc_resilience = pd.read_csv('C:/Users/BREIN/Desktop/stage_copath/20260122_NACC_linear_group.csv')
df_adni4_resilience = df_adni4_resilience.rename(columns={"FULL_ID": "SUBJ_ID"})
df_nacc_resilience = df_nacc_resilience.rename(columns={"subject_id" : "SUBJ_ID"})
nacc_raw = pd.read_csv(data_path + '/nacc/260120_NACC_VA_TAU_PATH_matched.csv')

ModuleNotFoundError: No module named 'xgboost'

In [None]:
## SAMPLE MCI FROM NACC AND ADD TO TRAIN DATA ##
mci_candidates = nacc_raw[
    (nacc_raw["NACC_MCI"] == 1) &  # MCI positive
    (nacc_raw['DX'] == 'MCI') &
    (~nacc_raw["subject_id"].isin(df_nacc_resilience["FULL_ID"]))  # not already in resilience
]
print(f"Total MCI candidates outside resilience: {len(mci_candidates)}")
N = 25  # number you want to add
mci_sample = mci_candidates.sample(n=min(N, len(mci_candidates)), random_state=42)

region_cols = nacc_raw.loc[:, 'VA/2':'VA/2035'].columns
mci_prep = DataProcessor(
    region_cols=region_cols,
    dx_col='DX',
    subject_col='subject_id'
)
region_cols = nacc_raw.loc[:, 'VA/2':'VA/2035'].columns
pathology_cols = nacc_raw.loc[:, 'NACC_AD':'NACC_svPPA'].columns
nacc_filtered = nacc_raw[nacc_raw['DX'] != 'Unknown']

nacc_cn = nacc_filtered[nacc_filtered['DX'] == 'CN']

mci_prep.fit_baseline(hc_data=nacc_cn)
nacc_Z = mci_prep.compute_atrophy_scores(data=mci_sample)
print(type(nacc_Z))
print(nacc_Z.shape)

import pandas as pd

# Convert the Z-scores array to a DataFrame
nacc_Z_df = pd.DataFrame(nacc_Z, columns=region_cols)

# Add subject IDs
nacc_Z_df['SUBJ_ID'] = mci_sample['subject_id'].values

# Optional: add DX if you want
nacc_Z_df['DX'] = mci_sample['DX'].values

# Now you can concatenate
inp_df = pd.concat([inp_df, nacc_Z_df], ignore_index=True)
print(f"New training dataset shape: {inp_df.shape}")

print(inp_df['DX'].value_counts())

## DOWNSAMPLE LARGE DX 
N = 25
dx_col = "DX"
balanced_parts = []

for dx, g in inp_df.groupby(dx_col):
    # if dx == 'AD':
    #     N=50
    # else: 
    #     N=25
    if len(g) > N:
        g = g.sample(n=N, replace=False, random_state=42)
    balanced_parts.append(g)

balanced_df = pd.concat(balanced_parts).reset_index(drop=True)


#### add mci to AD ####
# balanced_df['DX'] = balanced_df['DX'].replace({'MCI' : 'AD'})

print(balanced_df[dx_col].value_counts())

In [None]:
# for n in list(range(6, 25, 2)):
n=18
# for n in [10,12,14,16,18,20,22]:
print('k_topics = ', n)
N_TOPICS = n ###
region_cols = list(balanced_df.loc[:, "VA/2":"VA/2035"].columns)
labels = balanced_df["DX"].values
ids = balanced_df["SUBJ_ID"].values

# Fit LDA on combined z-scores
lda = LDATopicModel(n_topics=N_TOPICS)
theta = lda.fit_transform(balanced_df[region_cols])

# Fit classifier
classifier = TopicClassifier(n_splits=5) ##
cv_results = classifier.cross_validate(theta, labels, ids, verbose=False)
classifier.fit(theta, labels)

print(f"K_topics {n}, CV Accuracy: {cv_results['accuracy']:.4f}")

# visualizer = CopathologyVisualizer(
#     output_dir=f'./train_mci_added/topics_{N_TOPICS}_downsampled'
# )

# fig_conf_mat = visualizer.plot_confusion_matrix(
#     cm=classifier.get_confusion_matrix(),
#     class_names=classifier._classes
# )

# fig_top_regions = visualizer.plot_top_regions_per_topic(
#     topic_patterns = lda.get_topic_patterns(),
#     region_names=region_cols
# )

# fig3 = visualizer.plot_diagnosis_topic_profiles(
#     theta=lda._theta,
#     dx_labels = labels
# )

In [None]:
print(lda._theta.shape)

In [None]:
## Surface Mapping 
## Topicwise Surface Maps
from atlas_vis import DKTAtlas62ROIPlotter
plotter_62  = DKTAtlas62ROIPlotter(
    cmap='Reds',
    clim=(0, 0.1),  
    window_size=(1200, 1000),
    nan_color='lightgray',
    background='white',
    template_key='pial'
)
os.makedirs(f'./train_mci_added/topics_{n}_downsampled/topicwise',exist_ok=True)

topic_df = pd.DataFrame(
    lda.get_topic_patterns().T,
    index=region_cols,
    columns=[f"Topic_{k}" for k in range(n)]
)

df = topic_df.tail(62).reset_index(drop=True)

print(len(df))
for col in df.columns: ##################
    print(col)
    l_values = df.loc[:30,col].to_list()
    r_values = df.loc[31:,col].to_list()
    print(len(l_values))
    print(len(r_values))
    print(np.min(l_values+r_values))
    print(np.max(l_values+r_values))

    plotter_62(l_values, r_values, save_path=f'./train_mci_added/topics_{n}_downsampled/topicwise/{col}.png')


In [None]:
## NACC Inference ## 260120
from data_processor import *
nacc_raw = pd.read_csv('C:/Users/BREIN/Desktop/stage_copath/data/nacc/260120_NACC_VA_TAU_PATH_matched.csv')
# nacc_raw.rename(columns=col_map)
region_cols = nacc_raw.loc[:, 'VA/2':'VA/2035'].columns
pathology_cols = nacc_raw.loc[:, 'NACC_AD':'NACC_svPPA'].columns
nacc_filtered = nacc_raw[nacc_raw['DX'] != 'Unknown']

nacc_cn = nacc_filtered[nacc_filtered['DX'] == 'CN']
nacc_pat = nacc_filtered[nacc_filtered['DX'] != 'CN']
# nacc_pat = nacc_filtered

print(nacc_cn.shape)
print(nacc_pat.shape)

nacc_prep = DataProcessor(
    region_cols=region_cols,
    dx_col='DX',
    subject_col='subject_id'
)
nacc_prep.fit_baseline(hc_data=nacc_cn)
nacc_Z = nacc_prep.compute_atrophy_scores(data=nacc_pat)
print(type(nacc_Z))

nacc_theta = lda.transform(nacc_Z)
y_pred = classifier.predict(nacc_theta)
y_proba = classifier.predict_proba(nacc_theta)
# print(nacc_theta.shape)
# print(y_pred.shape)
# print(y_proba.shape)

nacc_results = pd.DataFrame(nacc_theta, columns=[f"Topic_{k}" for k in range(lda.n_topics)])
print(nacc_results.shape)

subj_col = nacc_prep.subject_col
if subj_col in nacc_pat.columns:
    nacc_results.insert(0, "SUBJ_ID", nacc_pat[subj_col].values)

nacc_results['pred_DX'] = y_pred
for i, dx in enumerate(classifier.classes):
    nacc_results[f"P({dx})"] = y_proba[:,i]

# nacc_results = nacc_results.merge(
#     df_nacc_resilience[["SUBJ_ID", "TN_group", "standardized_residual"]],
#     on="SUBJ_ID",
#     how="left"
# )
# nacc_results = nacc_results.dropna(subset=['TN_group'])

nacc_results = nacc_results.merge(
    df_nacc_resilience[["FULL_ID", "linear_group", "standardized_residual"]],
    left_on="SUBJ_ID",
    right_on='FULL_ID',
    how="left"
)
nacc_results = nacc_results.dropna(subset=['linear_group'])

nacc_results = nacc_results.merge(
    nacc_raw[["subject_id", "DX", 'NACC_AD', 'NACC_PD', 'NACC_VD', 'NACC_LBD', 'NACC_SVAD', 'NACC_PCA', 'NACC_bvFTD']],
    left_on="SUBJ_ID",
    right_on="subject_id",
    how="left"
)

nacc_results = nacc_results.drop(columns=["subject_id"])

In [None]:
## TN Group Margin Chage ##
# Example column: 'standardized_residual'
conditions = [
    nacc_results['standardized_residual'] > 1,
    nacc_results['standardized_residual'] < -1
]

choices = [
    'Vulnerable',
    'Resilient'
]

# Default is 'canonical'
nacc_results['TN_group_1'] = np.select(conditions, choices, default='Canonical')

# Quick check
print(nacc_results[['standardized_residual', 'TN_group_1']].head())


In [None]:
# prob_cols = nacc_results.loc[:,'P(AD)':'P(svPPA)'].columns
prob_cols = ['P(AD)', 'P(MCI)', 'P(PD)', 'P(DLB)', 'P(SVAD)', 'P(bvFTD)', 'P(nfvPPA)', 'P(svPPA)']
# nacc_results = nacc_results[nacc_results['DX']!='IMCI']######## TEMP
# prob_cols = nacc_results.loc[:,'P(AD)':'P(bvFTD)'].columns
# group_col = 'TN_group'
# group_col = 'TN_group_1'
# group_col = 'TN_group_15'
group_col = 'linear_group'

In [None]:
## TN groupwise proportion of NACC copathology positivities
import pandas as pd
import matplotlib.pyplot as plt

# Example:
# nacc_results: your dataframe
# group_col: column with group info
# status_cols: list of binary columns (0/1)

status_cols = ['NACC_AD', 'NACC_PD', 'NACC_VD', 'NACC_LBD', 'NACC_PCA']  # example


# -----------------------------
# Prepare data: proportion of positives and negatives
# -----------------------------
prop_nacc_results_list = []

for col in status_cols:
    # Compute proportion of positives per group
    pos = nacc_results.groupby(group_col)[col].mean()
    neg = 1 - pos  # proportion of negatives
    
    temp_nacc_results = pd.DataFrame({
        'Group': pos.index,
        'Positive': pos.values,
        'Negative': neg.values,
        'Condition': col
    })
    prop_nacc_results_list.append(temp_nacc_results)

# Combine all conditions for plotting
plot_nacc_results = pd.concat(prop_nacc_results_list)

# -----------------------------
# Plot stacked barplot
# -----------------------------
conditions = plot_nacc_results['Condition'].unique()
n_cols = 2
n_rows = (len(conditions) + 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))
axes = axes.flatten()

colors = {'Positive':'blue', 'Negative':'red'}

for ax, cond in zip(axes, conditions):
    nacc_results_cond = plot_nacc_results[plot_nacc_results['Condition'] == cond].set_index('Group')
    nacc_results_cond[['Negative','Positive']].plot(
        kind='bar',
        stacked=True,
        ax=ax,
        color=[colors['Negative'], colors['Positive']],
        legend=False
    )
    ax.set_ylabel('Proportion')
    ax.set_title(cond)
    ax.set_ylim(0,1)

# Single legend for figure
handles = [plt.Rectangle((0,0),1,1,color=colors[c]) for c in colors]
fig.legend(handles, colors.keys(), loc='upper right', title='Status')

# Remove unused axes
for ax in axes[len(conditions):]:
    ax.remove()
plt.suptitle('NACC Pathology Positivity Proportions')
plt.tight_layout()
plt.show()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# ------------------------------------------------------------
# Compute group-wise mean probabilities
# ------------------------------------------------------------
group_means = (
    nacc_results
    .groupby(group_col)[prob_cols]
    .mean()
)

# ------------------------------------------------------------
# Plot heatmap
# ------------------------------------------------------------
plt.figure(figsize=(8, 5))

sns.heatmap(
    group_means,
    cmap="Reds",
    annot=True,
    fmt=".2f",
    linewidths=0.5,
    vmin=0,
    vmax=0.4,
    cbar_kws={"label": "Mean predicted probability"}
)

plt.xlabel("Predicted pathology")
plt.ylabel("Subgroup")
plt.title("NACC Group-wise Mean Predicted Probability Distribution")

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# prob_cols = nacc_results.loc[:, 'P(AD)':'P(svPPA)'].columns
# prob_cols = nacc_results.loc[:, 'P(AD)':'P(bvFTD)'].columns

# ------------------------------------------------------------
# Sort: group first, then descending P(AD)
# ------------------------------------------------------------
nacc_sorted = (
    nacc_results
    .sort_values([group_col, "P(AD)"], ascending=[True, False])
    .reset_index(drop=True)
)

heatmap_data = nacc_sorted[prob_cols]

# ------------------------------------------------------------
# Compute group positions for y-axis labels
# ------------------------------------------------------------
group_counts = nacc_sorted[group_col].value_counts(sort=False)

group_centers = {}
start = 0

for grp, count in group_counts.items():
    center = start + count / 2
    group_centers[grp] = center
    start += count

# ------------------------------------------------------------
# Plot
# ------------------------------------------------------------
plt.figure(figsize=(10, 10))

ax = sns.heatmap(
    heatmap_data,
    cmap="Reds",
    vmin=0,
    vmax=1,
    yticklabels=False,
    cbar_kws={"label": "Predicted probability"}
)

# ------------------------------------------------------------
# Horizontal lines between groups
# ------------------------------------------------------------
cum_sizes = np.cumsum(group_counts.values)

for y in cum_sizes[:-1]:
    ax.hlines(y, *ax.get_xlim(), colors="black", linewidth=1.5)

# ------------------------------------------------------------
# TN subgroup labels on y-axis
# ------------------------------------------------------------
ax.set_yticks(list(group_centers.values()))
ax.set_yticklabels(list(group_centers.keys()), rotation=0, fontsize=11)

# ------------------------------------------------------------
# Labels
# ------------------------------------------------------------
ax.set_xlabel("Predicted pathology")
ax.set_ylabel("TN subgroup")
ax.set_title("Subject-level Predicted Probability Heatmap\n(sorted by descending P(AD))")

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

topic_cols = [c for c in nacc_results.columns if c.startswith("Topic_")]

groups = nacc_results[group_col].unique()
n_groups = len(groups)
n_topics = len(topic_cols)

# ------------------------------------------------------------
# Global max for shared axis
# ------------------------------------------------------------
global_max = (
    nacc_results
    .groupby(group_col)[topic_cols]
    .mean()
    .values
    .max()
)

# ------------------------------------------------------------
# Radar setup
# ------------------------------------------------------------
angles = np.linspace(0, 2 * np.pi, n_topics, endpoint=False)
angles = np.concatenate([angles, [angles[0]]])

fig, axes = plt.subplots(
    1, n_groups,
    figsize=(4 * n_groups, 4),
    subplot_kw=dict(polar=True)
)

if n_groups == 1:
    axes = [axes]

# ------------------------------------------------------------
# Plot
# ------------------------------------------------------------
for ax, grp in zip(axes, groups):

    grp_df = nacc_results[nacc_results[group_col] == grp]
    mean_topics = grp_df[topic_cols].mean().values
    mean_topics = np.concatenate([mean_topics, [mean_topics[0]]])

    ax.plot(angles, mean_topics, linewidth=2)
    ax.fill(angles, mean_topics, alpha=0.25)

    ax.set_title(grp, pad=20)

    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(topic_cols, fontsize=9)

    ax.set_ylim(0, global_max * 1.1)   # âœ… shared scale
    ax.set_yticklabels([])

plt.suptitle("TN Subgroup Topic Weight Profiles (shared radial scale)", fontsize=14)
plt.tight_layout()
plt.show()


In [None]:
## NACC vulnerable check
print('VULN')
r2 = nacc_results[nacc_results[group_col]=='Vulnerable']
print(r2['DX'].value_counts())
print('\nRESIL')
r2 = nacc_results[nacc_results[group_col]=='Resilient']
print(r2['DX'].value_counts())
print('\nCANON')
r2 = nacc_results[nacc_results[group_col]=='Canonical']
print(r2['DX'].value_counts())

In [None]:
## Correlation Subplots ##
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr  # or spearmanr if you prefer
import numpy as np

# -----------------------------
# Example inputs
# -----------------------------
# nacc_results: your dataframe
# cols_to_corr: list of columns of probabilities to correlate
# target_col: column to correlate against
cols_to_corr = prob_cols
target_col = 'standardized_residual'  # for example

# -----------------------------
# Plotting setup
# -----------------------------
n_cols = 3  # how many subplots per row
n_rows = int(np.ceil(len(cols_to_corr) / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
axes = axes.flatten()
unique_groups = nacc_results[group_col].unique()
palette = sns.color_palette("tab10", n_colors=len(unique_groups))
group_palette = dict(zip(unique_groups, palette))

for ax, col in zip(axes, cols_to_corr):
    
    x = nacc_results[col]
    y = nacc_results[target_col]
    
    # Compute correlation
    r, p = pearsonr(x, y)
    
    # Scatter plot
    sns.scatterplot(
        x=x, y=y, hue=nacc_results[group_col], palette=group_palette, ax=ax, s=60, alpha=0.8
    )
    
    # Fit line
    sns.regplot(x=x, y=y, ax=ax, scatter=False, color='red', ci=None)
    
    # Annotate r and p
    ax.text(0.05, 0.95, f"r={r:.2f}\np={p:.3f}",
            transform=ax.transAxes,
            verticalalignment='top',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7))
    
    ax.set_xlabel(col)
    ax.set_ylabel(target_col)
    ax.set_title(f"{col} vs {target_col}")

# Remove empty axes if any
for ax in axes[len(cols_to_corr):]:
    ax.remove()
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right', title=group_col, bbox_to_anchor=(1.05, 1))
plt.suptitle('NACC')
plt.tight_layout()
plt.show()

**ADNI4 INFERENCE**

In [None]:
## ADNI4 Inference ## 260120
adni4_raw = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/data/stage_data/ptau_volume_model/ADNI4_3.csv')
region_cols = adni4_raw.loc[:, 'VA/2':'VA/2035'].columns
adni4_raw = adni4_raw.dropna(subset=region_cols)
adni4_cn = adni4_raw[adni4_raw['DX'] == 'CN']
adni4_pat = adni4_raw[adni4_raw['DX'] != 'CN']
adni4_stage_df = adni4_pat[['FULL_ID', 'DX', 'tau_stage_aa/low', 'tau_stage_aa/mid', 'tau_stage_aa/high', 'pred_tau_stage_aa/low', 'pred_tau_stage_aa/mid', 'pred_tau_stage_aa/high']].dropna()
prob_cols = ['pred_tau_stage_aa/low','pred_tau_stage_aa/mid','pred_tau_stage_aa/high']
max_col = adni4_stage_df[prob_cols].idxmax(axis=1)
adni4_stage_df[prob_cols] = 0
adni4_stage_df.loc[:, prob_cols] = (pd.get_dummies(max_col).reindex(columns=prob_cols, fill_value=0).astype(float))
stage_map = {
    'low': 0,
    'mid': 1,
    'high': 2
}
def get_stage(colname):
    return stage_map[colname.split('/')[-1]]
adni4_stage_df['gt_stage'] = (
    adni4_stage_df[['tau_stage_aa/low', 'tau_stage_aa/mid', 'tau_stage_aa/high']]
    .idxmax(axis=1)
    .apply(get_stage)
)

adni4_stage_df['pred_stage'] = (
    adni4_stage_df[prob_cols]
    .idxmax(axis=1)
    .apply(get_stage)
)

# --------------------------------
# Subject grouping
# --------------------------------
adni4_lower_than_pred = adni4_stage_df[
    adni4_stage_df['gt_stage'] < adni4_stage_df['pred_stage']
][['FULL_ID', 'DX', 'gt_stage', 'pred_stage']]

adni4_exact_match = adni4_stage_df[
    adni4_stage_df['gt_stage'] == adni4_stage_df['pred_stage']
][['FULL_ID', 'DX', 'gt_stage', 'pred_stage']]

adni4_prep = DataProcessor(
    region_cols=region_cols,
    dx_col='DX',
    subject_col='FULL_ID'
)
adni4_prep.fit_baseline(hc_data=adni4_cn)
adni4_Z = adni4_prep.compute_atrophy_scores(data=adni4_pat)
print(adni4_Z.shape)
print(adni4_cn.shape)

adni4_theta = lda.transform(adni4_Z)
adni4_y_pred = classifier.predict(adni4_theta)
adni4_y_proba = classifier.predict_proba(adni4_theta)

adni4_results = pd.DataFrame(adni4_theta, columns=[f"Topic_{k}" for k in range(lda.n_topics)])
print(adni4_results.shape)

subj_col = adni4_prep.subject_col
if subj_col in adni4_pat.columns:
    adni4_results.insert(0, "SUBJ_ID", adni4_pat[subj_col].values)

adni4_results['pred_DX'] = adni4_y_pred
for i, dx in enumerate(classifier.classes):
    adni4_results[f"P({dx})"] = adni4_y_proba[:,i]

adni4_results = adni4_results.merge(
    adni4_raw[["FULL_ID", "DX"]],
    left_on="SUBJ_ID",
    right_on="FULL_ID",
    how="left"
)
adni4_results = adni4_results.drop(columns=["FULL_ID"])

adni4_results["SUBJ_ID"] = adni4_results["SUBJ_ID"].str.replace("_M", "_m", regex=False)

adni4_results = adni4_results.merge(
    df_adni4_resilience[["SUBJ_ID", "TN_group", "standardized_residual"]],
    on="SUBJ_ID",
    how="left"
)
adni4_results = adni4_results.dropna(subset=['TN_group'])
print('!!!!', adni4_results.shape)

In [None]:
## TN Group Margin Chage ##
# Example column: 'standardized_residual'
conditions = [
    adni4_results['standardized_residual'] > 1.0,
    adni4_results['standardized_residual'] < -1.0
]

choices = [
    'Vulnerable',
    'Resilient'
]

# Default is 'canonical'
adni4_results['TN_group_1'] = np.select(conditions, choices, default='Canonical')

# Quick check
print(adni4_results[['standardized_residual', 'TN_group_1']].head())


In [None]:
# prob_cols = nacc_results.loc[:,'P(AD)':'P(svPPA)'].columns
prob_cols = ['P(AD)', 'P(MCI)', 'P(PD)', 'P(DLB)', 'P(SVAD)', 'P(bvFTD)', 'P(nfvPPA)', 'P(svPPA)']
# nacc_results = nacc_results[nacc_results['DX']!='IMCI']######## TEMP
# prob_cols = nacc_results.loc[:,'P(AD)':'P(bvFTD)'].columns
group_col = 'TN_group'
# group_col = 'TN_group_1'

In [None]:
## Groupwise
# ------------------------------------------------------------
# Compute group-wise mean probabilities
# ------------------------------------------------------------
group_means = (
    adni4_results
    .groupby(group_col)[prob_cols]
    .mean()
)

# ------------------------------------------------------------
# Plot heatmap
# ------------------------------------------------------------
plt.figure(figsize=(8, 5))

sns.heatmap(
    group_means,
    cmap="Reds",
    annot=True,
    fmt=".2f",
    linewidths=0.5,
    vmin=0,
    vmax=0.4,
    cbar_kws={"label": "Mean predicted probability"}
)

plt.xlabel("Predicted pathology")
plt.ylabel("Subgroup")
plt.title("ADNI4 Group-wise Mean Predicted Probability Distribution")

plt.tight_layout()
plt.show()


In [None]:
## Subject-wise
# ------------------------------------------------------------
# Sort: group first, then descending P(AD)
# ------------------------------------------------------------
adni4_sorted = (
    adni4_results
    .sort_values([group_col, "P(AD)"], ascending=[True, False])
    .reset_index(drop=True)
)

heatmap_data = adni4_sorted[prob_cols]

# ------------------------------------------------------------
# Compute group positions for y-axis labels
# ------------------------------------------------------------
group_counts = adni4_sorted[group_col].value_counts(sort=False)

group_centers = {}
start = 0

for grp, count in group_counts.items():
    center = start + count / 2
    group_centers[grp] = center
    start += count

# ------------------------------------------------------------
# Plot
# ------------------------------------------------------------
plt.figure(figsize=(10, 10))

ax = sns.heatmap(
    heatmap_data,
    cmap="Reds",
    vmin=0,
    vmax=1,
    yticklabels=False,
    cbar_kws={"label": "Predicted probability"}
)

# ------------------------------------------------------------
# Horizontal lines between groups
# ------------------------------------------------------------
cum_sizes = np.cumsum(group_counts.values)

for y in cum_sizes[:-1]:
    ax.hlines(y, *ax.get_xlim(), colors="black", linewidth=1.5)

# ------------------------------------------------------------
# TN subgroup labels on y-axis
# ------------------------------------------------------------
ax.set_yticks(list(group_centers.values()))
ax.set_yticklabels(list(group_centers.keys()), rotation=0, fontsize=11)

# ------------------------------------------------------------
# Labels
# ------------------------------------------------------------
ax.set_xlabel("Predicted pathology")
ax.set_ylabel("TN subgroup")
ax.set_title("ADNI4 Subject-level Predicted Probability Heatmap\n(sorted by descending P(AD))")

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

topic_cols = [c for c in adni4_results.columns if c.startswith("Topic_")]

groups = adni4_results[group_col].unique()
n_groups = len(groups)
n_topics = len(topic_cols)

# ------------------------------------------------------------
# Global max for shared axis
# ------------------------------------------------------------
global_max = (
    adni4_results
    .groupby(group_col)[topic_cols]
    .mean()
    .values
    .max()
)

# ------------------------------------------------------------
# Radar setup
# ------------------------------------------------------------
angles = np.linspace(0, 2 * np.pi, n_topics, endpoint=False)
angles = np.concatenate([angles, [angles[0]]])

fig, axes = plt.subplots(
    1, n_groups,
    figsize=(4 * n_groups, 4),
    subplot_kw=dict(polar=True)
)

if n_groups == 1:
    axes = [axes]

# ------------------------------------------------------------
# Plot
# ------------------------------------------------------------
for ax, grp in zip(axes, groups):

    grp_df = adni4_results[adni4_results[group_col] == grp]
    mean_topics = grp_df[topic_cols].mean().values
    mean_topics = np.concatenate([mean_topics, [mean_topics[0]]])

    ax.plot(angles, mean_topics, linewidth=2)
    ax.fill(angles, mean_topics, alpha=0.25)

    ax.set_title(grp, pad=20)

    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(topic_cols, fontsize=9)

    ax.set_ylim(0, global_max * 1.1)   # âœ… shared scale
    ax.set_yticklabels([])

plt.suptitle("ADNI4 TN Subgroup Topic Weight Profiles (shared radial scale)", fontsize=14)
plt.tight_layout()
plt.show()


In [None]:
## Correlation Subplots ##
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr  # or spearmanr if you prefer
import numpy as np

# -----------------------------
# Example inputs
# -----------------------------
# adni4_results: your dataframe
# cols_to_corr: list of columns of probabilities to correlate
# target_col: column to correlate against
cols_to_corr = prob_cols
target_col = 'standardized_residual'  # for example

# -----------------------------
# Plotting setup
# -----------------------------
n_cols = 3  # how many subplots per row
n_rows = int(np.ceil(len(cols_to_corr) / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
axes = axes.flatten()
unique_groups = adni4_results[group_col].unique()
palette = sns.color_palette("tab10", n_colors=len(unique_groups))
group_palette = dict(zip(unique_groups, palette))

for ax, col in zip(axes, cols_to_corr):
    
    x = adni4_results[col]
    y = adni4_results[target_col]
    
    # Compute correlation
    r, p = pearsonr(x, y)
    
    # Scatter plot
    sns.scatterplot(
        x=x, y=y, hue=adni4_results[group_col], palette=group_palette, ax=ax, s=60, alpha=0.8
    )
    
    # Fit line
    sns.regplot(x=x, y=y, ax=ax, scatter=False, color='red', ci=None)
    
    # Annotate r and p
    ax.text(0.05, 0.95, f"r={r:.2f}\np={p:.3f}",
            transform=ax.transAxes,
            verticalalignment='top',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7))
    
    ax.set_xlabel(col)
    ax.set_ylabel(target_col)
    ax.set_title(f"{col} vs {target_col}")

# Remove empty axes if any
for ax in axes[len(cols_to_corr):]:
    ax.remove()
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right', title=group_col, bbox_to_anchor=(1.05, 1))
plt.suptitle('ADNI4')
plt.tight_layout()
plt.show()

In [None]:
# adni4_stage_df
real_low = adni4_stage_df[adni4_stage_df['gt_stage']==0]
print(real_low['pred_stage'].value_counts())
real_low["FULL_ID"] = real_low["FULL_ID"].str.replace("_M", "_m", regex=False)

In [None]:
r2 = pd.merge(adni4_results, real_low, left_on='SUBJ_ID', right_on='FULL_ID')
print(r2['pred_stage'].value_counts())