## NMF Topic Modelling

**RAW VA INPUT**

In [None]:
import pandas as pd
import numpy as np
from LDA_XGB.data_processor import *
from LDA_XGB.visualizer import *
from LDA_XGB.brain_visualizer import *
import matplotlib.pyplot as plt
from sklearn.decomposition import NMF

def class_balance(inp_df, class_col, n=25, special_care=['AD','NC'], special_n=50):
    out_df = []
    for dx, g in inp_df.groupby(class_col):
        if dx in (special_care):
            N=special_n
        else:
            N=n
        if len(g) > N:
            g = g.sample(n=N, replace=False, random_state=42)
        out_df.append(g)
    
    return pd.concat(out_df).reset_index(drop=True)

In [None]:
def aa_stage_resilient_group(inp_df):
    """assumes cols include tau_stage_aa/group and pred_tau_stage_aa/(low|mid|high)"""
    inp_df = inp_df.copy()

    inp_df['pred_tau_stage_aa/group'] = (
        inp_df[['pred_tau_stage_aa/low', 'pred_tau_stage_aa/mid', 'pred_tau_stage_aa/high']]
        .idxmax(axis=1)
        .str.replace('pred_tau_stage_aa/', '', regex=False)
    )

    order = {'low': 0, 'mid': 1, 'high': 2}

    gt_group = (
        inp_df['tau_stage_aa/group']
        .astype('string')
        .str.strip()
        .str.lower()
    )
    pred_group = (
        inp_df['pred_tau_stage_aa/group']
        .astype('string')
        .str.strip()
        .str.lower()
    )

    gt_num = gt_group.map(order)
    pred_num = pred_group.map(order)

    inp_df['tau_stage_aa/res_group'] = np.select(
        [pred_num > gt_num, pred_num == gt_num, pred_num < gt_num],
        ['Overpredict', 'Canonical', 'Underpredict'],
        default=None
    )

    return inp_df

**Visualization Functions**

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr
from LDA_XGB.visualizer import CopathologyVisualizer

def nacc_copath_radar_plot(inp_df, copath_col='copath', copath_count_col='copath_counts', y_max=0.35, prefix=''):
    plot_df = inp_df.dropna(subset=copath_col)
    plot_df[copath_col] = plot_df[copath_col].str.replace("FTD_ANY", "FTD", regex=False) ## HARD-CODING
    for i in plot_df[copath_count_col].unique():
        print(i)
        temp_df = plot_df[plot_df[copath_count_col] == i]
        vis = CopathologyVisualizer(
            output_dir='C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/results/nacc_copath_figures'
        )
        print(temp_df.shape)
        topic_cols = [c for c in temp_df.columns if c.startswith('nmf/Topic_')]
        copath_radar = vis.plot_diagnosis_topic_profiles(
            theta=temp_df[topic_cols],
            dx_labels=temp_df[copath_col],
            y_max=y_max,
            title = f'{prefix} Diagnostics Topic Profiles ({copath_col}={int(i)})',
            save=False
        )
    # plt.show()
    return copath_radar

def res_group_mean_heatmap(inp_df, prob_cols, group_col, group_order, prefix=''):
    group_means = (
        inp_df
        .groupby(group_col)[prob_cols]
        .mean()
        .reindex(group_order)
    )
    plt.figure(figsize=(8, 5))

    sns.heatmap(
        group_means,
        cmap="Reds",
        annot=True,
        fmt=".2f",
        linewidths=0.5,
        vmin=0,
        vmax=1,
        cbar_kws={"label": "Mean predicted probability"}
    )

    plt.xlabel("Predicted pathology")
    plt.ylabel("Subgroup")
    plt.title(f"{prefix} Group-wise Mean Predicted Probability Distribution")
    plt.tight_layout()
    plt.show()

