In [None]:
import numpy as np
import pandas as pd
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report
from sklearn.model_selection import StratifiedKFold, cross_val_predict
from atlas_vis import DKTAtlas62ROIPlotter


new_wsev = pd.read_csv("C:/Users/BREIN/Desktop/copathology_visualization_temp/data/260108_wsev_final_df.csv")
hc_df = new_wsev[new_wsev['DX'] == 'HC']

df_all = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/data/wsev_old_100_data.csv')
df_all = df_all.dropna()
print(df_all['DX'].value_counts())

region_cols = df_all.loc[:, 'VA/1002':'VA/2035'].columns

X_hc = hc_df[hc_df.loc[:,'ctx_lh_caudalanteriorcingulate':'ctx_rh_insula'].columns].values.astype(float)
X_pat = df_all[region_cols].values.astype(float)

print(f"HC: {X_hc.shape[0]} subjects")
print(f"Patients: {X_pat.shape[0]} subjects")

hc_mean = X_hc.mean(axis=0, keepdims=True)
hc_std  = X_hc.std(axis=0, keepdims=True) + 1e-8  # avoid divide-by-zero

Z = (X_pat - hc_mean) / hc_std

X_atrophy = np.maximum(-Z, 0.0)
X = X_atrophy
X[X < 0] = 0.0

n_topics = 6  # try 3–6 in sensitivity analyses

lda = LatentDirichletAllocation(
    n_components=n_topics,
    doc_topic_prior=1.0,      # alpha
    topic_word_prior=0.1,     # beta
    learning_method='batch',
    max_iter=500,
    random_state=42
)

theta = lda.fit_transform(X)     # subject × topic
beta = lda.components_           # topic × region

print("LDA fitting complete")

# -------------------------------
# 4. Normalize topic maps
# -------------------------------

beta_norm = beta / beta.sum(axis=1, keepdims=True)

topic_df = pd.DataFrame(
    beta_norm.T,
    index=region_cols,
    columns=[f"Topic_{k}" for k in range(n_topics)]
)

# -------------------------------
# 5. Save topic (atrophy pattern) maps
# -------------------------------

topic_df.to_csv("./old_wsev_results/lda_topic_atrophy_patterns.csv")
print("Saved topic atrophy patterns")

# -------------------------------
# 6. Subject-level topic mixtures
# -------------------------------

theta_df = pd.DataFrame(
    theta,
    columns=[f"Topic_{k}" for k in range(n_topics)]
)

theta_df["DX"] = df_all["DX"].values
theta_df.to_csv("./old_wsev_results/lda_subject_topic_weights.csv", index=False)

print("Saved subject-level topic mixtures")

# -------------------------------
# 7. Diagnosis-wise topic expression (post hoc)
# -------------------------------

group_means = theta_df.groupby("DX").mean()
group_means.to_csv("./old_wsev_results/lda_diagnosis_topic_expression.csv")

print("Saved diagnosis-wise topic expression")

# -------------------------------
# 8. Reconstruct diagnosis-specific atrophy maps
# -------------------------------

dx_maps = {}

for dx in group_means.index:
    weights = group_means.loc[dx].values
    dx_map = np.dot(weights, beta_norm)
    dx_maps[dx] = dx_map

dx_maps_df = pd.DataFrame(dx_maps, index=region_cols)
dx_maps_df.to_csv("./old_wsev_results/lda_diagnosis_atrophy_maps.csv")

print("Saved diagnosis-specific atrophy maps")

# -------------------------------
# 9. Quick sanity checks
# -------------------------------

print("\nTop regions per topic:")
for k in range(n_topics):
    top_regions = (
        topic_df[f"Topic_{k}"]
        .sort_values(ascending=False)
        .head(8)
    )
    print(f"\nTopic {k}:")
    print(top_regions)

print("\nMean topic expression per diagnosis:")
print(group_means)

print("\nDone.")


In [None]:
from atlas_vis import DKTAtlas62ROIPlotter
## Surface Map of VA zscore ##
plotter_62  = DKTAtlas62ROIPlotter(
    cmap='Reds',
    clim=(0, 3.0),  
    window_size=(1200, 1000),
    nan_color='lightgray',
    background='white',
    template_key='pial'
)
new_wsev = pd.read_csv("C:/Users/BREIN/Desktop/copathology_visualization_temp/data/260108_wsev_final_df.csv")
hc_df = new_wsev[new_wsev['DX'] == 'HC']

