In [None]:
import os
import numpy as np
import pandas as pd
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report
from sklearn.model_selection import StratifiedKFold, cross_val_predict
import matplotlib.pyplot as plt

regions = 'dkt95'
# regions = 'dkt62'


## WSEV data prep
new_wsev = pd.read_csv("C:/Users/BREIN/Desktop/copathology_visualization_temp/data/260108_wsev_final_df.csv")
hc_df = new_wsev[new_wsev['DX'] == 'HC']

df_wsev = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/data/wsev_old_100_data.csv')
df_wsev = df_wsev.dropna()
print(df_wsev['DX'].value_counts())

if regions == 'dkt62':
    region_cols = df_wsev.loc[:, 'VA/1002':'VA/2035'].columns
    X_hc = hc_df[hc_df.loc[:,'ctx_lh_caudalanteriorcingulate':'ctx_rh_insula'].columns].values.astype(float)
    print(f"HC: {X_hc.shape[0]} subjects")
    hc_mean = X_hc.mean(axis=0, keepdims=True)
    hc_std  = X_hc.std(axis=0, keepdims=True) + 1e-8  # avoid divide-by-zero

elif regions == 'dkt95':
    region_cols = df_wsev.loc[:, 'VA/2':'VA/2035'].columns
    X_hc = hc_df[hc_df.loc[:,'Left_Cerebral_White_Matter':'ctx_rh_insula'].columns].values.astype(float)
    print(f"HC: {X_hc.shape[0]} subjects")
    hc_mean = X_hc.mean(axis=0, keepdims=True)
    hc_std  = X_hc.std(axis=0, keepdims=True) + 1e-8  # avoid divide-by-zero

X_wsev = df_wsev[region_cols].values.astype(float)
print(f"Patients: {X_wsev.shape[0]} subjects")

Z_wsev = (X_wsev - hc_mean) / hc_std
X_wsev = np.maximum(-Z_wsev, 0.0)
X_wsev[X_wsev < 0] = 0.0

## SMC data prep
df_smc = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/data/SMC_AD_FTD_VA_final.csv')
df_nc = df_smc[df_smc['DX'] == 'NC']
df_smc_pat = df_smc[df_smc['DX'] != 'NC']
print(df_smc_pat['DX'].value_counts())

if regions == 'dkt62':
    region_cols = df_smc_pat.loc[:, 'VA/1002':'VA/2035'].columns
    X_nc = df_nc[df_nc.loc[:,'VA/1002':'VA/2035'].columns].values.astype(float)
    print(f"NC: {X_nc.shape[0]} subjects")
    nc_mean = X_nc.mean(axis=0, keepdims=True)
    nc_std  = X_nc.std(axis=0, keepdims=True) + 1e-8  # avoid divide-by-zero

elif regions == 'dkt95':
    region_cols = df_smc_pat.loc[:, 'VA/2':'VA/2035'].columns
    X_nc = df_nc[df_nc.loc[:,'VA/2':'VA/2035'].columns].values.astype(float)
    print(f"NC: {X_nc.shape[0]} subjects")
    nc_mean = X_nc.mean(axis=0, keepdims=True)
    nc_std  = X_nc.std(axis=0, keepdims=True) + 1e-8  # avoid divide-by-zero

X_smc = df_smc_pat[region_cols].values.astype(float)
print(f"Patients: {X_smc.shape[0]} subjects")

Z_smc = (X_smc - nc_mean) / nc_std
X_smc = np.maximum(-Z_smc, 0.0)
X_smc[X_smc < 0] = 0.0


In [None]:
## combine cohorts
# -------------------------------
# WSEV cohort
# -------------------------------
df_wsev_x = pd.DataFrame(
    X_wsev,
    columns=region_cols
)

df_wsev_x.insert(0, "SUBJ_ID", df_wsev["PTID"].values)
df_wsev_x.insert(1, "DX", df_wsev["DX"].values)


# -------------------------------
# SMC cohort
# -------------------------------
df_smc_x = pd.DataFrame(
    X_smc,
    columns=region_cols
)

df_smc_x.insert(0, "SUBJ_ID", df_smc_pat["PTID"].values)
df_smc_x.insert(1, "DX", df_smc_pat["DX"].values)


# df_smc_x = downsample_to_n_per_class(df_smc)

# -------------------------------
# Safety check: align ROI columns
# -------------------------------
meta_cols = ["SUBJ_ID", "DX"]
roi_cols = list(region_cols)  # already aligned by construction

df_wsev_x = df_wsev_x[meta_cols + roi_cols]
# df_smc_x  = df_smc_x[['PTID', 'DX'] + roi_cols]
df_smc_x  = df_smc_x[meta_cols + roi_cols]


# -------------------------------
# Concatenate cohorts
# -------------------------------
df_combined = pd.concat(
    [df_wsev_x, df_smc_x],
    axis=0,
    ignore_index=True
)

df_combined = df_combined.dropna()
print("Combined shape:", df_combined.shape)
# print(df_combined.head())
print(df_combined['DX'].value_counts())


In [None]:
import pandas as pd
import numpy as np
from data_processor import *
from lda_model import LDATopicModel
from classifier import TopicClassifier
from visualizer import *
from brain_visualizer import *

N_TOPICS = 18 ###
inp_df = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/data/260120_wsev_smc_combined_zscores.csv')

N = 25
dx_col = "DX"
balanced_parts = []

