## AA STAGE RESILIENCE INFERENCE

In [None]:
from LDA_XGB.pipeline import CopathologyPipeline
from LDA_XGB.data_processor import *
from LDA_XGB.visualizer import CopathologyVisualizer
import pandas as pd
import os


path_ad_non_ad_k_8 = 'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/STAGE_CoPathology/LDA_XGB/models/wsev_smc_ad_non_ad_with_cn_k_8.pkl'
mdl_ad_non_ad_k_8 = CopathologyPipeline.load(path_ad_non_ad_k_8)

path_all_dx_k_8 = 'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/STAGE_CoPathology/LDA_XGB/models/wsev_smc_all_dx_with_cn_k_8.pkl'
mdl_all_dx_k_8 = CopathologyPipeline.load(path_all_dx_k_8)

path_all_dx_k_18 = 'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/STAGE_CoPathology/LDA_XGB/models/wsev_smc_all_dx_with_cn_k_18.pkl'
mdl_all_dx_k_18 = CopathologyPipeline.load(path_all_dx_k_18)

path_ad_non_ad_k_18 = 'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/STAGE_CoPathology/LDA_XGB/models/wsev_smc_ad_non_ad_with_cn_k_18.pkl'
mdl_ad_non_ad_k_18 = CopathologyPipeline.load(path_ad_non_ad_k_18)


**Utility Functions**

In [None]:
def aa_stage_resilient_group(inp_df):
    """assumes cols include tau_stage_aa/group and pred_tau_stage_aa/(low|mid|high)"""
    inp_df = inp_df.copy()

    inp_df['pred_tau_stage_aa/group'] = (
        inp_df[['pred_tau_stage_aa/low', 'pred_tau_stage_aa/mid', 'pred_tau_stage_aa/high']]
        .idxmax(axis=1)
        .str.replace('pred_tau_stage_aa/', '', regex=False)
    )

    order = {'low': 0, 'mid': 1, 'high': 2}

    gt_group = (
        inp_df['tau_stage_aa/group']
        .astype('string')
        .str.strip()
        .str.lower()
    )
    pred_group = (
        inp_df['pred_tau_stage_aa/group']
        .astype('string')
        .str.strip()
        .str.lower()
    )

    gt_num = gt_group.map(order)
    pred_num = pred_group.map(order)

    inp_df['tau_stage_aa/res_group'] = np.select(
        [pred_num > gt_num, pred_num == gt_num, pred_num < gt_num],
        ['Overpredict', 'Canonical', 'Underpredict'],
        default=None
    )

    return inp_df

**Visualization Functions**

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr
from LDA_XGB.visualizer import CopathologyVisualizer

def nacc_prediction_probability_heatmap(inp_df, prob_cols, dx_cols, prefix=''):
    print()

def nacc_copath_radar_plot(inp_df, copath_col='copath', copath_count_col='copath_counts', y_max=0.35, prefix=''):
    plot_df = inp_df.dropna(subset=copath_col)
    plot_df[copath_col] = plot_df[copath_col].str.replace("FTD_ANY", "FTD", regex=False) ## HARD-CODING
    for i in plot_df[copath_count_col].unique():
        print(i)
        temp_df = plot_df[plot_df[copath_count_col] == i]
        vis = CopathologyVisualizer(
            output_dir='C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/results/nacc_copath_figures'
        )
        print(temp_df.shape)
        topic_cols = [c for c in temp_df.columns if c.startswith('Topic_')]
        copath_radar = vis.plot_diagnosis_topic_profiles(
            theta=temp_df[topic_cols],
            dx_labels=temp_df[copath_col],
            y_max=y_max,
            title = f'{prefix} Diagnostics Topic Profiles ({copath_col}={int(i)})',
            save=False
        )
    # plt.show()
    return copath_radar

def res_group_mean_heatmap(inp_df, prob_cols, group_col, group_order, prefix=''):
    group_means = (
        inp_df
        .groupby(group_col)[prob_cols]
        .mean()
        .reindex(group_order)
    )
    plt.figure(figsize=(8, 5))

    sns.heatmap(
        group_means,
        cmap="Reds",
        annot=True,
        fmt=".2f",
        linewidths=0.5,
        vmin=0,
        vmax=1,
        cbar_kws={"label": "Mean predicted probability"}
    )

    plt.xlabel("Predicted pathology")
    plt.ylabel("Subgroup")
    plt.title(f"{prefix} Group-wise Mean Predicted Probability Distribution")
    plt.tight_layout()
    plt.show()