df_all = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/data/wsev_old_100_data.csv')
df_all = df_all.dropna()
print(df_all['DX'].value_counts())

region_cols = df_all.loc[:, 'VA/1002':'VA/2035'].columns

X_hc = hc_df[hc_df.loc[:,'ctx_lh_caudalanteriorcingulate':'ctx_rh_insula'].columns].values.astype(float)
X_pat = df_all[region_cols].values.astype(float)

hc_mean = X_hc.mean(axis=0, keepdims=True)
hc_std  = X_hc.std(axis=0, keepdims=True) + 1e-8  # avoid divide-by-zero

for dx in df_all['DX'].unique():
    df_dx = df_all[df_all['DX']==dx]
    print(dx)
    print(len(df_dx))

    X_raw = df_dx[region_cols].values.astype(float)
    X_zscore = (X_raw - hc_mean) / hc_std
    # X_zscore = np.maximum(-X_zscore, 0.0)
    X_zscore = -X_zscore
    # Get mean of z-scores across subjects
    X_mean_zscore = X_zscore.mean(axis=0)

    print(X_mean_zscore.shape)

    l_values = X_mean_zscore[:31].tolist()
    r_values = X_mean_zscore[31:].tolist()
    print(l_values)
    print(r_values)
    print(len(l_values))
    print(len(r_values))
    print(np.min(l_values + r_values))
    print(np.max(l_values + r_values))

    plotter_62(l_values, r_values, save_path=f'./old_wsev_results/surface_maps/zscore/{dx}.png')


In [None]:

## Surface Map of Raw VA ##
plotter_62  = DKTAtlas62ROIPlotter(
    cmap='Reds',
    clim=(0, 0.2),  
    window_size=(1200, 1000),
    nan_color='lightgray',
    background='white',
    template_key='pial'
)

df = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/lda_with_xg/old_wsev_results/lda_topic_atrophy_patterns.csv')
# df.index = df.index.astype(str).str.replace('VA/', '', regex=False)

print(len(df))
for col in df.columns[1:]:
    print(col)
    l_values = df.loc[:30,col].to_list()
    r_values = df.loc[31:,col].to_list()
    print(len(l_values))
    print(len(r_values))
    print(np.min(l_values+r_values))
    print(np.max(l_values+r_values))

    plotter_62(l_values, r_values, save_path=f'./old_wsev_results/surface_maps/topicwise/{col}.png')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv('C:/Users/BREIN/Desktop/copathology_visualization_temp/lda_with_xg/old_wsev_results/lda_diagnosis_topic_expression.csv')
subjects = df.iloc[:, 0].values
categories = df.columns[1:].tolist()
data = df.iloc[:, 1:].values

label_map = {'Topic_0': 'TO', 'Topic_1': 'MT', 'Topic_2': 'TP', 'Topic_3': 'CP', 'Topic_4': 'OFL', 'Topic_5': 'PF'}
labels = [label_map[cat] for cat in categories]
num_vars = len(categories)
angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
angles += angles[:1]

# Calculate grid size (add 1 for the combined plot)
n_subjects = len(subjects)
n_plots = n_subjects + 1  # +1 for combined plot
n_cols = 3
n_rows = int(np.ceil(n_plots / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows),subplot_kw=dict(polar=True))
axes = axes.flatten()

colors = plt.cm.tab10.colors

# First subplot: all subjects combined
ax = axes[0]
for idx, (subject, values) in enumerate(zip(subjects, data)):
    values_closed = values.tolist() + [values[0]]
    ax.plot(angles, values_closed, 'o-', linewidth=2, 
            color=colors[idx % len(colors)], label=subject, alpha=0.7)
    ax.fill(angles, values_closed, alpha=0.1, color=colors[idx % len(colors)])

ax.set_xticks(angles[:-1])
ax.set_xticklabels(labels, size=15)
ax.set_ylim(0, data.max() * 1.1)
ax.set_title('All', size=20)
ax.legend(loc='upper right', bbox_to_anchor=(1.1, 1.0), fontsize=6)