for dx, g in inp_df.groupby(dx_col):
    # if dx == 'AD':
    #     N=50
    # else: 
    #     N=25
    if len(g) > N:
        g = g.sample(n=N, replace=False, random_state=42)
    balanced_parts.append(g)

balanced_df = pd.concat(balanced_parts).reset_index(drop=True)
print(balanced_df[dx_col].value_counts())

inp_df = balanced_df

# Define your region columns (adjust to match your data)
region_cols = list(inp_df.loc[:, "VA/2":"VA/2035"].columns)
labels = inp_df["DX"].values
ids = inp_df["SUBJ_ID"].values

# Fit LDA on combined z-scores
lda = LDATopicModel(n_topics=N_TOPICS)
theta = lda.fit_transform(inp_df[region_cols])

# Fit classifier
classifier = TopicClassifier(n_splits=5)
classifier.fit(theta, labels)
cv_results = classifier.cross_validate(theta, labels, ids, verbose=True)

print(f"CV Accuracy: {cv_results['accuracy']:.4f}")

# visualizer = CopathologyVisualizer(
#     output_dir='./test/'
# )

In [None]:
## Model Internal Visualization ##
# print(classifier.get_confusion_matrix())
# print(lda.get_topic_patterns())

topic_patterns = lda.get_topic_patterns()


# fig4 = visualizer.plot_confusion_matrix(
#     cm=classifier.get_confusion_matrix(),
#     class_names=classifier._classes
# )
# fig = visualizer.plot_topic_heatmap(
#     topic_patterns=topic_patterns,
#     region_names=region_cols,
#     topic_names=['Thal', 'LF', 'P', 'RT', 'RF', 'LT']
# )
# fig2 = visualizer.plot_topic_distribution_by_dx(
#     theta=lda._theta,
#     dx_labels=labels,
#     topic_names=['Thal', 'LF', 'P', 'RT', 'RF', 'LT']
# )
# fig3 = visualizer.plot_diagnosis_topic_profiles(
#     theta=lda._theta,
#     dx_labels = labels,
#     label_map={'Topic_0': 'Thal', 'Topic_1': 'LF', 'Topic_2': 'P', 'Topic_3': 'RT', 'Topic_4': 'RF', 'Topic_5': 'LT'}
# )
# fig_conf_matrix = visualizer.plot_confusion_matrix(
#     cm=classifier.get_confusion_matrix(),
#     class_names=classifier._classes,
#     accuracy=cv_results['accuracy'],
#     title='5-Fold CV Confusion Matrix'
# )
# prediction_probabilities = visualizer.plot_prediction_probabilities(
#     proba_df=classifier.get_cv_results(),
#     dx_order = ["AD", "DLB", "PD", "SVAD", "bvFTD", "nfvPPA", "svPPA"]
# )
# probabilities_heatmap = visualizer.plot_probability_heatmap(
#     proba_df=classifier.get_cv_results(),
#     dx_order = ["AD", "DLB", "PD", "SVAD", "bvFTD", "nfvPPA", "svPPA"]
# )

# top_regions = visualizer.plot_top_regions_per_topic(
#     topic_patterns = lda.get_topic_patterns(),
#     region_names=region_cols,
#     topic_names=['Thal', 'LF', 'P', 'RT', 'RF', 'LT']
# )

# cv_results = classifier.get_cv_results()
# copathology_stacked_bar = visualizer.plot_copathology_stacked_bars(
#     theta=lda._theta,
#     dx_labels=labels,
#     # topic_names=['Topic_0', 'Topic_1', 'Topic_2', 'Topic_3', 'Topic_4', 'Topic_5']
#     label_map={'Topic_0': 'Thal', 'Topic_1': 'LF', 'Topic_2': 'P', 'Topic_3': 'RT', 'Topic_4': 'RF', 'Topic_5': 'LT'},
#     predictions=cv_results["DX_pred"].values,
#     proba_df=cv_results
# )

# feature_importance = visualizer.plot_feature_importance(
#     importance_df=classifier.get_feature_importance()
# )

In [None]:
## ADNI Inference ## 260120
adni_raw = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/data/stage_data/ptau_volume_model/ADNI_3.csv')
region_cols = adni_raw.loc[:, 'VA/2':'VA/2035'].columns
adni_cn = adni_raw[adni_raw['DX'] == 'CN']
adni_pat = adni_raw[adni_raw['DX'] != 'CN']

adni_stage_df = adni_pat[['FULL_ID', 'DX', 'tau_stage_aa/low', 'tau_stage_aa/mid', 'tau_stage_aa/high', 'pred_tau_stage_aa/low', 'pred_tau_stage_aa/mid', 'pred_tau_stage_aa/high']].dropna()
prob_cols = ['pred_tau_stage_aa/low','pred_tau_stage_aa/mid','pred_tau_stage_aa/high']
max_col = adni_stage_df[prob_cols].idxmax(axis=1)
adni_stage_df[prob_cols] = 0
adni_stage_df.loc[:, prob_cols] = (pd.get_dummies(max_col).reindex(columns=prob_cols, fill_value=0).astype(float))
stage_map = {
    'low': 0,
    'mid': 1,
    'high': 2
}
def get_stage(colname):
    return stage_map[colname.split('/')[-1]]
adni_stage_df['gt_stage'] = (
    adni_stage_df[['tau_stage_aa/low', 'tau_stage_aa/mid', 'tau_stage_aa/high']]
    .idxmax(axis=1)
    .apply(get_stage)
)

