In [1]:
# Cell 1 - imports & settings
import os
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
import warnings
warnings.filterwarnings("ignore")

sns.set(style="white")


In [2]:
# Cell 2 - utilities
def detect_dataset_name(consensus_folder):
    """
    Detect dataset prefix inside a consensus output folder.
    Example file:
      anal_pc5_c21_S11.filtered.gene_spectra_score.k_7.dt_0_01.txt
    Returns the dataset prefix (e.g. 'anal_pc5_c21_S11.filtered')
    """
    consensus_folder = Path(consensus_folder)
    for f in sorted(consensus_folder.iterdir()):
        if f.is_file() and "gene_spectra_score" in f.name and f.suffix == ".txt":
            # everything before ".gene_spectra_score"
            return f.name.split(".gene_spectra_score")[0]
    raise FileNotFoundError(f"No gene_spectra_score files found in {consensus_folder}")

def expected_gene_spectra_path(folder, dataset_prefix, k, density_threshold=0.01):
    dt_str = f"{density_threshold:.2f}".replace(".", "_")
    fn = f"{dataset_prefix}.gene_spectra_score.k_{k}.dt_{dt_str}.txt"
    return Path(folder) / fn


In [3]:
# Cell 3 - core processing functions
def load_gene_spectra(fp, k):
    """
    Load a gene spectra file; rename rows to 'k{K}_GEP{i}'.
    """
    df = pd.read_csv(fp, sep="\t", index_col=0)
    df.index = [f"k{k}_GEP{i+1}" for i in range(len(df))]
    return df

def stack_matrices(gene_spectra_dict):
    k_vals = sorted(gene_spectra_dict.keys())
    stacked = pd.concat([gene_spectra_dict[k] for k in k_vals], axis=0)
    return stacked

def calculate_correlation_matrix(stacked_matrix):
    # correlation between rows (GEP Ã— GEP)
    corr = stacked_matrix.T.corr(method="pearson")
    return corr

def perform_hierarchical_clustering(corr_matrix, linkage_method="average"):
    dist = 1 - corr_matrix.values
    dist = np.clip(dist, 0, 2)
    np.fill_diagonal(dist, 0)
    condensed = squareform(dist, checks=False)
    link = linkage(condensed, method=linkage_method)
    return link