def res_group_subjectwise_heatmap(inp_df, prob_cols, group_col, group_order, prefix=''):
    inp_df[group_col] = pd.Categorical(
        inp_df[group_col],
        categories=group_order,
        ordered=True
    )
    df_sorted = (
        inp_df
        .sort_values([group_col, "nmf_P(AD)"], ascending=[True, False])
        .reset_index(drop=True)
    )
    heatmap_data = df_sorted[prob_cols]
    group_counts = (
        df_sorted[group_col]
        .value_counts()
        .reindex(group_order)
    )
    group_centers = {}
    start = 0
    for grp, count in group_counts.items():
        center = start + count / 2
        group_centers[grp] = center
        start += count
    plt.figure(figsize=(10, 10))
    ax = sns.heatmap(
        heatmap_data,
        cmap="Reds",
        vmin=0,
        vmax=1,
        yticklabels=False,
        cbar_kws={"label": "Predicted probability"}
    )
    cum_sizes = np.cumsum(group_counts.values)
    for y in cum_sizes[:-1]:
        ax.hlines(y, *ax.get_xlim(), colors="black", linewidth=1.5)
    ax.set_yticks(list(group_centers.values()))
    ax.set_yticklabels(list(group_centers.keys()), rotation=0, fontsize=11)
    ax.set_xlabel("Predicted Pathology")
    ax.set_ylabel("Subgroup")
    ax.set_title(f"{prefix} Subject-level Predicted Probability Heatmap\n(sorted by descending P(AD))")
    plt.tight_layout()
    plt.show()

def res_group_topic_radar_plot(inp_df, prob_cols, group_col, group_order, prefix=''):
    topic_cols = [c for c in inp_df.columns if c.startswith("nmf/Topic_")]
    groups = inp_df[group_col].unique()
    n_groups = len(groups)
    n_topics = len(topic_cols)
    global_max = (
        inp_df
        .groupby(group_col)[topic_cols]
        .mean()
        .values
        .max()
    )
    angles = np.linspace(0, 2 * np.pi, n_topics, endpoint=False)
    angles = np.concatenate([angles, [angles[0]]])

    fig, axes = plt.subplots(
        1, n_groups,
        figsize=(4 * n_groups, 4),
        subplot_kw=dict(polar=True)
    )

    if n_groups == 1:
        axes = [axes]
    for ax, grp in zip(axes, group_order):
        print(ax,grp)

        grp_df = inp_df[inp_df[group_col] == grp]
        mean_topics = grp_df[topic_cols].mean().values
        mean_topics = np.concatenate([mean_topics, [mean_topics[0]]])

        ax.plot(angles, mean_topics, linewidth=2)
        ax.fill(angles, mean_topics, alpha=0.25)

        ax.set_title(grp, pad=20)

        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(topic_cols, fontsize=9)
        # radial ticks (levels)
        rmax = global_max * 1.1
        ax.set_ylim(0, rmax)

        rticks = np.linspace(0, rmax, 5)  # 0%, 25%, 50%, 75%, 100% of rmax
        ax.set_yticks(rticks)
        ax.set_yticklabels([f"{t:.2f}" for t in rticks], fontsize=8)
        ax.set_rlabel_position(90)  # move radial tick labels (degrees); adjust as you like

        ax.set_ylim(0, global_max * 1.1)   # ✅ shared scale
        # ax.set_yticklabels([])

    plt.suptitle(f"{prefix} Resilience Subgroup Topic Weight Profiles (shared radial scale)", fontsize=14)
    plt.tight_layout()
    plt.show()

**Train Data and Model Initialization**

In [None]:
data_path = 'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/data'
dkt_labels = pd.read_csv(os.path.join(data_path, 'dkt_labels.csv'))
dkt_rois = dkt_labels.iloc[0].tolist()

df_wsev = pd.read_csv(os.path.join(data_path, 'wsev_old_100_data.csv'))
df_smc = pd.read_csv(os.path.join(data_path, 'SMC_AD_FTD_VA_final.csv'))
region_cols = df_smc.loc[:, 'VA/2':'VA/2035'].columns.to_list()