adni_stage_df['pred_stage'] = (
    adni_stage_df[prob_cols]
    .idxmax(axis=1)
    .apply(get_stage)
)

# --------------------------------
# Subject grouping
# --------------------------------
adni_lower_than_pred = adni_stage_df[
    adni_stage_df['gt_stage'] < adni_stage_df['pred_stage']
][['FULL_ID', 'DX', 'gt_stage', 'pred_stage']]

adni_exact_match = adni_stage_df[
    adni_stage_df['gt_stage'] == adni_stage_df['pred_stage']
][['FULL_ID', 'DX', 'gt_stage', 'pred_stage']]

## Inference ##
adni_prep = DataProcessor(
    region_cols=region_cols,
    dx_col='DX',
    subject_col='FULL_ID'
)
adni_prep.fit_baseline(hc_data=adni_cn)
adni_Z = adni_prep.compute_atrophy_scores(data=adni_pat)
print(adni_Z.shape)
print(adni_cn.shape)

adni_theta = lda.transform(adni_Z)
adni_y_pred = classifier.predict(adni_theta)
adni_y_proba = classifier.predict_proba(adni_theta)

adni_results = pd.DataFrame(adni_theta, columns=[f"Topic_{k}" for k in range(lda.n_topics)])
print(adni_results.shape)

subj_col = adni_prep.subject_col
if subj_col in adni_pat.columns:
    adni_results.insert(0, "SUBJ_ID", adni_pat[subj_col].values)

adni_results['pred_DX'] = adni_y_pred
for i, dx in enumerate(classifier.classes):
    adni_results[f"P({dx})"] = adni_y_proba[:,i]

In [None]:
# ============================================================
# FIGURE 1 â€” INDIVIDUAL SUBJECT DX PROFILES (grouped by DX)
# ============================================================

subject_order_dict = {}  # save the order for Figure 2

for group in df["stage_group"].unique():

    df_g = df[df["stage_group"] == group].copy()
    if df_g.shape[0] < min_subjects:
        continue

    # Sort subjects by DX groups, then by P(AD) within DX
    subject_order = []
    dx_values = df_g["DX"].unique()
    dx_boundaries = []  # store start of each DX group
    current_idx = 0

    for dx in dx_values:
        df_dx = df_g[df_g["DX"] == dx].copy()
        df_dx = df_dx.sort_values(f"P({target_dx})", ascending=False)
        subject_order.extend(df_dx["SUBJ_ID"].tolist())

        dx_count = df_dx.shape[0]
        dx_boundaries.append((current_idx, current_idx + dx_count, dx))
        current_idx += dx_count

    # save order for topic plot
    subject_order_dict[group] = subject_order

    # Reorder df_g
    df_g = df_g.set_index("SUBJ_ID").loc[subject_order]

    # stacked DX probability bar
    fig, ax = plt.subplots(figsize=(12, 4))
    bottom = np.zeros(len(df_g))

    for i, col in enumerate(prob_cols):
        ax.bar(
            np.arange(len(df_g)),
            df_g[col].values,
            bottom=bottom,
            color=colors_dx[i],
            width=1.0,
            label=dx_labels[i]
        )
        bottom += df_g[col].values

    ax.set_ylim(0, 1)
    ax.set_xticks([])  # we will add custom DX labels

    # Add vertical lines and DX labels
    for start, end, dx in dx_boundaries:
        ax.axvline(start - 0.5, color='black', linewidth=1.2)
        ax.text((start + end) / 2 - 0.5, -0.05, dx,
                ha='center', va='top', fontsize=12)
    ax.axvline(end - 0.5, color='black', linewidth=1.2)

    ax.set_ylabel("Predicted DX probability")
    ax.set_title(f"{group} subjects â€” grouped by DX, ordered by P({target_dx})")

    ax.legend(
        title="Predicted DX",
        bbox_to_anchor=(1.02, 1),
        loc="upper left"
    )

    plt.tight_layout()
    plt.show()


# ============================================================
# FIGURE 2 â€” TOPIC DISTRIBUTION PER SUBJECT
# ============================================================

topic_cols = list(label_map.keys())
topic_labels = [label_map[t] for t in topic_cols]
colors_topic = sns.color_palette("Dark2", len(topic_cols))

for group in df["stage_group"].unique():

    subj_order = subject_order_dict[group]
    df_g = df[df["SUBJ_ID"].isin(subj_order)].copy()
    df_g = df_g.set_index("SUBJ_ID").loc[subj_order]

    # Determine DX boundaries
    dx_values = df_g["DX"].unique()
    dx_boundaries = []
    current_idx = 0
    for dx in dx_values:
        df_dx = df_g[df_g["DX"] == dx]
        dx_count = df_dx.shape[0]
        dx_boundaries.append((current_idx, current_idx + dx_count, dx))
        current_idx += dx_count

    # Plot stacked topic bar
    fig, ax = plt.subplots(figsize=(12, 4))
    bottom = np.zeros(len(df_g))

    for i, col in enumerate(topic_cols):
        ax.bar(
            np.arange(len(df_g)),
            df_g[col].values,
            bottom=bottom,
            color=colors_topic[i],
            width=1.0,
            label=topic_labels[i]
        )
        bottom += df_g[col].values

    ax.set_ylim(0, 1)
    ax.set_xticks([])

    # Add vertical lines and DX labels
    for start, end, dx in dx_boundaries:
        ax.axvline(start - 0.5, color='black', linewidth=1.2)
        ax.text((start + end) / 2 - 0.5, -0.05, dx,
                ha='center', va='top', fontsize=12)
    ax.axvline(end - 0.5, color='black', linewidth=1.2)

    ax.set_ylabel("Topic weight")
    ax.set_title(f"{group} subjects â€” topic distribution grouped by DX")

    ax.legend(
        title="Topics",
        bbox_to_anchor=(1.02, 1),
        loc="upper left",
        ncol=1
    )

    plt.tight_layout()
    plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# ============================================================