# Remaining subplots: individual subjects
for idx, (subject, values) in enumerate(zip(subjects, data)):
    ax = axes[idx + 1]  # Offset by 1
    values_closed = values.tolist() + [values[0]]
    
    ax.plot(angles, values_closed, 'o-', linewidth=2, color=colors[idx % len(colors)])
    ax.fill(angles, values_closed, alpha=0.25, color=colors[idx % len(colors)])
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels, size=15)
    ax.set_ylim(0, data.max() * 1.1)
    ax.set_title(subject, size=20)

# Hide empty subplots
for idx in range(n_plots, len(axes)):
    axes[idx].set_visible(False)

plt.tight_layout()
plt.savefig('./old_wsev_results/figures/spider_grid.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================
# XGBoost Multiclass Classification with 5-Fold CV
# Using LDA Topic Weights (Out-of-Fold Ensemble Predictions)
# ============================================================

import numpy as np
import pandas as pd
from xgboost import XGBClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix

# -------------------------
# 1. Load data
# -------------------------
theta_df = pd.read_csv(
    "C:/Users/BREIN/Desktop/copathology_visualization_temp/lda_with_xg/old_wsev_results/lda_subject_topic_weights.csv"
)

topic_cols = [c for c in theta_df.columns if c.startswith("Topic_")]
X = theta_df[topic_cols].values
y = theta_df["DX"].values
subject_idx = theta_df.index.values

# Encode DX labels
le = LabelEncoder()
y_encoded = le.fit_transform(y)
dx_classes = le.classes_

print("DX classes:", dx_classes)

# -------------------------
# 2. Stratified 5-fold CV
# -------------------------
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

all_y_true = []
all_y_pred = []
all_y_proba = []
all_idx = []

for fold, (train_idx, test_idx) in enumerate(skf.split(X, y_encoded), 1):
    print(f"\n===== Fold {fold} =====")

    X_train, X_test = X[train_idx], X[test_idx]
    y_train, y_test = y_encoded[train_idx], y_encoded[test_idx]

    xgb = XGBClassifier(
        objective="multi:softprob",
        num_class=len(dx_classes),
        eval_metric="mlogloss",
        use_label_encoder=False,
        random_state=42
    )

    xgb.fit(X_train, y_train)

    y_pred = xgb.predict(X_test)
    y_proba = xgb.predict_proba(X_test)

    all_y_true.append(y_test)
    all_y_pred.append(y_pred)
    all_y_proba.append(y_proba)
    all_idx.append(subject_idx[test_idx])

# -------------------------
# 3. Aggregate CV results
# -------------------------
y_true_all = np.concatenate(all_y_true)
y_pred_all = np.concatenate(all_y_pred)
y_proba_all = np.vstack(all_y_proba)
idx_all = np.concatenate(all_idx)

# -------------------------
# 4. Metrics (CV-aggregated)
# -------------------------
print("\n===== CV-Aggregated Classification Report =====")
print(
    classification_report(
        y_true_all,
        y_pred_all,
        target_names=dx_classes
    )
)

cm = confusion_matrix(y_true_all, y_pred_all)
accuracy = np.trace(cm) / np.sum(cm)

print("CV Accuracy:", round(accuracy, 4))
print("Confusion Matrix:\n", cm)

# -------------------------
# 5. Build final results DataFrame
# -------------------------
results_df = pd.DataFrame(
    y_proba_all,
    columns=[f"P({dx})" for dx in dx_classes]
)

results_df["DX_true"] = le.inverse_transform(y_true_all)
results_df["DX_pred"] = le.inverse_transform(y_pred_all)
results_df["subject_index"] = idx_all

# Sort by true DX for visualization
results_df = results_df.sort_values("DX_true").reset_index(drop=True)

# -------------------------
# 6. Save results
# -------------------------
out_path = (
    "C:/Users/BREIN/Desktop/copathology_visualization_temp/"
    "lda_with_xg/old_wsev_results/xgb_5fold_cv_test_predictions.csv"
)
results_df.to_csv(out_path, index=False)

print("\nSaved CV predictions to:")
print(out_path)


In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=le.classes_)
disp.plot(cmap='Blues', values_format='d')  # or normalize=True for % values


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# ---------------------------------
# 1. Merge CV results with topic mixtures
# ---------------------------------
plot_df = theta_df.loc[results_df["subject_index"]].copy()