raw_va_train_df = pd.concat([df_wsev, df_smc], axis=0, ignore_index=True)
raw_va_train_df = raw_va_train_df.dropna(subset=region_cols)
raw_va_train_df = raw_va_train_df[raw_va_train_df['DX']!='HC'] ## get rid of WSEV HC control

train_all_dx = class_balance(raw_va_train_df, class_col='DX')
print(train_all_dx['DX'].value_counts())

train_ad_non_ad = class_balance(raw_va_train_df, class_col='DX', special_n=100)
train_ad_non_ad['AD_label'] = np.where(train_ad_non_ad['DX'].isin(['AD','NC']), train_ad_non_ad['DX'], 'non-AD')
print(train_ad_non_ad['AD_label'].value_counts())

K_TOPICS = 6
PARAM = 2
region_cols = train_all_dx.loc[:, 'VA/1002':'VA/2035'].columns.to_list() ############ 62 regions 

In [None]:
## ALL DX MODEL ## (PARAM-raw_Va)
print('K-topics =', K_TOPICS)
print(f'Subtracting Raw VA from {PARAM}')
print('regions count:', len(region_cols))

##############################
# train_copy = train_all_dx
train_copy = train_ad_non_ad
dx_label = 'AD_label'
# dx_label = 'DX'
##################################
train_copy[region_cols] = PARAM-train_copy[region_cols]
va_means = train_copy[region_cols].mean(axis=0)
dx = train_copy[dx_label]
print(dx.unique())

nmf_transormed_va = NMF(
    n_components=K_TOPICS,
    init="nndsvda",
    solver="cd",
    max_iter=2000,
    random_state=42
)
W = nmf_transormed_va.fit_transform(train_copy[region_cols])
H = nmf_transormed_va.components_
Hn = H / (H.sum(axis=1, keepdims=True) + 1e-12)

W_df = pd.DataFrame(
    W,
    columns=[f"Topic_{k}" for k in range(K_TOPICS)]
)
W_df[dx_label] = dx.values
W_df['FULL_ID'] = train_copy['PTID'].values

H_df = pd.DataFrame(
    H,
    columns=region_cols,
    index=[f'Topic_{k}' for k in range(K_TOPICS)]
)

print('W :', W_df.shape)
print('H :', H_df.shape)

## Visualization
nmf_visualizer = CopathologyVisualizer(
    output_dir=f'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/results/nmf_{PARAM}-raw_va/{K_TOPICS}/figures',
)
dx_radar = nmf_visualizer.plot_diagnosis_topic_profiles(
    theta = W_df.loc[:,[f"Topic_{k}" for k in range(K_TOPICS)]],
    dx_labels=W_df[dx_label]
)
top_topics = nmf_visualizer.plot_top_regions_per_topic(
    topic_patterns=Hn,
    region_names=dkt_rois[-62:]
)
surface_mapper = BrainVisualizer(
    output_dir=f'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/results/nmf_{PARAM}-raw_va/{K_TOPICS}/surface_maps',
    # clim=(0,Hn.max())
    clim=(0,0.03)
)
surface_mapper.plot_all_topics(
    topic_patterns=Hn
)

In [None]:
## XGB Classifier ##
from LDA_XGB.classifier import TopicClassifier
from LDA_XGB.visualizer import CopathologyVisualizer

nmf_classifier = TopicClassifier(
    n_splits=5
)

theta = W_df.loc[:,'Topic_0':'Topic_5'].values
y=W_df[dx_label].values
nmf_classifier.fit(theta, y)
cv_results = nmf_classifier.cross_validate(
    theta, y, subject_ids=W_df['FULL_ID']
)
print(cv_results['accuracy'])
r2=cv_results['results_df']
cv_results['confusion_matrix']

nmf_visualizer.plot_confusion_matrix(
    cm=cv_results['confusion_matrix'],
    class_names = list(nmf_classifier.classes),
    accuracy=cv_results['accuracy'],
    label_order=['NC', 'AD', 'non-AD']
)
nmf_visualizer.plot_probability_heatmap(cv_results['results_df'], dx_order=['NC', 'AD', 'non-AD'])