# SETTINGS
# ============================================================
label_map = {'Topic_0': 'Thal', 'Topic_1': 'LF', 'Topic_2': 'P',
             'Topic_3': 'RT', 'Topic_4': 'RF', 'Topic_5': 'LT'}
target_dx = "AD"
min_subjects = 5

prob_cols = [c for c in adni_results.columns if c.startswith("P(")]
dx_labels = [c.replace("P(", "").replace(")", "") for c in prob_cols]

topic_cols = [c for c in adni_results.columns if c.startswith("Topic_")]
topic_labels = topic_cols  # can rename if needed

colors_dx = sns.color_palette("tab10", len(dx_labels))
colors_topic = sns.color_palette("tab20", len(topic_cols))

groups = {
    "GT < Pred": adni_lower_than_pred,
    "GT = Pred": adni_exact_match
}

# ============================================================
# MERGE STAGING + INFERENCE
# ============================================================

stage_info = pd.concat(
    [
        adni_lower_than_pred.assign(stage_group="GT < Pred"),
        adni_exact_match.assign(stage_group="GT = Pred")
    ],
    axis=0
)

df = adni_results.merge(
    stage_info,
    left_on="SUBJ_ID",
    right_on="FULL_ID",
    how="inner"
)

print("Subjects included:", df.shape[0])
# ============================================================
# FIGURE 1 â€” INDIVIDUAL SUBJECT DX PROFILES (grouped by DX)
# ============================================================

subject_order_dict = {}  # save the order for Figure 2

for group in df["stage_group"].unique():

    df_g = df[df["stage_group"] == group].copy()
    if df_g.shape[0] < min_subjects:
        continue

    # Sort subjects by DX groups, then by P(AD) within DX
    subject_order = []
    # dx_values = df_g["DX"].unique()
    dx_values = ['MCI', 'Dementia']
    dx_boundaries = []  # store start of each DX group
    current_idx = 0

    for dx in dx_values:
        df_dx = df_g[df_g["DX"] == dx].copy()
        df_dx = df_dx.sort_values(f"P({target_dx})", ascending=False)
        subject_order.extend(df_dx["SUBJ_ID"].tolist())

        dx_count = df_dx.shape[0]
        dx_boundaries.append((current_idx, current_idx + dx_count, dx))
        current_idx += dx_count

    # save order for topic plot
    subject_order_dict[group] = subject_order

    # Reorder df_g
    df_g = df_g.set_index("SUBJ_ID").loc[subject_order]

    # stacked DX probability bar
    fig, ax = plt.subplots(figsize=(12, 4))
    bottom = np.zeros(len(df_g))

    for i, col in enumerate(prob_cols):
        ax.bar(
            np.arange(len(df_g)),
            df_g[col].values,
            bottom=bottom,
            color=colors_dx[i],
            width=1.0,
            label=dx_labels[i]
        )
        bottom += df_g[col].values

    ax.set_ylim(0, 1)
    ax.set_xticks([])  # we will add custom DX labels

    # Add vertical lines and DX labels
    for start, end, dx in dx_boundaries:
        ax.axvline(start - 0.5, color='black', linewidth=1.2)
        ax.text((start + end) / 2 - 0.5, -0.05, dx,
                ha='center', va='top', fontsize=12)
    ax.axvline(end - 0.5, color='black', linewidth=1.2)

    ax.set_ylabel("Predicted DX probability")
    ax.set_title(f"{group} subjects â€” grouped by DX, ordered by P({target_dx})")

    ax.legend(
        title="Predicted DX",
        bbox_to_anchor=(1.02, 1),
        loc="upper left"
    )

    plt.tight_layout()
    plt.show()


# ============================================================
# FIGURE 2 â€” TOPIC DISTRIBUTION PER SUBJECT
# ============================================================

topic_cols = list(label_map.keys())
topic_labels = [label_map[t] for t in topic_cols]
colors_topic = sns.color_palette("Dark2", len(topic_cols))

for group in df["stage_group"].unique():

    subj_order = subject_order_dict[group]
    df_g = df[df["SUBJ_ID"].isin(subj_order)].copy()
    df_g = df_g.set_index("SUBJ_ID").loc[subj_order]

    # Determine DX boundaries
    # dx_values = df_g["DX"].unique()
    dx_values = ['MCI', 'Dementia']
    dx_boundaries = []
    current_idx = 0
    for dx in dx_values:
        df_dx = df_g[df_g["DX"] == dx]
        dx_count = df_dx.shape[0]
        dx_boundaries.append((current_idx, current_idx + dx_count, dx))
        current_idx += dx_count

    # Plot stacked topic bar
    fig, ax = plt.subplots(figsize=(12, 4))
    bottom = np.zeros(len(df_g))

    for i, col in enumerate(topic_cols):
        ax.bar(
            np.arange(len(df_g)),
            df_g[col].values,
            bottom=bottom,
            color=colors_topic[i],
            width=1.0,
            label=topic_labels[i]
        )
        bottom += df_g[col].values

    ax.set_ylim(0, 1)
    ax.set_xticks([])

    # Add vertical lines and DX labels
    for start, end, dx in dx_boundaries:
        ax.axvline(start - 0.5, color='black', linewidth=1.2)
        ax.text((start + end) / 2 - 0.5, -0.05, dx,
                ha='center', va='top', fontsize=12)
    ax.axvline(end - 0.5, color='black', linewidth=1.2)

    ax.set_ylabel("Topic weight")
    ax.set_title(f"{group} subjects â€” topic distribution grouped by DX")

    ax.legend(
        title="Topics",
        bbox_to_anchor=(1.02, 1),
        loc="upper left",
        ncol=1
    )

    plt.tight_layout()
    plt.show()