plot_df["DX_true"] = results_df["DX_true"].values
plot_df["DX_pred"] = results_df["DX_pred"].values

# ---------------------------------
# 2. Compute P(true DX) per subject
# ---------------------------------
proba_cols = [c for c in results_df.columns if c.startswith("P(")]

plot_df["p_true_dx"] = [
    results_df.iloc[i][f"P({dx})"]
    for i, dx in enumerate(plot_df["DX_true"])
]

# ---------------------------------
# 3. Sort subjects:
#    group by true DX, then by P(true DX)
# ---------------------------------
plot_df = plot_df.sort_values(
    ["DX_true", "p_true_dx"],
    ascending=[True, False]
).reset_index(drop=True)

# ---------------------------------
# 4. Plot stacked topic mixtures
# ---------------------------------
label_map = {
    'Topic_0': 'LT',
    'Topic_1': 'RT',
    'Topic_2': 'P',
    'Topic_3': 'CO',
    'Topic_4': 'ROF',
    'Topic_5': 'LPF'
}

topic_cols = [c for c in plot_df.columns if c.startswith("Topic_")]

fig, ax = plt.subplots(figsize=(14, 5))

bottom = np.zeros(len(plot_df))
colors = sns.color_palette("tab20", len(topic_cols))

for i, topic in enumerate(topic_cols):
    ax.bar(
        np.arange(len(plot_df)),
        plot_df[topic],
        bottom=bottom,
        color=colors[i],
        label=label_map.get(topic, topic)
    )
    bottom += plot_df[topic].values

# ---------------------------------
# 5. Axis styling
# ---------------------------------
ax.set_xticks([])
ax.set_ylabel("Topic proportion")
ax.set_ylim(0, 1.08)  # <-- CRITICAL FIX
ax.set_title(
    "Topic Mixtures per Subject (5-fold CV)\n"
    "(Grouped by DX, Sorted by P(Actual DX))"
)

# ---------------------------------
# 6. DX group labels & separators
# ---------------------------------
current = 0
for dx in plot_df["DX_true"].unique():
    count = (plot_df["DX_true"] == dx).sum()

    ax.text(
        current + count / 2 - 0.5,
        -0.05,
        dx,
        ha="center",
        va="top",
        fontsize=12
    )

    ax.axvline(current - 0.5, color="black", linewidth=1.5)
    current += count

ax.axvline(current - 0.5, color="black", linewidth=1.5)

# ---------------------------------
# 7. Mark misclassified subjects
# ---------------------------------
misclassified = plot_df["DX_true"] != plot_df["DX_pred"]

ax.scatter(
    np.where(misclassified)[0],
    np.ones(misclassified.sum()) * 1.03,
    color="red",
    marker="x",
    s=60,
    linewidths=2,
    label="Misclassified",
    zorder=10
)

# ---------------------------------
# 8. Legend & layout
# ---------------------------------
ax.legend(
    title="LDA Topics",
    bbox_to_anchor=(1.05, 1),
    loc="upper left"
)

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# ---------------------------------
# 1. Prepare probability columns
# ---------------------------------
proba_cols = [c for c in results_df.columns if c.startswith("P(")]
dx_labels = [c.replace("P(", "").replace(")", "") for c in proba_cols]

# Align probabilities to plot_df order
df = results_df.loc[plot_df.index, proba_cols].copy()
df["DX_true"] = plot_df["DX_true"].values
df["DX_pred"] = plot_df["DX_pred"].values

# ---------------------------------
# 2. Sort subjects:
#    group by true DX,
#    within group sort by P(true DX)
# ---------------------------------
sorted_blocks = []

for dx in df["DX_true"].unique():
    dx_block = df[df["DX_true"] == dx].copy()

    true_dx_col = f"P({dx})"
    if true_dx_col not in dx_block.columns:
        raise ValueError(f"Missing probability column: {true_dx_col}")

    dx_block = dx_block.sort_values(
        by=true_dx_col,
        ascending=False
    )

    sorted_blocks.append(dx_block)