def res_group_subjectwise_heatmap(inp_df, prob_cols, group_col, group_order, prefix=''):
    inp_df[group_col] = pd.Categorical(
        inp_df[group_col],
        categories=group_order,
        ordered=True
    )
    df_sorted = (
        inp_df
        .sort_values([group_col, "P(AD)"], ascending=[True, False])
        .reset_index(drop=True)
    )
    heatmap_data = df_sorted[prob_cols]
    group_counts = (
        df_sorted[group_col]
        .value_counts()
        .reindex(group_order)
    )
    group_centers = {}
    start = 0
    for grp, count in group_counts.items():
        center = start + count / 2
        group_centers[grp] = center
        start += count
    plt.figure(figsize=(10, 10))
    ax = sns.heatmap(
        heatmap_data,
        cmap="Reds",
        vmin=0,
        vmax=1,
        yticklabels=False,
        cbar_kws={"label": "Predicted probability"}
    )
    cum_sizes = np.cumsum(group_counts.values)
    for y in cum_sizes[:-1]:
        ax.hlines(y, *ax.get_xlim(), colors="black", linewidth=1.5)
    ax.set_yticks(list(group_centers.values()))
    ax.set_yticklabels(list(group_centers.keys()), rotation=0, fontsize=11)
    ax.set_xlabel("Predicted Pathology")
    ax.set_ylabel("Subgroup")
    ax.set_title(f"{prefix} Subject-level Predicted Probability Heatmap\n(sorted by descending P(AD))")
    plt.tight_layout()
    plt.show()

def res_group_topic_radar_plot(inp_df, prob_cols, group_col, group_order, prefix=''):
    topic_cols = [c for c in inp_df.columns if c.startswith("Topic_")]
    groups = inp_df[group_col].unique()
    n_groups = len(groups)
    n_topics = len(topic_cols)
    global_max = (
        inp_df
        .groupby(group_col)[topic_cols]
        .mean()
        .values
        .max()
    )
    angles = np.linspace(0, 2 * np.pi, n_topics, endpoint=False)
    angles = np.concatenate([angles, [angles[0]]])

    fig, axes = plt.subplots(
        1, n_groups,
        figsize=(4 * n_groups, 4),
        subplot_kw=dict(polar=True)
    )

    if n_groups == 1:
        axes = [axes]
    for ax, grp in zip(axes, group_order):
        print(ax,grp)

        grp_df = inp_df[inp_df[group_col] == grp]
        mean_topics = grp_df[topic_cols].mean().values
        mean_topics = np.concatenate([mean_topics, [mean_topics[0]]])

        ax.plot(angles, mean_topics, linewidth=2)
        ax.fill(angles, mean_topics, alpha=0.25)

        ax.set_title(grp, pad=20)

        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(topic_cols, fontsize=9)

        ax.set_ylim(0, global_max * 1.1)   # âœ… shared scale
        ax.set_yticklabels([])

    plt.suptitle(f"{prefix} Resilience Subgroup Topic Weight Profiles (shared radial scale)", fontsize=14)
    plt.tight_layout()
    plt.show()

### NACC Inference

In [None]:
## Data Prep ## - 260212 OUTDATED 
nacc_pred_df = pd.read_csv('C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/data/nacc/STAGE_NACC_inference.csv')
pred_cols = [c for c in nacc_pred_df.columns if c.startswith('pred_')]
nacc_pred_df = nacc_pred_df[['FULL_ID'] + pred_cols]
nacc_external_df = pd.read_csv('C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/data/nacc/nacc_stage_external/NACC_external_260206.csv')

nacc_combined_df = pd.merge(nacc_external_df, nacc_pred_df, on='FULL_ID', how='left')
nacc_combined_df = aa_stage_resilient_group(nacc_combined_df)
print(nacc_combined_df.shape)

va_cols = nacc_combined_df.loc[:,'VA/2':'VA/2035'].columns
nacc_combined_df = nacc_combined_df.dropna(subset=va_cols)

In [None]:
## Data Prep ## 
nacc_data_path = 'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/data/nacc'
# raw_nacc_df = pd.read_csv(os.path.join(nacc_data_path, 'NACC_external_260212.csv'))
raw_nacc_df = pd.read_csv(os.path.join(nacc_data_path, 'NACC_external_260212_inferenced_ref_model.csv'))
raw_nacc_df = aa_stage_resilient_group(raw_nacc_df)

print(raw_nacc_df.shape)

va_cols = raw_nacc_df.loc[:, 'VA/2':'VA/2035'].columns
raw_nacc_df = raw_nacc_df.dropna(subset=va_cols)
print(raw_nacc_df.shape)

# print(raw_nacc_df['copath'].value_counts())

**ALL DX**

In [None]:
## Inference ## K=8
nacc_all_dx_k_8_df = mdl_all_dx_k_8.predict_new_subjects(inp_df=raw_nacc_df, subject_col='FULL_ID', cn_dx='CN')
nacc_all_dx_k_8_df = pd.merge(raw_nacc_df, nacc_all_dx_k_8_df, left_on='FULL_ID', right_on='SUBJ_ID', how='left')
print(nacc_all_dx_k_8_df.shape)

# nacc_all_dx_k_8_df.to_csv(os.path.join(nacc_data_path, 'NACC_external_260212_all_dx_lda.csv'), index=False)
# nacc_copath_radar_plot(inp_df=nacc_all_dx_k_8_df, y_max=0.5, prefix='NACC')