In [None]:
r2 = cv_results['results_df']
r2 = pd.merge(train_copy, r2, left_on='PTID', right_on='subject_id', how='left')
# r2.to_csv('nmf_cv_results.csv', index=False)

box1 = r2[(r2['DX_pred']=='NC') & (r2['DX_true']=='non-AD')]
print(box1.shape)
box2= r2[(r2['DX_pred']=='AD') & (r2['DX_true']=='non-AD')]
print(box2.shape)
box3= r2[(r2['DX_pred']=='non-AD') & (r2['DX_true']=='non-AD')]
print(box3.shape)
def dx_proportions(df, dx_col="DX"):
    # proportions of DX within df
    return (df[dx_col]
            .value_counts(dropna=False, normalize=True)
            .rename_axis(dx_col)
            .reset_index(name="prop"))
# dx_proportions(box1)
# dx_proportions(box2)
# dx_proportions(box3)
boxes = {
    "Pred NC | True non-AD": box1,
    "Pred AD | True non-AD": box2,
    "Pred non-AD | True non-AD": box3,
}

prop_long = []
for name, df in boxes.items():
    tmp = (df["DX"].value_counts(normalize=True, dropna=False)
           .rename("prop")
           .reset_index()
           .rename(columns={"index": "DX"}))
    tmp["box"] = name
    prop_long.append(tmp)

prop_long = pd.concat(prop_long, ignore_index=True)
import seaborn as sns
import matplotlib.pyplot as plt

prop_wide = prop_long.pivot(index="box", columns="DX", values="prop").fillna(0)

# ax = prop_wide.plot(kind="bar", stacked=True, figsize=(10, 4), colormap="tab20")
bar_order = ["Pred NC | True non-AD", "Pred AD | True non-AD", "Pred non-AD | True non-AD"]  # 원하는 순서
ax = prop_wide.reindex(bar_order).plot(kind="bar", stacked=True, figsize=(10, 4), colormap="tab20")
ax.set_ylabel("Proportion")
ax.set_xlabel("")
plt.setp(ax.get_xticklabels(), rotation=360, ha="center")
ax.legend(title="DX", bbox_to_anchor=(1.02, 1), loc="upper left")
plt.tight_layout()
plt.show()


# bar_order = ["Pred AD | True non-AD", "Pred NC | True non-AD", "Pred non-AD | True non-AD"]  # 원하는 순서
# ax = prop_wide.reindex(bar_order).plot(kind="bar", stacked=True, figsize=(10, 4), colormap="tab20")

# ax.set_ylabel("Proportion")
# ax.set_xlabel("")
# plt.setp(ax.get_xticklabels(), rotation=0, ha="center")
# ax.legend(title="DX", bbox_to_anchor=(1.02, 1), loc="upper left")
# plt.tight_layout()
# plt.show()

In [None]:
## ALL DX MODEL ## log(PARAM-raw_Va) - OUTDATED
print('K-topics =', K_TOPICS)
print(f'Subtracting Raw VA from {PARAM}')
region_cols = train_all_dx.loc[:, 'VA/1002':'VA/2035'].columns.to_list()

train_log_copy = train_all_dx
train_log_copy[region_cols] = np.log1p(PARAM-train_log_copy[region_cols])
va_means = train_log_copy[region_cols].mean(axis=0)
dx = train_log_copy['DX']

nmf = NMF(
    n_components=K_TOPICS,
    init="nndsvda",
    solver="cd",
    max_iter=2000,
    random_state=42
)
W = nmf.fit_transform(train_log_copy[region_cols])
H = nmf.components_
Hn = H / (H.sum(axis=1, keepdims=True) + 1e-12)

W_df = pd.DataFrame(
    W,
    columns=[f"Topic_{k}" for k in range(K_TOPICS)]
)
W_df['DX'] = dx.values
H_df = pd.DataFrame(
    H,
    columns=region_cols,
    index=[f'Topic_{k}' for k in range(K_TOPICS)]
)