# ============================================================
# FIGURE 3 â€” MEAN DX COMPOSITION
# ============================================================

mean_df = (
    df.groupby("stage_group")[prob_cols]
    .mean()
    .reset_index()
)

fig, ax = plt.subplots(figsize=(7, 4))
bottom = np.zeros(mean_df.shape[0])

for i, col in enumerate(prob_cols):
    ax.bar(
        mean_df["stage_group"],
        mean_df[col],
        bottom=bottom,
        color=colors_dx[i],
        label=dx_labels[i]
    )
    bottom += mean_df[col].values

ax.set_ylim(0, 1)
ax.set_ylabel("Mean predicted probability")
ax.set_title("Predicted DX enrichment by tau-stage mismatch")

ax.legend(
    title="Predicted DX",
    bbox_to_anchor=(1.02, 1),
    loc="upper left"
)

plt.tight_layout()
plt.show()

# ============================================================
# FIGURE 4 â€” COPATHOLOGY INDEX
# ============================================================

# define non-AD copathology signal
non_ad_cols = [c for c in prob_cols if "AD" not in c]
df["copath_index"] = df[non_ad_cols].sum(axis=1)

plt.figure(figsize=(6, 4))
sns.boxplot(
    data=df,
    x="stage_group",
    y="copath_index"
)

plt.ylabel("Non-AD copathology probability")
plt.title("Copathology enrichment in tau-stage mismatch")
plt.tight_layout()
plt.show()


In [None]:
## ADNI4 Inference ## 260120
adni4_raw = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/data/stage_data/ptau_volume_model/ADNI4_3.csv')
region_cols = adni4_raw.loc[:, 'VA/2':'VA/2035'].columns
adni4_raw = adni4_raw.dropna(subset=region_cols)
adni4_cn = adni4_raw[adni4_raw['DX'] == 'CN']
adni4_pat = adni4_raw[adni4_raw['DX'] != 'CN']
adni4_stage_df = adni4_pat[['FULL_ID', 'DX', 'tau_stage_aa/low', 'tau_stage_aa/mid', 'tau_stage_aa/high', 'pred_tau_stage_aa/low', 'pred_tau_stage_aa/mid', 'pred_tau_stage_aa/high']].dropna()
prob_cols = ['pred_tau_stage_aa/low','pred_tau_stage_aa/mid','pred_tau_stage_aa/high']
max_col = adni4_stage_df[prob_cols].idxmax(axis=1)
adni4_stage_df[prob_cols] = 0
adni4_stage_df.loc[:, prob_cols] = (pd.get_dummies(max_col).reindex(columns=prob_cols, fill_value=0).astype(float))
stage_map = {
    'low': 0,
    'mid': 1,
    'high': 2
}
def get_stage(colname):
    return stage_map[colname.split('/')[-1]]
adni4_stage_df['gt_stage'] = (
    adni4_stage_df[['tau_stage_aa/low', 'tau_stage_aa/mid', 'tau_stage_aa/high']]
    .idxmax(axis=1)
    .apply(get_stage)
)

adni4_stage_df['pred_stage'] = (
    adni4_stage_df[prob_cols]
    .idxmax(axis=1)
    .apply(get_stage)
)

# --------------------------------
# Subject grouping
# --------------------------------
adni4_lower_than_pred = adni4_stage_df[
    adni4_stage_df['gt_stage'] < adni4_stage_df['pred_stage']
][['FULL_ID', 'DX', 'gt_stage', 'pred_stage']]

adni4_exact_match = adni4_stage_df[
    adni4_stage_df['gt_stage'] == adni4_stage_df['pred_stage']
][['FULL_ID', 'DX', 'gt_stage', 'pred_stage']]



adni4_prep = DataProcessor(
    region_cols=region_cols,
    dx_col='DX',
    subject_col='FULL_ID'
)
adni4_prep.fit_baseline(hc_data=adni4_cn)
adni4_Z = adni4_prep.compute_atrophy_scores(data=adni4_pat)
print(adni4_Z.shape)
print(adni4_cn.shape)

adni4_theta = lda.transform(adni4_Z)
adni4_y_pred = classifier.predict(adni4_theta)
adni4_y_proba = classifier.predict_proba(adni4_theta)

adni4_results = pd.DataFrame(adni4_theta, columns=[f"Topic_{k}" for k in range(lda.n_topics)])
print(adni4_results.shape)

subj_col = adni4_prep.subject_col
if subj_col in adni4_pat.columns:
    adni4_results.insert(0, "SUBJ_ID", adni4_pat[subj_col].values)