# RES Group Visualization ##
prob_cols = ['P(NC)', 'P(AD)', 'P(PD)', 'P(DLB)', 'P(SVAD)', 'P(bvFTD)', 'P(nfvPPA)', 'P(svPPA)']
group_col = 'tau_stage_aa/res_group'
group_order = ['Underpredict', 'Canonical', 'Overpredict']
plot_df = nacc_all_dx_k_8_df.dropna(subset=group_col)
plot_df = plot_df[plot_df['TRACER']!='MK6240']
print(plot_df.shape)
plot_df = plot_df[plot_df['DX'].isin(['AD'])]
print(plot_df.shape)

res_group_mean_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_subjectwise_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_topic_radar_plot(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')

In [None]:
## Inference ## K=18
nacc_all_dx_k_18_df = mdl_all_dx_k_18.predict_new_subjects(inp_df=raw_nacc_df, subject_col='FULL_ID', cn_dx='CN')
nacc_all_dx_k_18_df = pd.merge(raw_nacc_df, nacc_all_dx_k_18_df, left_on='FULL_ID', right_on='SUBJ_ID', how='left')
print(nacc_all_dx_k_18_df.shape)

# nacc_copath_radar_plot(inp_df=nacc_all_dx_k_18_df, y_max=0.35, prefix='NACC')
nacc_all_dx_k_18_df.to_csv(os.path.join(nacc_data_path, 'NACC_external_260212_all_dx_lda_k_18.csv'), index=False)

## RES Group Visualization ##
prob_cols = ['P(NC)', 'P(AD)', 'P(PD)', 'P(DLB)', 'P(SVAD)', 'P(bvFTD)', 'P(nfvPPA)', 'P(svPPA)']
group_col = 'tau_stage_aa/res_group'
group_order = ['Underpredict', 'Canonical', 'Overpredict']
plot_df = nacc_all_dx_k_18_df.dropna(subset=group_col)
plot_df = plot_df[plot_df['TRACER']!='MK6240']
print(plot_df.shape)
plot_df = plot_df[plot_df['DX'].isin(['MCI', 'AD'])]
print(plot_df.shape)

res_group_mean_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_subjectwise_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_topic_radar_plot(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')

**AD vs non-AD**

In [None]:
## Inference ## K=8
nacc_ad_non_ad_k_8_df = mdl_ad_non_ad_k_8.predict_new_subjects(inp_df=raw_nacc_df, subject_col='FULL_ID', cn_dx='CN')
nacc_ad_non_ad_k_8_df = pd.merge(raw_nacc_df, nacc_ad_non_ad_k_8_df, left_on='FULL_ID', right_on='SUBJ_ID', how='left')
print(nacc_ad_non_ad_k_8_df.shape)
# nacc_ad_non_ad_k_8_df.to_csv(os.path.join(nacc_data_path, 'NACC_external_260212_ad_nonad_lda.csv'), index=False)
# nacc_copath_radar_plot(inp_df=nacc_ad_non_ad_k_8_df, y_max=0.5, prefix='NACC')

## Visualization ##
prob_cols = ['P(NC)', 'P(AD)', 'P(non-AD)']
group_col = 'tau_stage_aa/res_group'
group_order = ['Underpredict', 'Canonical', 'Overpredict']
plot_df = nacc_ad_non_ad_k_8_df.dropna(subset=group_col)
plot_df = plot_df[plot_df['TRACER']!='MK6240']
print(plot_df.shape)
plot_df = plot_df[plot_df['DX'].isin(['CN'])]
print(plot_df.shape)

res_group_mean_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_subjectwise_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_topic_radar_plot(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')

In [None]:
## Inference ## K=18
nacc_ad_non_ad_k_18_df = mdl_ad_non_ad_k_18.predict_new_subjects(inp_df=raw_nacc_df, subject_col='FULL_ID', cn_dx='CN')
nacc_ad_non_ad_k_18_df = pd.merge(raw_nacc_df, nacc_ad_non_ad_k_18_df, left_on='FULL_ID', right_on='SUBJ_ID', how='left')
print(nacc_ad_non_ad_k_18_df.shape)
# nacc_ad_non_ad_k_18_df.to_csv(os.path.join(nacc_data_path, 'NACC_external_260212_ad_nonad_lda.csv'), index=False)
# nacc_copath_radar_plot(inp_df=nacc_ad_non_ad_k_18_df, y_max=0.5, prefix='NACC')
# nacc_ad_non_ad_k_18_df.to_csv(os.path.join(nacc_data_path, 'NACC_external_260212_ad_non_ad_lda_k_18.csv'), index=False)

## Visualization ##
prob_cols = ['P(NC)', 'P(AD)', 'P(non-AD)']
group_col = 'tau_stage_aa/res_group'
group_order = ['Underpredict', 'Canonical', 'Overpredict']
plot_df = nacc_ad_non_ad_k_18_df.dropna(subset=group_col)
plot_df = plot_df[plot_df['TRACER']!='MK6240']
print(plot_df.shape)
plot_df = plot_df[plot_df['DX'].isin(['MCI', 'AD'])]
print(plot_df.shape)

res_group_mean_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_subjectwise_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_topic_radar_plot(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')