print('W :', W_df.shape)
print('H :', H_df.shape)

## Visualization
nmf_visualizer = CopathologyVisualizer(
    output_dir=f'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/results/nmf_{PARAM}-raw_va_log/{K_TOPICS}/figures',
)
dx_radar = nmf_visualizer.plot_diagnosis_topic_profiles(
    theta = W_df.loc[:,[f"Topic_{k}" for k in range(K_TOPICS)]],
    dx_labels=W_df['DX']
)
top_topics = nmf_visualizer.plot_top_regions_per_topic(
    topic_patterns=Hn,
    region_names=dkt_rois[-62:]
)
surface_mapper = BrainVisualizer(
    output_dir=f'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/results/nmf_{PARAM}-raw_va_log/{K_TOPICS}/surface_maps',
    clim=(0,Hn.max())
)
surface_mapper.plot_all_topics(
    topic_patterns=Hn
)

### NACC INFERENCE

In [None]:
# mean topic expression per diagnosis## - for VA reconstruction
topic_cols = [c for c in W_df.columns if c.startswith('Topic_')]
topic_means = W_df.loc[:, topic_cols + [dx_label]].groupby(dx_label).mean()

# simple AD vs non-AD contrast
ad_mean = topic_means.loc["AD"]
non_ad_mean = topic_means.drop(index="AD").mean()

ad_contrast = ad_mean - non_ad_mean
ad_contrast = ad_contrast.sort_values(ascending=False)

print(ad_contrast)

# ad_topics = ad_contrast.index[-2:].tolist()
ad_topics = ad_contrast[ad_contrast>0].index.to_list()
non_ad_topics = ad_contrast[ad_contrast<0].index.to_list()
print("AD-related topics:", ad_topics)

ad_topic_idx = [int(t.split("_")[1]) for t in ad_topics]
non_ad_topic_idx = [int(t.split("_")[1]) for t in non_ad_topics]
print(ad_topic_idx)
print(non_ad_topic_idx)

In [None]:
# nacc_df_orig = pd.read_csv(data_path+'/nacc/NACC_external_260212_all_dx_lda.csv')
# nacc_df_orig = pd.read_csv(data_path+'/nacc/NACC_external_260212_ad_nonad_lda.csv')
##################################
nacc_data_path = 'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/data/nacc'
# nacc_df_orig = pd.read_csv(os.path.join(nacc_data_path, 'NACC_external_260212.csv'))
nacc_df_orig = pd.read_csv(os.path.join(nacc_data_path, 'NACC_external_260212_inferenced_ref_model.csv'))
nacc_df_orig = aa_stage_resilient_group(nacc_df_orig)

print(nacc_df_orig.shape)

va_cols = nacc_df_orig.loc[:, 'VA/2':'VA/2035'].columns
nacc_df_orig = nacc_df_orig.dropna(subset=va_cols)
print(nacc_df_orig.shape)
#################################

nacc_inf_df = nacc_df_orig.dropna(subset=region_cols)
print(nacc_inf_df.shape)
X_ext = nacc_inf_df[region_cols]
Y_ext = PARAM-X_ext
W_ext = nmf_transormed_va.transform(Y_ext)
W_ext_df = pd.DataFrame(
    W_ext,
    columns=[f"nmf/Topic_{k}" for k in range(W_ext.shape[1])],
    index=nacc_inf_df.index
)
W_ext_df.insert(0, 'FULL_ID', nacc_inf_df['FULL_ID'].values)

ad_decomposed = W_ext[:, ad_topic_idx] @ H[ad_topic_idx, :]
X_ext_weighted = PARAM - ad_decomposed
X_ext_weighted_df = pd.DataFrame(
    X_ext_weighted,
    columns=region_cols,
    index=nacc_inf_df.index
)

nacc_df_orig_weighted = nacc_inf_df.copy()
nacc_df_orig_weighted.loc[:, region_cols] = X_ext_weighted_df.loc[nacc_df_orig_weighted.index, region_cols]