adni4_results['pred_DX'] = adni4_y_pred
for i, dx in enumerate(classifier.classes):
    adni4_results[f"P({dx})"] = adni4_y_proba[:,i]

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# ============================================================
# SETTINGS
# ============================================================
label_map = {'Topic_0': 'Thal', 'Topic_1': 'LF', 'Topic_2': 'P',
             'Topic_3': 'RT', 'Topic_4': 'RF', 'Topic_5': 'LT'}
target_dx = "AD"
min_subjects = 5

prob_cols = [c for c in adni4_results.columns if c.startswith("P(")]
dx_labels = [c.replace("P(", "").replace(")", "") for c in prob_cols]

topic_cols = [c for c in adni4_results.columns if c.startswith("Topic_")]
topic_labels = topic_cols  # can rename if needed

colors_dx = sns.color_palette("tab10", len(dx_labels))
colors_topic = sns.color_palette("tab20", len(topic_cols))

groups = {
    "GT < Pred": adni4_lower_than_pred,
    "GT = Pred": adni4_exact_match
}

# ============================================================
# MERGE STAGING + INFERENCE
# ============================================================

stage_info = pd.concat(
    [
        adni4_lower_than_pred.assign(stage_group="GT < Pred"),
        adni4_exact_match.assign(stage_group="GT = Pred")
    ],
    axis=0
)

df = adni4_results.merge(
    stage_info,
    left_on="SUBJ_ID",
    right_on="FULL_ID",
    how="inner"
)

print("Subjects included:", df.shape[0])
# ============================================================
# FIGURE 1 â€” INDIVIDUAL SUBJECT DX PROFILES (grouped by DX)
# ============================================================

subject_order_dict = {}  # save the order for Figure 2

for group in df["stage_group"].unique():

    df_g = df[df["stage_group"] == group].copy()
    if df_g.shape[0] < min_subjects:
        continue

    # Sort subjects by DX groups, then by P(AD) within DX
    subject_order = []
    # dx_values = df_g["DX"].unique()
    dx_values = ['MCI', 'AD']
    dx_boundaries = []  # store start of each DX group
    current_idx = 0

    for dx in dx_values:
        df_dx = df_g[df_g["DX"] == dx].copy()
        df_dx = df_dx.sort_values(f"P({target_dx})", ascending=False)
        subject_order.extend(df_dx["SUBJ_ID"].tolist())

        dx_count = df_dx.shape[0]
        dx_boundaries.append((current_idx, current_idx + dx_count, dx))
        current_idx += dx_count

    # save order for topic plot
    subject_order_dict[group] = subject_order

    # Reorder df_g
    df_g = df_g.set_index("SUBJ_ID").loc[subject_order]

    # stacked DX probability bar
    fig, ax = plt.subplots(figsize=(12, 4))
    bottom = np.zeros(len(df_g))

    for i, col in enumerate(prob_cols):
        ax.bar(
            np.arange(len(df_g)),
            df_g[col].values,
            bottom=bottom,
            color=colors_dx[i],
            width=1.0,
            label=dx_labels[i]
        )
        bottom += df_g[col].values

    ax.set_ylim(0, 1)
    ax.set_xticks([])  # we will add custom DX labels

    # Add vertical lines and DX labels
    for start, end, dx in dx_boundaries:
        ax.axvline(start - 0.5, color='black', linewidth=1.2)
        ax.text((start + end) / 2 - 0.5, -0.05, dx,
                ha='center', va='top', fontsize=12)
    ax.axvline(end - 0.5, color='black', linewidth=1.2)

    ax.set_ylabel("Predicted DX probability")
    ax.set_title(f"{group} subjects â€” grouped by DX, ordered by P({target_dx})")

    ax.legend(
        title="Predicted DX",
        bbox_to_anchor=(1.02, 1),
        loc="upper left"
    )

    plt.tight_layout()
    plt.show()


# ============================================================
# FIGURE 2 â€” TOPIC DISTRIBUTION PER SUBJECT
# ============================================================

topic_cols = list(label_map.keys())
topic_labels = [label_map[t] for t in topic_cols]
colors_topic = sns.color_palette("Dark2", len(topic_cols))

for group in df["stage_group"].unique():

    subj_order = subject_order_dict[group]
    df_g = df[df["SUBJ_ID"].isin(subj_order)].copy()
    df_g = df_g.set_index("SUBJ_ID").loc[subj_order]

    # Determine DX boundaries
    # dx_values = df_g["DX"].unique()
    dx_values = ['MCI', 'AD']
    dx_boundaries = []
    current_idx = 0
    for dx in dx_values:
        df_dx = df_g[df_g["DX"] == dx]
        dx_count = df_dx.shape[0]
        dx_boundaries.append((current_idx, current_idx + dx_count, dx))
        current_idx += dx_count

    # Plot stacked topic bar
    fig, ax = plt.subplots(figsize=(12, 4))
    bottom = np.zeros(len(df_g))

    for i, col in enumerate(topic_cols):
        ax.bar(
            np.arange(len(df_g)),
            df_g[col].values,
            bottom=bottom,
            color=colors_topic[i],
            width=1.0,
            label=topic_labels[i]
        )
        bottom += df_g[col].values

    ax.set_ylim(0, 1)
    ax.set_xticks([])

    # Add vertical lines and DX labels
    for start, end, dx in dx_boundaries:
        ax.axvline(start - 0.5, color='black', linewidth=1.2)
        ax.text((start + end) / 2 - 0.5, -0.05, dx,
                ha='center', va='top', fontsize=12)
    ax.axvline(end - 0.5, color='black', linewidth=1.2)

    ax.set_ylabel("Topic weight")
    ax.set_title(f"{group} subjects â€” topic distribution grouped by DX")

    ax.legend(
        title="Topics",
        bbox_to_anchor=(1.02, 1),
        loc="upper left",
        ncol=1
    )

    plt.tight_layout()
    plt.show()