In [4]:
# Cell 4 - plotting functions
def plot_dendrogram(linkage_matrix, labels, output_prefix):
    """
    Dendrogram: color leaf labels by k.
    """
    # extract k for label coloring
    k_for_labels = [int(lbl.split('_')[0][1:]) for lbl in labels]
    unique_k = sorted(list(set(k_for_labels)))
    colors = plt.cm.tab10(np.linspace(0,1,max(1,len(unique_k))))
    k_to_color = {k: colors[i % len(colors)] for i,k in enumerate(unique_k)}
    label_colors = [k_to_color[k] for k in k_for_labels]
    
    plt.figure(figsize=(16,7))
    dendrogram(linkage_matrix, labels=labels, leaf_rotation=90, leaf_font_size=7)
    ax = plt.gca()
    # color tick labels
    xticks = ax.get_xmajorticklabels()
    for lbl in xticks:
        txt = lbl.get_text()
        try:
            k = int(txt.split('_')[0][1:])
            lbl.set_color(k_to_color[k])
        except Exception:
            pass
    plt.title('Hierarchical Clustering of GEPs Across K Values', fontsize=14, fontweight='bold')
    plt.xlabel('GEP ID')
    plt.ylabel('Distance (1 - Pearson Correlation)')
    # legend
    legend_elems = [plt.Line2D([0],[0], color=k_to_color[k], lw=4, label=f'k={k}') for k in unique_k]
    plt.legend(handles=legend_elems, loc='upper right', title='K Value')
    out = f"{output_prefix}_dendrogram.png"
    plt.tight_layout()
    plt.savefig(out, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved dendrogram: {out}")



def plot_correlation_clustermap(corr_matrix, output_prefix, figsize=(12,10)):
    labels = corr_matrix.index.tolist()
    k_values = [int(l.split('_')[0][1:]) for l in labels]
    unique_k = sorted(set(k_values))
    colors = plt.cm.tab10(np.linspace(0,1,max(1,len(unique_k))))
    k_to_color = {k: colors[i % len(colors)] for i,k in enumerate(unique_k)}
    row_colors = [k_to_color[k] for k in k_values]
    col_colors = row_colors.copy()
    g = sns.clustermap(
        corr_matrix,
        cmap='RdBu_r',
        center=0,
        vmin=-1, vmax=1,
        figsize=figsize,
        dendrogram_ratio=0.15,
        cbar_pos=(0.02, 0.83, 0.03, 0.15),
        linewidths=0,
        xticklabels=True,
        yticklabels=True,
        cbar_kws={'label': 'Pearson Correlation'},
        row_colors=row_colors,
        col_colors=col_colors
    )
    g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), rotation=90, fontsize=6)
    g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(), rotation=0, fontsize=6)
    plt.suptitle('Pearson Correlation Between GEPs Across K Values', fontsize=14, fontweight='bold', y=0.98)
    # legend patches
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=k_to_color[k], label=f'k={k}') for k in unique_k]
    g.ax_heatmap.legend(handles=legend_elements, title='K Value', bbox_to_anchor=(1.35, 1.0), loc='upper left', frameon=True)
    out = f"{output_prefix}_correlation_heatmap.png"
    plt.savefig(out, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved correlation heatmap: {out}")



def plot_pairwise_k_comparison(corr_matrix, k1, k2, output_prefix):
    k1_geps = [g for g in corr_matrix.index if g.startswith(f'k{k1}_')]
    k2_geps = [g for g in corr_matrix.index if g.startswith(f'k{k2}_')]
    if not k1_geps or not k2_geps:
        print(f"No GEPs for k={k1} or k={k2}, skipping pairwise.")
        return
    pair = corr_matrix.loc[k1_geps, k2_geps]
    n1, n2 = len(k1_geps), len(k2_geps)
    figsize = (max(8, n2*0.5), max(6, n1*0.5))
    plt.figure(figsize=figsize)
    sns.heatmap(pair, cmap='RdBu_r', center=0, vmin=-1, vmax=1,
                annot=True if (n1<=12 and n2<=12) else False, fmt='.2f',
                cbar_kws={'label': 'Pearson Correlation'}, linewidths=0.5, linecolor='gray')
    plt.title(f'Pairwise GEP Correlation: k={k1} vs k={k2}')
    plt.xlabel(f'GEPs from k={k2}')
    plt.ylabel(f'GEPs from k={k1}')
    out = f"{output_prefix}_pairwise_k{k1}_vs_k{k2}.png"
    plt.tight_layout()
    plt.savefig(out, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved pairwise heatmap: {out}")


In [5]:
# Cell 7 - RUN: set paths and parameters here
patient_folder = Path("all_datasets_consensus/anal_pc5_c21_S1.filtered_consensus_outputs")
k_values = [7, 8, 9] 
density_threshold = 0.01
output_prefix = "anal_pc5_c21_S1_kcomparison"   # change to taste
linkage_method = "average"                       # e.g., 'average', 'complete', 'single', 'ward'

# ---- detect dataset prefix
dataset_prefix = detect_dataset_name(patient_folder)
print("Detected dataset prefix:", dataset_prefix)

# ---- locate files for requested K values
found = {}
dt_str = f"{density_threshold:.2f}".replace(".", "_")
for k in sorted(k_values):
    p = expected_gene_spectra_path(patient_folder, dataset_prefix, k, density_threshold)
    if p.exists():
        found[k] = p
        print(f"Found k={k}: {p.name}")
    else:
        print(f"Missing k={k}: expected {p.name}")

if not found:
    raise SystemExit("No gene_spectra files found for requested k values. Aborting.")

# ---- load matrices
gene_spectra_dict = {}
for k, path in sorted(found.items()):
    gene_spectra_dict[k] = load_gene_spectra(path, k)

# ---- stack & save stacked matrix
stacked = stack_matrices(gene_spectra_dict)
stacked_out = f"{output_prefix}_stacked_gep_matrix.csv"
stacked.to_csv(stacked_out)
print(f"Saved stacked matrix: {stacked_out}")

# ---- correlation & save
corr = calculate_correlation_matrix(stacked)
corr_out = f"{output_prefix}_correlation_matrix.csv"
corr.to_csv(corr_out)
print(f"Saved correlation matrix: {corr_out}")

# ---- hierarchical clustering
link = perform_hierarchical_clustering(corr, linkage_method=linkage_method)

# ---- visualizations
plot_dendrogram(link, list(stacked.index), output_prefix)
plot_correlation_clustermap(corr, output_prefix, figsize=(12,10))

# ---- pairwise comparisons (only if <= 4 k values)
if len(gene_spectra_dict) <= 4:
    ks = sorted(gene_spectra_dict.keys())
    for i in range(len(ks)):
        for j in range(i+1, len(ks)):
            plot_pairwise_k_comparison(corr, ks[i], ks[j], output_prefix)
else:
    print("Skipping pairwise comparisons (more than 4 k-values).")


print("\nALL DONE.")


Detected dataset prefix: anal_pc5_c21_S1.filtered
Found k=7: anal_pc5_c21_S1.filtered.gene_spectra_score.k_7.dt_0_01.txt
Found k=8: anal_pc5_c21_S1.filtered.gene_spectra_score.k_8.dt_0_01.txt
Found k=9: anal_pc5_c21_S1.filtered.gene_spectra_score.k_9.dt_0_01.txt
Saved stacked matrix: anal_pc5_c21_S1_kcomparison_stacked_gep_matrix.csv
Saved correlation matrix: anal_pc5_c21_S1_kcomparison_correlation_matrix.csv
Saved dendrogram: anal_pc5_c21_S1_kcomparison_dendrogram.png
Saved correlation heatmap: anal_pc5_c21_S1_kcomparison_correlation_heatmap.png
Saved pairwise heatmap: anal_pc5_c21_S1_kcomparison_pairwise_k7_vs_k8.png
Saved pairwise heatmap: anal_pc5_c21_S1_kcomparison_pairwise_k7_vs_k9.png
Saved pairwise heatmap: anal_pc5_c21_S1_kcomparison_pairwise_k8_vs_k9.png

ALL DONE.