xgb_results = nmf_classifier.predict_with_proba(
    W_ext,
    subject_ids=W_ext_df['FULL_ID'].values
)
W_ext_df = W_ext_df.merge(
    xgb_results,
    left_on='FULL_ID',
    right_on='subject_id',
    how='left'
).drop(columns='subject_id')
prob_cols = [c for c in W_ext_df.columns if c.startswith("P(")]
W_ext_df.rename(
    columns={c: f"nmf_{c}" for c in prob_cols},
    inplace=True
)
W_ext_df.rename(columns={'DX_pred': 'nmf_DX_pred'}, inplace=True)
# nacc_df_orig_weighted.to_csv(data_path+'/nacc/nacc_stage_external/NACC_external_260206_nmf_adjusted.csv', index=False) ## reconstructed VA

nmf_results = pd.merge(nacc_df_orig, W_ext_df, on='FULL_ID', how='left')
print(nmf_results.shape)
# nmf_results.to_csv(data_path+'/nacc/NACC_external_260212_all_dx_lda_nmf.csv', index=False)
# nmf_results.to_csv(data_path+'/nacc/NACC_external_260212_ad_nonad_lda_nmf.csv', index=False)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr
from LDA_XGB.visualizer import CopathologyVisualizer

def nacc_copath_radar_plot(inp_df, copath_col='copath', copath_count_col='copath_counts', y_max=0.35, prefix=''):
    plot_df = inp_df.dropna(subset=copath_col)
    plot_df[copath_col] = plot_df[copath_col].str.replace("FTD_ANY", "FTD", regex=False) ## HARD-CODING
    for i in plot_df[copath_count_col].unique():
        print(i)
        temp_df = plot_df[plot_df[copath_count_col] == i]
        vis = CopathologyVisualizer(
            output_dir='C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/results/nacc_copath_figures'
        )
        print(temp_df.shape)
        topic_cols = [c for c in temp_df.columns if c.startswith('nmf/Topic_')]
        print(topic_cols)
        copath_radar = vis.plot_diagnosis_topic_profiles(
            theta=temp_df[topic_cols],
            dx_labels=temp_df[copath_col],
            topic_names=topic_cols,
            y_max=y_max,
            title = f'{prefix} Diagnostics Topic Profiles ({copath_col}={int(i)})',
            save=False
        )
    # plt.show()
    return copath_radar


nacc_copath_radar_plot(nmf_results, y_max=0.7, prefix='NACC')

In [None]:
prob_cols = ['nmf_P(NC)', 'nmf_P(AD)', 'nmf_P(PD)', 'nmf_P(DLB)', 'nmf_P(SVAD)', 'nmf_P(bvFTD)', 'nmf_P(nfvPPA)', 'nmf_P(svPPA)']
group_col = 'tau_stage_aa/res_group'
group_order = ['Underpredict', 'Canonical', 'Overpredict']
plot_df = nmf_results.dropna(subset=group_col)
plot_df = plot_df[plot_df['TRACER']!='MK6240']
print(plot_df.shape)
plot_df = plot_df[plot_df['DX'].isin(['CN'])]
print(plot_df.shape)

res_group_mean_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_subjectwise_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_topic_radar_plot(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')

In [None]:
prob_cols = ['nmf_P(NC)', 'nmf_P(AD)', 'nmf_P(non-AD)']
group_col = 'tau_stage_aa/res_group'
group_order = ['Underpredict', 'Canonical', 'Overpredict']
plot_df = nmf_results.dropna(subset=group_col)
plot_df = plot_df[plot_df['TRACER']!='MK6240']
print(plot_df.shape)
plot_df = plot_df[plot_df['DX'].isin(['MCI', 'AD'])]
print(plot_df.shape)

res_group_mean_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_subjectwise_heatmap(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')
res_group_topic_radar_plot(plot_df, prob_cols=prob_cols, group_col=group_col, group_order=group_order, prefix='NACC')