# ============================================================
# FIGURE 3 â€” MEAN DX COMPOSITION
# ============================================================

mean_df = (
    df.groupby("stage_group")[prob_cols]
    .mean()
    .reset_index()
)

fig, ax = plt.subplots(figsize=(7, 4))
bottom = np.zeros(mean_df.shape[0])

for i, col in enumerate(prob_cols):
    ax.bar(
        mean_df["stage_group"],
        mean_df[col],
        bottom=bottom,
        color=colors_dx[i],
        label=dx_labels[i]
    )
    bottom += mean_df[col].values

ax.set_ylim(0, 1)
ax.set_ylabel("Mean predicted probability")
ax.set_title("Predicted DX enrichment by tau-stage mismatch")

ax.legend(
    title="Predicted DX",
    bbox_to_anchor=(1.02, 1),
    loc="upper left"
)

plt.tight_layout()
plt.show()

# ============================================================
# FIGURE 4 â€” COPATHOLOGY INDEX
# ============================================================

# define non-AD copathology signal
non_ad_cols = [c for c in prob_cols if "AD" not in c]
df["copath_index"] = df[non_ad_cols].sum(axis=1)

plt.figure(figsize=(6, 4))
sns.boxplot(
    data=df,
    x="stage_group",
    y="copath_index"
)

plt.ylabel("Non-AD copathology probability")
plt.title("Copathology enrichment in tau-stage mismatch")
plt.tight_layout()
plt.show()


In [None]:
## NACC Inference ## 260120
from data_processor import *
nacc_raw = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/data/nacc/260120_NACC_VA_TAU_PATH_matched.csv')
# nacc_raw.rename(columns=col_map)
region_cols = nacc_raw.loc[:, 'VA/2':'VA/2035'].columns
pathology_cols = nacc_raw.loc[:, 'NACC_AD':'NACC_svPPA'].columns
nacc_filtered = nacc_raw[nacc_raw['DX'] != 'Unknown']

nacc_cn = nacc_filtered[nacc_filtered['DX'] == 'CN']
nacc_pat = nacc_filtered[nacc_filtered['DX'] != 'CN']

print(nacc_cn.shape)
print(nacc_pat.shape)

nacc_prep = DataProcessor(
    region_cols=region_cols,
    dx_col='DX',
    subject_col='subject_id'
)
nacc_prep.fit_baseline(hc_data=nacc_cn)
nacc_Z = nacc_prep.compute_atrophy_scores(data=nacc_pat)
print(type(nacc_Z))

nacc_theta = lda.transform(nacc_Z)
y_pred = classifier.predict(nacc_theta)
y_proba = classifier.predict_proba(nacc_theta)
# print(nacc_theta.shape)
# print(y_pred.shape)
# print(y_proba.shape)

nacc_results = pd.DataFrame(nacc_theta, columns=[f"Topic_{k}" for k in range(lda.n_topics)])
print(nacc_results.shape)

subj_col = nacc_prep.subject_col
if subj_col in nacc_pat.columns:
    nacc_results.insert(0, "SUBJ_ID", nacc_pat[subj_col].values)

nacc_results['pred_DX'] = y_pred
for i, dx in enumerate(classifier.classes):
    nacc_results[f"P({dx})"] = y_proba[:,i]


***Inference Visualization TESTS***

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from math import ceil
from itertools import cycle

viz_df = nacc_results.merge(
    nacc_pat[[nacc_prep.subject_col] + list(pathology_cols)],
    left_on="SUBJ_ID",
    right_on=nacc_prep.subject_col,
    how="left"
)
# -----------------------------
# Topic label map
# -----------------------------
label_map = {
    'Topic_0': 'Thal',
    'Topic_1': 'LF',
    'Topic_2': 'P',
    'Topic_3': 'RT',
    'Topic_4': 'RF',
    'Topic_5': 'LT'
}

topic_cols = list(label_map.keys())
topic_labels = [label_map[t] for t in topic_cols]

# -----------------------------
# Radar utilities
# -----------------------------
def compute_topic_profile(df):
    values = df[topic_cols].mean().values
    return np.concatenate([values, [values[0]]])


def radar_angles(n):
    angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
    return np.concatenate([angles, [angles[0]]])


