In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os

# Choose a specific gene set to plot
specific_gene_set = "Cellular Response To Interleukin-1 (GO:0071347)"

# Define paths and cell types
cell_types = ["Podo", "TAL", "DCT_CNT_CD", "EC", "Stromal", "Immune", "PEC", "PT", "IC", "DTL_ATL"]
base_path = ".../Atlas/scSPECTRA/onthefly/onthefly_V2/"
r2_folder = "R2/"
pval_folder = "Pval/"
metadat_path = ".../Atlas/scSPECTRA/R2_pval/Atlas_Extended_II_Albuminuria_gpt.csv"

# Read metadata
metadat = pd.read_csv(metadat_path)
unique_samples = metadat[metadat['Disease_level2'] != "Control"]['Sample'].unique()

# Initialize combined dataframes for R2 and p-values
combined_r2 = pd.DataFrame(index=unique_samples)
combined_pval = pd.DataFrame(index=unique_samples)

# Process each cell type
for cell_type in cell_types:
    r2_path = os.path.join(base_path, r2_folder, f"R2_{cell_type}.csv")
    pval_path = os.path.join(base_path, pval_folder, f"Pval_{cell_type}.csv")
    
    if os.path.exists(r2_path) and os.path.exists(pval_path):
        r2_df = pd.read_csv(r2_path, index_col=0)
        pval_df = pd.read_csv(pval_path, index_col=0)
        r2_df = r2_df[[specific_gene_set]]
        pval_df = pval_df[[specific_gene_set]]
    
        r2_df.columns = [f"{cell_type}_{col}" for col in r2_df.columns]
        pval_df.columns = [f"{cell_type}_{col}" for col in pval_df.columns]
    
        combined_r2 = combined_r2.join(r2_df, how='left')
        combined_pval = combined_pval.join(pval_df, how='left')

# Set NaN values to 1
combined_r2.fillna(1, inplace=True)
combined_pval.fillna(1, inplace=True)



# Set non-significant R2 values to 1
p_value_threshold = 0.01
for col in combined_r2.columns:
    combined_r2[col] = np.where((combined_pval[col] > p_value_threshold), 1, combined_r2[col])

    
# Filter for the specific gene set across all cell types
plotting_matrix = combined_r2[[f"{ct}_{specific_gene_set}" for ct in cell_types]]

# Remove gene set names from column labels
plotting_matrix.columns = cell_types

# Plotting without dendrogram and custom NaN color
plt.figure(figsize=(24, 24))
sns.clustermap(plotting_matrix, cmap="rocket", figsize=(8, 8), 
               linewidths=.5, linecolor='black', row_cluster=True, col_cluster=True, 
               cbar_pos=(0.00, 0.00, 0.00, 0.00), vmin=0, vmax=1)

plt.show()
