In [3]:
import os
import pandas as pd
import numpy as np
from data_processor import *
from lda_model import LDATopicModel
from classifier import TopicClassifier
from visualizer import *
from brain_visualizer import *
from pipeline import *

data_dir = 'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/data'
fig_save_dir = 'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/results'
mdl_save_dir = 'C:/Users/WooSikKim/Desktop/Research/projects/co_pathology/scripts/stage_copath/saved_mdls'

train_df = pd.read_csv(os.path.join(data_dir, '260128_wsev_smc_combined_cn_included.csv'))
train_df = train_df[train_df['DX']!='HC']

print(train_df['DX'].value_counts())


DX
NC        166
AD         72
svPPA      59
bvFTD      53
nfvPPA     46
DLB        25
PD         24
SVAD       24
Name: count, dtype: int64


In [None]:
from scipy.stats import pearsonr  # or spearmanr if you prefer
def resilient_subgroup_visualization(inp_df,prob_cols, group_col,group_order,cohort='NACC',scatter_col='standardized_residual'):
    group_means = (
        inp_df
        .groupby(group_col)[prob_cols]
        .mean()
        .reindex(group_order)
    )
###################### GROUP MEAN LEVEL #####################
    # ------------------------------------------------------------
    # Plot heatmap
    # ------------------------------------------------------------
    plt.figure(figsize=(8, 5))

    sns.heatmap(
        group_means,
        cmap="Reds",
        annot=True,
        fmt=".2f",
        linewidths=0.5,
        vmin=0,
        vmax=1,
        cbar_kws={"label": "Mean predicted probability"}
    )

    plt.xlabel("Predicted pathology")
    plt.ylabel("Subgroup")
    plt.title(f"{cohort} Group-wise Mean Predicted Probability Distribution")
    plt.tight_layout()
    plt.show()
###################### SUBJECT LEVEL #####################
    # ------------------------------------------------------------
    # Sort: group first, then descending P(AD)
    # ------------------------------------------------------------
    inp_df[group_col] = pd.Categorical(
        inp_df[group_col],
        categories=group_order,
        ordered=True
    )

    df_sorted = (
        inp_df
        .sort_values([group_col, "P(AD)"], ascending=[True, False])
        .reset_index(drop=True)
    )


    heatmap_data = df_sorted[prob_cols]

    # ------------------------------------------------------------
    # Compute group positions for y-axis labels
    # ------------------------------------------------------------
    group_counts = (
        df_sorted[group_col]
        .value_counts()
        .reindex(group_order)
    )


    group_centers = {}
    start = 0

    for grp, count in group_counts.items():
        center = start + count / 2
        group_centers[grp] = center
        start += count

    # ------------------------------------------------------------
    # Plot
    # ------------------------------------------------------------
    plt.figure(figsize=(10, 10))

    ax = sns.heatmap(
        heatmap_data,
        cmap="Reds",
        vmin=0,
        vmax=1,
        yticklabels=False,
        cbar_kws={"label": "Predicted probability"}
    )

    # ------------------------------------------------------------
    # Horizontal lines between groups
    # ------------------------------------------------------------
    cum_sizes = np.cumsum(group_counts.values)

    for y in cum_sizes[:-1]:
        ax.hlines(y, *ax.get_xlim(), colors="black", linewidth=1.5)

    # ------------------------------------------------------------
    # TN subgroup labels on y-axis
    # ------------------------------------------------------------
    ax.set_yticks(list(group_centers.values()))
    ax.set_yticklabels(list(group_centers.keys()), rotation=0, fontsize=11)

    # ------------------------------------------------------------
    # Labels
    # ------------------------------------------------------------
    ax.set_xlabel("Predicted pathology")
    ax.set_ylabel("Subgroup")
    ax.set_title(f"{cohort} Subject-level Predicted Probability Heatmap\n(sorted by descending P(AD))")
    plt.tight_layout()
    plt.show()
    
###################### RADAR PLOT #####################
    topic_cols = [c for c in inp_df.columns if c.startswith("Topic_")]

    groups = inp_df[group_col].unique()
    n_groups = len(groups)
    n_topics = len(topic_cols)

    # ------------------------------------------------------------
    # Global max for shared axis
    # ------------------------------------------------------------
    global_max = (
        inp_df
        .groupby(group_col)[topic_cols]
        .mean()
        .values
        .max()
    )

    # ------------------------------------------------------------
    # Radar setup
    # ------------------------------------------------------------
    angles = np.linspace(0, 2 * np.pi, n_topics, endpoint=False)
    angles = np.concatenate([angles, [angles[0]]])

    fig, axes = plt.subplots(
        1, n_groups,
        figsize=(4 * n_groups, 4),
        subplot_kw=dict(polar=True)
    )

    if n_groups == 1:
        axes = [axes]

    # ------------------------------------------------------------
    # Plot
    # ------------------------------------------------------------
    # for ax, grp in zip(axes, groups):
    for ax, grp in zip(axes, group_order):

        grp_df = inp_df[inp_df[group_col] == grp]
        mean_topics = grp_df[topic_cols].mean().values
        mean_topics = np.concatenate([mean_topics, [mean_topics[0]]])

        ax.plot(angles, mean_topics, linewidth=2)
        ax.fill(angles, mean_topics, alpha=0.25)

        ax.set_title(grp, pad=20)

        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(topic_cols, fontsize=9)

        ax.set_ylim(0, global_max * 1.1)   # âœ… shared scale
        ax.set_yticklabels([])

    plt.suptitle(f"{cohort} Resilience Subgroup Topic Weight Profiles (shared radial scale)", fontsize=14)
    plt.tight_layout()
    plt.show()

###################### CORRELATION SCATTER #####################
    # -----------------------------
    # Example inputs
    # -----------------------------
    # inp_df: your dataframe
    # cols_to_corr: list of columns of probabilities to correlate
    # target_col: column to correlate against
    cols_to_corr = prob_cols
    target_col = scatter_col  # for example

    # -----------------------------
    # Plotting setup
    # -----------------------------
    n_cols = 3  # how many subplots per row
    n_rows = int(np.ceil(len(cols_to_corr) / n_cols))

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
    axes = axes.flatten()
    palette = sns.color_palette("tab10", n_colors=len(group_order))

    group_palette = dict(zip(group_order, palette))

    for ax, col in zip(axes, cols_to_corr):
        
        x = inp_df[col]
        y = inp_df[target_col]
        
        # Compute correlation
        r, p = pearsonr(x, y)
        
        # Scatter plot
        sns.scatterplot(
            x=x, y=y, hue=inp_df[group_col], palette=group_palette, ax=ax, s=60, alpha=0.8
        )
        
        # Fit line
        sns.regplot(x=x, y=y, ax=ax, scatter=False, color='red', ci=None)
        
        # Annotate r and p
        ax.text(0.05, 0.95, f"r={r:.2f}\np={p:.3f}",
                transform=ax.transAxes,
                verticalalignment='top',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7),
                fontsize = 13)
        
        ax.set_xlabel(col)
        ax.set_ylabel(target_col)
        ax.set_title(f"{col} vs {target_col}")
        ax.set_xlim([0, 1])
        ax.set_ylim([-3, 4])

    # Remove empty axes if any
    for ax in axes[len(cols_to_corr):]:
        ax.remove()
    handles, labels = ax.get_legend_handles_labels()

    fig.legend(handles, labels, loc='upper right', title=group_col, bbox_to_anchor=(1.05, 1))
    plt.suptitle('NACC')
    plt.tight_layout()
    plt.show()