# -----------------------------
# Main visualization
# -----------------------------
def plot_pathology_radar_panels(
    viz_df,
    pathology_cols,
    min_subjects=1,
    ncols=4
):
    """
    Creates:
    1) Grid of individual pathology radar plots
    2) One large combined radar plot
    """

    # -------------------------
    # collect valid pathologies
    # -------------------------
    valid = []
    for p in pathology_cols:
        n = (viz_df[p] == 1).sum()
        if n >= min_subjects:
            valid.append((p, n))

    if len(valid) == 0:
        print("No pathology columns with positive subjects.")
        return

    angles = radar_angles(len(topic_cols))
    colors = plt.cm.tab10.colors
    color_cycle = cycle(colors)

    # ======================================================
    # 1. INDIVIDUAL RADAR SUBPLOTS
    # ======================================================
    n_panels = len(valid)
    nrows = ceil(n_panels / ncols)

    fig, axes = plt.subplots(
        nrows=nrows,
        ncols=ncols,
        figsize=(4 * ncols, 4 * nrows),
        subplot_kw=dict(polar=True)
    )

    axes = np.array(axes).reshape(-1)

    for ax, (pathology, n_pos) in zip(axes, valid):

        pos_df = viz_df[viz_df[pathology] == 1]
        values = compute_topic_profile(pos_df)
        color = next(color_cycle)

        ax.plot(angles, values, linewidth=2.5, color=color)
        ax.fill(angles, values, alpha=0.25, color=color)

        ax.set_thetagrids(
            angles[:-1] * 180 / np.pi,
            topic_labels,
            fontsize=10
        )

        ax.set_title(f"{pathology}\n(n={n_pos})", fontsize=11, pad=12)

    # remove unused axes
    for ax in axes[len(valid):]:
        ax.remove()

    fig.suptitle(
        "Topic Expression by Copathology (Positive Subjects Only)",
        fontsize=16,
        y=1.02
    )

    plt.tight_layout()
    plt.show()

    # ======================================================
    # 2. COMBINED RADAR OVERLAY
    # ======================================================
    plt.figure(figsize=(9, 9))
    ax = plt.subplot(111, polar=True)

    legend_labels = []
    color_cycle = cycle(colors)

    for pathology, n_pos in valid:

        pos_df = viz_df[viz_df[pathology] == 1]
        values = compute_topic_profile(pos_df)
        color = next(color_cycle)

        ax.plot(angles, values, linewidth=2.5, color=color)
        ax.fill(angles, values, alpha=0.12, color=color)

        legend_labels.append(f"{pathology} (n={n_pos})")

    ax.set_thetagrids(
        angles[:-1] * 180 / np.pi,
        topic_labels,
        fontsize=12
    )

    ax.set_title(
        "Combined Topic Profiles Across Copathologies",
        fontsize=16,
        pad=30
    )

    ax.legend(
        legend_labels,
        bbox_to_anchor=(1.35, 1.1),
        loc="upper right",
        frameon=False
    )

    plt.tight_layout()
    plt.show()
plot_pathology_radar_panels(
    viz_df=viz_df,
    pathology_cols=pathology_cols,
    ncols=4
)


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# ---------------------------------------
# settings
# ---------------------------------------
target_dx = "AD"           # sort by P(AD)
min_subjects = 5           # skip tiny groups

prob_cols = [c for c in viz_df.columns if c.startswith("P(")]
dx_labels = [c.replace("P(", "").replace(")", "") for c in prob_cols]

colors = sns.color_palette("tab10", len(dx_labels))


# ---------------------------------------
# loop over each pathology
# ---------------------------------------
for path in pathology_cols:

    # select pathology-positive subjects
    df_pos = viz_df[viz_df[path] == 1].copy()

    if df_pos.shape[0] < min_subjects:
        continue

    sort_col = f"P({target_dx})"
    if sort_col not in df_pos.columns:
        continue

    # sort subjects by predicted AD probability
    df_pos = df_pos.sort_values(
        by=sort_col,
        ascending=False
    ).reset_index(drop=True)

    # ---------------------------------------
    # stacked probability plot
    # ---------------------------------------
    fig, ax = plt.subplots(figsize=(12, 4))

    bottom = np.zeros(len(df_pos))

    for i, (dx, col) in enumerate(zip(dx_labels, prob_cols)):
        ax.bar(
            np.arange(len(df_pos)),
            df_pos[col].values,
            bottom=bottom,
            color=colors[i],
            label=dx,
            width=1.0
        )
        bottom += df_pos[col].values

    ax.set_ylim(0, 1)
    ax.set_xlim(-0.5, len(df_pos) - 0.5)

    ax.set_xticks([])
    ax.set_ylabel("Predicted diagnosis probability")

    ax.set_title(
        f"{path}+ subjects (n={len(df_pos)}) â€” sorted by P({target_dx})",
        fontsize=13
    )

    ax.legend(
        title="Predicted DX",
        bbox_to_anchor=(1.02, 1),
        loc="upper left"
    )

    plt.tight_layout()
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

target_dx = "AD"
min_subjects = 5

prob_cols = [c for c in viz_df.columns if c.startswith("P(")]
dx_labels = [c.replace("P(", "").replace(")", "") for c in prob_cols]

for path in pathology_cols:

    df_pos = viz_df[viz_df[path] == 1].copy()

    if df_pos.shape[0] < min_subjects:
        continue

    sort_col = f"P({target_dx})"
    if sort_col not in df_pos.columns:
        continue

    # ðŸ”¥ order subjects by AD probability
    df_pos = df_pos.sort_values(
        by=sort_col,
        ascending=False
    ).reset_index(drop=True)

    # transpose â†’ diagnoses Ã— subjects
    heatmap_data = df_pos[prob_cols].T

    plt.figure(figsize=(10, 4))

    sns.heatmap(
        heatmap_data,
        cmap="Reds",
        vmin=0,
        vmax=1,
        cbar_kws={"label": "Predicted probability"},
        xticklabels=False
    )

    plt.yticks(
        np.arange(len(dx_labels)) + 0.5,
        dx_labels,
        rotation=0
    )

    plt.xlabel("Subjects (ordered by P(AD) â†“)")
    plt.ylabel("Predicted diagnosis")

    plt.title(
        f"{path}+ subjects (n={len(df_pos)})",
        fontsize=13
    )

    plt.tight_layout()
    plt.show()