proba_plot_df = (
    pd.concat(sorted_blocks)
    .reset_index(drop=True)
)

# ---------------------------------
# 3. Stacked bar plot
# ---------------------------------
fig, ax = plt.subplots(figsize=(14, 5))

bottom = np.zeros(len(proba_plot_df))
colors = sns.color_palette("tab10", len(dx_labels))

for i, (dx, col) in enumerate(zip(dx_labels, proba_cols)):
    ax.bar(
        np.arange(len(proba_plot_df)),
        proba_plot_df[col],
        bottom=bottom,
        color=colors[i],
        label=dx
    )
    bottom += proba_plot_df[col].values

ax.set_xticks([])
ax.set_ylabel("Predicted DX probability")
ax.set_ylim(0, 1)
ax.set_title(
    "Predicted Diagnosis Probabilities per Subject (5-fold CV)\n"
    "(Grouped by True DX, Sorted by P(True DX))"
)

# ---------------------------------
# 4. Add DX group labels & separators
# ---------------------------------
current = 0

for dx in proba_plot_df["DX_true"].unique():
    count = (proba_plot_df["DX_true"] == dx).sum()

    ax.text(
        current + count / 2 - 0.5,
        -0.05,
        dx,
        ha="center",
        va="top",
        fontsize=12
    )

    ax.axvline(current - 0.5, color="black", linewidth=1.5)
    current += count

ax.axvline(current - 0.5, color="black", linewidth=1.5)

ax.legend(
    title="Predicted DX",
    bbox_to_anchor=(1.05, 1),
    loc="upper left"
)

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# ---------------------------------
# 1. Prepare probability columns
# ---------------------------------
proba_cols = [c for c in proba_plot_df.columns if c.startswith("P(")]
dx_labels = [c.replace("P(", "").replace(")", "") for c in proba_cols]

df = proba_plot_df.copy()

# ---------------------------------
# 2. Sort subjects:
#    - group by true DX
#    - within group, sort by P(true DX) descending
# ---------------------------------
sorted_blocks = []

for dx in df["DX_true"].unique():
    dx_block = df[df["DX_true"] == dx].copy()

    true_dx_col = f"P({dx})"
    if true_dx_col not in dx_block.columns:
        raise ValueError(f"Missing probability column: {true_dx_col}")

    dx_block = dx_block.sort_values(
        by=true_dx_col,
        ascending=False
    )

    sorted_blocks.append(dx_block)

df_sorted = pd.concat(sorted_blocks).reset_index(drop=True)

# ---------------------------------
# 3. Heatmap data
# ---------------------------------
heatmap_data = df_sorted[proba_cols].values

# ---------------------------------
# 4. Plot heatmap
# ---------------------------------
plt.figure(figsize=(8, 10))

ax = sns.heatmap(
    heatmap_data,
    cmap="Reds",
    vmin=0,
    vmax=1,
    cbar_kws={"label": "Predicted DX probability"},
    yticklabels=False,
    xticklabels=dx_labels
)

ax.set_xlabel("Predicted Diagnosis")
ax.set_ylabel("Subjects (True DX)")
ax.set_title(
    "Predicted Diagnosis Probabilities per Subject\n"
    "(Grouped by True DX, Sorted by P(True DX))"
)

# ---------------------------------
# 5. Add true DX group labels + lines
# ---------------------------------
current = 0
yticks = []
ylabels = []

for dx in df_sorted["DX_true"].unique():
    count = (df_sorted["DX_true"] == dx).sum()
    midpoint = current + count / 2

    yticks.append(midpoint)
    ylabels.append(dx)

    ax.hlines(
        current,
        xmin=0,
        xmax=len(dx_labels),
        colors="black",
        linewidth=1.5
    )

    current += count

# Final boundary
ax.hlines(
    current,
    xmin=0,
    xmax=len(dx_labels),
    colors="black",
    linewidth=1.5
)

ax.set_yticks(yticks)
ax.set_yticklabels(ylabels, rotation=0)

plt.tight_layout()
plt.show()
