In [None]:
import json
import multiprocessing
import operator
import os
import pickle
from itertools import product

import joblib
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotx
import numpy as np
import pandas as pd
import pyarrow.feather as feather
import scipy
import seaborn as sns
import shap
import xgboost as xgb
from scipy.stats import gaussian_kde
from adjustText import adjust_text
from sklearn.cluster import KMeans
from sklearn.datasets import load_svmlight_file
from sklearn.ensemble import RandomForestRegressor
from sklearn.inspection import partial_dependence
from sklearn.linear_model import LinearRegression
from sklearn.metrics import root_mean_squared_error, silhouette_score
from sklearn.model_selection import (
    GridSearchCV,
    KFold,
    cross_val_score,
    cross_validate,
    train_test_split,
)
from sklearn.pipeline import Pipeline
from kneed import KneeLocator

In [None]:
mpl.rcParams['pdf.fonttype'] = 42
plt.rcParams['font.family'] = 'Arial'
gene_select_dir = "/dcs05/hongkai/data/next_cutntag/bulk/dna_methylation/cpg_island"
gene_select_filename = "gene_categories.json"
gene_select_dir_filename = os.path.join(gene_select_dir, gene_select_filename)
with open(gene_select_dir_filename) as fp:
    gene_select_dict = json.load(fp)
gene_select_dict["coding_all"] = gene_select_dict["coding_cpg"] + gene_select_dict["coding_non_cpg"]

In [None]:
target_pair_mapping_file_36x = "/dcs05/hongkai/data/next_cutntag/script/utils/target_pair_short_hand.csv"
target_pair_mapping_df = pd.read_csv(target_pair_mapping_file_36x, sep=" ")
target_pairs_selected_file = "/dcs05/hongkai/data/next_cutntag/bulk/explainability/pdp_explain_results/target_pairs_selected.json"

with open(target_pairs_selected_file) as fp:
    target_pairs = json.load(fp)

In [None]:
# helper functions
def map_target_names(target_pair_list, target_pair_mapping_df, from_col="targets", to_col="shorthand" ):
    cur_names = target_pair_mapping_df.loc[:,from_col].to_list()
    new_names = target_pair_mapping_df.loc[:,to_col].to_list()
    print()
    result = target_pair_list
    for i in range(len(cur_names)):
        cur_name = cur_names[i]
        new_name = new_names[i]
        result = [target_pair.replace(cur_name, new_name) for target_pair in result]
    return(result)
def column_to_rownames(wgc, var="pos"):
    pos_list_full = wgc[var]
    wgc = wgc.set_index(var)
    return wgc
def remove_model_in_param_grid(best_params):
    results = {}
    for param in best_params.keys():
        new_param = param.replace("model__", "")
        results[new_param] = best_params[param]
    return results
def calculate_effect_size(x, y, method="slope"):
    # print(method)
    if method == "slope":
        result = (y[-1]-y[0])/(x[-1]-x[0])
    elif method == "auc":
        sorted_indices = np.argsort(x)
        x_sorted = x[sorted_indices]
        y_sorted = y[sorted_indices]

        # Calculate area under the curve
        result = np.trapz(y_sorted, x_sorted)
    elif method == "end":
        result = y[-1]
    return result

def get_average_lines(pdp_results, target_pair):
    all_arrays = []
    for random_seed in pdp_results.keys():
        individual_lines = pdp_results[random_seed]["coding_all"][target_pair]["individual"]
        all_arrays.append(individual_lines[0])
    stacked = np.vstack(all_arrays)
    return stacked

In [None]:
data = {}
model_designs = ["rnaseq_vs_hiplex"]
frag_types = ["mixed"]
cluster_result = pd.read_csv("/dcs05/hongkai/data/next_cutntag/bulk/explainability/leave_one_out/V/all/coding_cpg_clusters=5.csv", index_col=0)
cluster_result = cluster_result.set_index("gene_id")

In [None]:
# load data
data = {}
model_designs = ["rnaseq_vs_hiplex_rm_outlier_log"]
model_design = "rnaseq_vs_hiplex_rm_outlier_log"
frag_types = ["mixed"]
frag_type = "mixed"
gene_select_names = ["coding_all"]
gene_select_name = "coding_all"

# read in rnaseq TPM
rnaseq_dir = "/dcs05/hongkai/data/next_cutntag/bulk/RNA-seq"
rnaseq_filename = "RNA_seq_TPM_all.csv"
rnaseq_dir_filename = os.path.join(rnaseq_dir, rnaseq_filename)
rnaseq = pd.read_csv(rnaseq_dir_filename, sep=",", header=0, index_col=0)
# RNAseq is log transformed
rnaseq.loc[:, "sqrt_V"] = np.log10(rnaseq.loc[:, "V1V2"]+1)

# read in hiplex norm counts
wgc_root_dir = "/dcs05/hongkai/data/next_cutntag/bulk/wgc"
frag_type = "mixed"
bin_size = "promoter_-1000-1000"
scen = "V"
target_qc_type = "all"
post_process = "libnorm"
wgc_dir = os.path.join(wgc_root_dir, frag_type, bin_size)
wgc_filename = f"{scen}_{frag_type}_{bin_size}_colQC-{target_qc_type}_{post_process}.feather"
wgc_dir_filename = os.path.join(wgc_dir, wgc_filename)
wgc_raw = feather.read_feather(wgc_dir_filename)
with open('/dcs05/hongkai/data/next_cutntag/script/utils/filtered_target_pairs.json', 'r') as file:
    features = json.load(file)
wgc_vals = wgc_raw[features]
zero_hiplex_genes = wgc_raw["pos"][wgc_vals.sum(axis=1) == 0].to_list()
wgc_raw[features] = np.log10(wgc_raw[features]+1)
wgc_raw[features] = (wgc_raw[features] - wgc_raw[features].min()) / (wgc_raw[features].max() - wgc_raw[features].min())

features = ["pos"]+features
wgc_raw = wgc_raw.loc[:, features]

wgc_raw_list = wgc_raw.loc[:, "pos"].values.tolist()
overlap_gene_list = list(set(wgc_raw_list) & set(gene_select_list))
rnaseq_avail = rnaseq.loc[overlap_gene_list, ["gene_id", "sqrt_V"]]
# filter genes with extremely high gene expression
with open("/dcs05/hongkai/data/next_cutntag/bulk/explainability/rnaseq_hiplex_cutoff.json", 'r') as f:
    rnaseq_cutoffs = json.load(f)
q99 = rnaseq_cutoffs[model_design]
rnaseq_avail = rnaseq_avail.loc[rnaseq_avail.loc[:, "sqrt_V"] <= q99,:]
zero_rnaseq_genes = rnaseq_avail["gene_id"][rnaseq_avail["sqrt_V"]==0].to_list()
zero_all_genes = list(set(zero_hiplex_genes) & set(zero_rnaseq_genes))
rnaseq_wgc_raw = rnaseq_avail.merge(wgc_raw, left_on="gene_id", right_on="pos", how="inner")
rnaseq_wgc_raw = rnaseq_wgc_raw.set_index("gene_id")
rnaseq_wgc_raw = rnaseq_wgc_raw.loc[~rnaseq_wgc_raw.index.isin(zero_all_genes), :]
rnaseq_wgc = rnaseq_wgc_raw.drop(columns=["pos"], inplace=False)
rnaseq_wgc_train, rnaseq_wgc_test = train_test_split(rnaseq_wgc, test_size=0.2, random_state=42)
rnaseq_wgc_train_X = rnaseq_wgc_train.drop(columns=["sqrt_V"], inplace=False)
rnaseq_wgc_train_y = rnaseq_wgc_train.loc[:, "sqrt_V"].values
rnaseq_wgc_test_X = rnaseq_wgc_test.drop(columns=["sqrt_V"], inplace=False)
rnaseq_wgc_test_y = rnaseq_wgc_test.loc[:, "sqrt_V"].values
rnaseq_wgc_all_X = rnaseq_wgc.drop(columns=["sqrt_V"], inplace=False)
rnaseq_wgc_all_y = rnaseq_wgc.loc[:, "sqrt_V"].values


In [None]:

# cross-validation
model_design="rnaseq_vs_hiplex_rm_outlier_log"
gene_select_name ="coding_all"
frag_type="mixed"
all_models={}
all_models["coding_all"] = {}
model_params_file = f'/dcs05/hongkai/data/next_cutntag/script/explainability/{model_design}/{gene_select_name}.json'
random_seed=42
with open(model_params_file) as fp:
    model_params = json.load(fp)
    
# set parameters
for i in range(len(model_params)):
    model_name = model_params[i]['model_name']
    if model_name == "rf":
        model = RandomForestRegressor(random_state=random_seed)
    else:
        model = LinearRegression()
    params = model_params[i]['best params']

    if len(params) > 0:
        params = remove_model_in_param_grid(params)
        print(params)
        model.set_params(**params)
    all_models["coding_all"][model_name] = model

print(gene_select_name)
for frag_type in ['mixed']:
    print(frag_type)
    # rnaseq_wgc_all_X = data[gene_select_name][model_design][frag_type]["rnaseq_wgc_all_X"]
    # rnaseq_wgc_all_y = data[gene_select_name][model_design][frag_type]["rnaseq_wgc_all_y"]
    models = {"random_forest": all_models[gene_select_name]["rf"], 
              "lr": all_models[gene_select_name]["lr"]}
    # 5-fold cross validation
    kf = KFold(n_splits=5, shuffle=True, random_state=42)

    cv_scores = {}
    # Perform KFold cross-validation
    for model_name, model in models.items():
        rmse_scores = []
        cors = []
        spearman_cors = []
        for train_index, test_index in kf.split(rnaseq_wgc_all_X):
            X_train, X_test = rnaseq_wgc_all_X.iloc[train_index], rnaseq_wgc_all_X.iloc[test_index]
            y_train, y_test = rnaseq_wgc_all_y[train_index], rnaseq_wgc_all_y[test_index]
            
            # Fit the model on training data
            model.fit(X_train, y_train)
            
            # Predict on test data
            y_pred = model.predict(X_test)
            fig, ax = plt.subplots()
            xy = np.vstack([y_pred,y_test])
            z = gaussian_kde(xy)(xy)
            print(model)
            ax.scatter(y_pred, y_test, c=z, s=5)
            # Create scatter plot
            # plt.scatter(y_pred, y_test)
            
            # # Add labels and title
            plt.xlabel('y_pred')
            plt.ylabel('y_test')
            plt.title(f'{model_name} {test_index}')
            
            # Show plot
            plt.show()
            # Calculate mean squared error
            rmse = root_mean_squared_error(y_test, y_pred)
            cor = scipy.stats.pearsonr(y_test, y_pred)[0]
            spearman_cor = scipy.stats.spearmanr(y_test, y_pred)[0]
            spearman_cors.append(spearman_cor)
            rmse_scores.append(rmse)
            cors.append(cor)
        
        average_spearman_cor = np.mean(spearman_cors)
        average_rmse = np.mean(rmse_scores)
        average_cor = np.mean(cors)
        print(model_name)
        print(f'Root Mean Squared Error for each fold: {rmse_scores}')
        print(f'Average Mean Squared Error: {average_rmse}')
        print(f'Pearson for each fold: {cors}')
        print(f'Average Mean Pearson: {average_cor}\nAverage Mean Spearman: {average_spearman_cor}')

In [None]:
# kmeans cluster
model_design == "rnaseq_vs_hiplex_rm_outlier_log"
wgc_raw = data[gene_select_name][model_design][frag_type]["rnaseq_wgc_all_X"]
wcss = []
for n_clusters in range(2, 11):
    # kmeans = KMeans(n_clusters=i, init='k-means++', max_iter=300, n_init=10, random_state=0)
    kmeans = KMeans(n_clusters=n_clusters, init='k-means++', random_state=42, max_iter=300, n_init=10)
    kmeans.fit(wgc_raw)
    wcss.append(kmeans.inertia_)  # inertia_ is the WCSS
kn = 
(range(2, 11), wcss, curve='convex', direction='decreasing')
print(kn.knee)
best_n_clusters = kn.knee
plt.plot(range(2, 11), wcss)
plt.vlines(kn.knee, plt.ylim()[0], plt.ylim()[1], linestyles='dashed')
plt.title('Elbow Method')
plt.xlabel('Number of clusters')
plt.ylabel('WCSS')
plt.show()

cluster_method = KMeans(n_clusters=best_n_clusters, random_state=42)
# Fit the model
cluster_method.fit(wgc_raw)
# Assign cluster label
wgc_raw.loc[:, "cluster"] = cluster_method.labels_


In [None]:
# reorder cluster label based on average gene expression
rnaseq_dir = "/dcs05/hongkai/data/next_cutntag/bulk/RNA-seq"
rnaseq_filename = "RNA_seq_TPM_all.csv"
rnaseq_dir_filename = os.path.join(rnaseq_dir, rnaseq_filename)
rnaseq = pd.read_csv(rnaseq_dir_filename, sep=",", header=0, index_col=0)
rnaseq.loc[:, "sqrt_V"] = rnaseq.loc[:, "sqrt_V1V2"]

wgc_raw.loc[:, "sqrt_V"] = rnaseq.loc[wgc_raw.index, "sqrt_V"]
gene_exp_order = wgc_raw.groupby('cluster')['sqrt_V'].mean().sort_values(ascending=True).index.to_list()
swap_dict = dict(zip(
    gene_exp_order,
    list(range(len(gene_exp_order))),
))
print(wgc_raw.groupby('cluster')['sqrt_V'].mean().sort_values(ascending=True))
print(swap_dict)
wgc_raw['cluster'] = wgc_raw['cluster'].map(swap_dict)

In [None]:
os.makedirs(f"/dcs05/hongkai/data/next_cutntag/bulk/explainability/{model_design}/fig/coding_all/", exist_ok=True)
wgc_raw.to_csv(f"/dcs05/hongkai/data/next_cutntag/bulk/explainability/{model_design}/fig/coding_all/cluster.csv")

In [None]:
wgc_raw = pd.read_csv(f"/dcs05/hongkai/data/next_cutntag/bulk/explainability/{model_design}/fig/coding_all/cluster.csv", index_col=0)


In [None]:
cluster_ids = wgc_raw['cluster'].unique()
cluster_ids.sort()

In [None]:
feature_colors_df = pd.read_csv("/dcs05/hongkai/data/next_cutntag/bulk/explainability/rnaseq_vs_hiplex_rm_outlier_log/fig/feature_colors.csv", index_col=0)
feature_colors_df

In [None]:
result_dir = f"/dcs05/hongkai/data/next_cutntag/bulk/explainability/{model_design}/"

In [None]:
from functools import reduce
import pandas as pd

# Assume df_list is your list of DataFrames
# and they all have a common column called "gene"


random_seeds = [0, 1, 7, 42, 123, 999, 1234, 1337, 2021, 31415]
feature_importance_list = []
for random_seed in random_seeds:
    feature_importance = pd.read_csv(f"{result_dir}/{random_seed}/top_features.csv")
    # short_target_pairs = map_target_names(feature_importance["feature"], target_pair_mapping_df)
    # feature_importance["feature"] = short_target_pairs
    feature_importance.rename(columns={'importance': f'{random_seed}_importance'}, inplace=True)
    feature_importance_list.append(feature_importance)
feature_importance_all = reduce(lambda left, right: pd.merge(left, right, on='feature'), feature_importance_list)
feature_importance_all = feature_importance_all.set_index("feature")
feature_importance_all['importance'] = feature_importance_all.mean(axis=1)

# Sort by row mean from biggest to smallest
feature_importance = feature_importance_all.sort_values(by='importance', ascending=False)
feature_importance = feature_importance.reset_index()
feature_importance = feature_importance.loc[:19, ["feature", "importance"]]

top_features = feature_importance["feature"]
short_target_pairs = map_target_names(feature_importance["feature"], target_pair_mapping_df)
feature_importance["feature"] = short_target_pairs
print(feature_importance)

In [None]:
# feature importance plotting
feature_importance["norm_feature_importance"] = np.sqrt(feature_importance["importance"])

plt.figure(figsize=(6, 6))
ax = sns.barplot(
    data=feature_importance,
    y='feature',
    x='norm_feature_importance',
    palette=feature_importance['colors'].tolist()
)
ax.grid(False)
title_name = ""
ax.set_title(title_name, fontsize=25, weight='bold')
ax.set_xlabel('Feature Importance', fontsize=15)
ax.set_ylabel('', fontsize=13)
ax.tick_params(axis='x', labelsize=15)
ax.tick_params(axis='y', labelsize=15)

# Customize the x-axis tick labels to show original importance
# Get current ticks, square them to invert the sqrt transform
xticks = ax.get_xticks()
xtick_labels = [f"{x**2:.2f}" for x in xticks]
ax.set_xticklabels(xtick_labels)

plt.tight_layout()
os.makedirs(f"/dcs05/hongkai/data/next_cutntag/bulk/explainability/{model_design}/fig/coding_all/", exist_ok=True)
plt.savefig(f"/dcs05/hongkai/data/next_cutntag/bulk/explainability/{model_design}/fig/coding_all/feature_importance_scaled.pdf", transparent=True)
plt.show()

In [None]:
# load pdp result calculated from 10 trained models
result_dir = f"/dcs05/hongkai/data/next_cutntag/bulk/explainability/{model_design}/"

# /dcs05/hongkai/data/next_cutntag/bulk/explainability/rnaseq_vs_hiplex/0/top_features_pdp.pkl
random_seeds = [0, 1, 7, 42, 123, 999, 1234, 1337, 2021, 31415]
pdp_results = {}
for random_seed in random_seeds:
    pdp_result_dir = f"{result_dir}/{random_seed}/top_features_pdp.pkl"
    with open(pdp_result_dir, 'rb') as f:
        pdp_lines_result = pickle.load(f)
    pdp_results[random_seed] = pdp_lines_result
random_seeds = [0, 1, 7, 42, 123, 999, 1234, 1337, 2021, 31415]
cluster_results = []
for random_seed in random_seeds:
    rnaseq_wgc_all_X = pd.read_csv(f"{result_dir}/{random_seed}/rnaseq_wgc_all_X.csv", index_col=0)
    rnaseq_wgc_all_X["cluster"] = wgc_raw.loc[rnaseq_wgc_all_X.index, "cluster"]
    rnaseq_wgc_all_X["category"] = "coding_cpg"
    rnaseq_wgc_all_X.loc[rnaseq_wgc_all_X.index.isin(gene_select_dict["coding_non_cpg"]),"category"] = "coding_non_cpg"
    cluster_results.append(rnaseq_wgc_all_X)
cluster_results_stacked = pd.concat(cluster_results, ignore_index=True)

In [None]:
# draw pdp plot for each epitope pair in each cluster
effect_size_dict = {}
gene_select_name = "coding_all"
# Prepare the figure with subplots
n_rows = len(top_features)  # one row per rm_target_pair
n_cols = 2  # two columns: centered and not_centered
fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 5 * n_rows), squeeze=False)
# Iterate over plot types and features (rm_target_pair)
for col_idx, plt_type in enumerate(["centered", "not_centered"]):
    for row_idx, rm_target_pair in enumerate(top_features):
        axis = axes[row_idx, col_idx]
        print(f"Plotting: gene_select_name={gene_select_name}, plt_type={plt_type}, rm_target_pair={rm_target_pair}")
        grid_values = pdp_results[random_seed]["coding_all"][rm_target_pair]["grid_values"][0]
        x = grid_values
        
        axis.set_xlim(xmin=-0.1, xmax=1.5)
        axis.set_xlabel('Hi-Plex Signal', fontsize=20)
        axis.set_ylabel('Gene Expression', fontsize=20)

        color_map = 'Paired'
        cmap = plt.get_cmap(color_map)
        colors = cmap(np.linspace(0, 1, len(cluster_ids)))
        colors = ["#D76532", "#F6C546", "#8574A4", "#5583C2", "#8CB463"]
        texts = []
        for i, cluster_id in enumerate(cluster_ids):
            line_color = colors[i]
            individual_lines = get_average_lines(pdp_results, rm_target_pair)
            individual_lines = individual_lines[cluster_results_stacked["cluster"]==cluster_id,:]
            y = individual_lines.mean(axis=0)
            

            if plt_type == "centered":
                if min(grid_values) < 1:
                    y_at_zero = y[np.argmin(np.abs(x))]
                    y = y - y_at_zero
                else:
                    y = y  - y[0]
            y = y[:len(x)]

            axis.plot(
                x,
                y,
                linewidth=1.8,
                label=f'cluster {cluster_id}',
                color=line_color
            )
            axis.tick_params(axis='x', labelsize=18)  # X-axis ticks
            axis.tick_params(axis='y', labelsize=18)  # Y-axis ticks
            texts.append(axis.text(x[-1] + 0.0125,
                                   y[-1],
                                   f'cluster {cluster_id+1}',
                                   verticalalignment='center',
                                   color=axis.get_lines()[-1].get_color(),
                                   alpha=axis.get_lines()[-1].get_alpha(),
                                   fontsize=22))
        
        adjust_text(texts, ax=axis)

        axis.spines['top'].set_visible(False)
        axis.spines['right'].set_visible(False)
        rm_target_pair_short =map_target_names([rm_target_pair], target_pair_mapping_df)[0]
        axis.set_title(f"{rm_target_pair_short}\n({plt_type})", fontsize=28)
        axis.grid(False)
    # effect_size_dict[gene_select_name][plt_type] = effect_size_df

plt.tight_layout()
save_dir = f"/dcs05/hongkai/data/next_cutntag/bulk/explainability/{model_design}/fig/pdp_grid/"
plt.show()
os.makedirs(save_dir, exist_ok=True)
fig.savefig(os.path.join(save_dir, f"{gene_select_name}_pdp_grid.pdf"))
plt.close(fig)


In [None]:
# draw pdp plot for each epitope pair in each cluster with selected epitope pairs
effect_size_dict = {}
gene_select_name = "coding_all"
# Prepare the figure with subplots
n_rows = len(top_features)  # one row per rm_target_pair
n_cols = 2  # two columns: centered and not_centered
# fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 5 * n_rows), squeeze=False)
target_pair_threshold = cluster_result.loc[:,~cluster_result.columns.isin(["sqrt_V", "pos", "cluster", "category"])].quantile(1)
# Iterate over plot types and features (rm_target_pair)

for col_idx, plt_type in enumerate(["centered"]):
    color_map = 'Paired'
    cmap = plt.get_cmap(color_map)
    colors = cmap(np.linspace(0, 1, len(cluster_ids)))
    for i, cluster_id in enumerate(cluster_ids):
        fig, axis = plt.subplots(figsize=(6, 6))  # create a single subplot
        axis.set_ylim(ymin=-0.5, ymax=0.75)
        for gene_select_name in ["coding_all"]:
            colors = ["#D76532", "#F6C546", "#8574A4", "#4C9F70"]
            for row_idx, rm_target_pair in enumerate(["H3K27me3-H3K4me3", "H3K4me1-H3K4me3", "H3K27ac-H3K4me3", "H3K27me3-H3K9me3"]):
                # pdp_lines_dict_i = pdp_lines_result[gene_select_name]
                grid_values = pdp_results[random_seed]["coding_all"]["H3K4me3-H3K4me3"]["grid_values"][0]
                grid_values = pdp_results[random_seed]["coding_all"][rm_target_pair]["grid_values"][0]
                x = grid_values
                x = x[x <= target_pair_threshold[rm_target_pair]]
                if min(grid_values) < 0:
                    axis.set_xlim(xmin=-4, xmax=4)
                    axis.set_xlabel('log2FC(Hi-Plex Signal)', fontsize=20)
                    axis.set_ylabel('log2FC(Gene Expression)', fontsize=20)
                else:
                    axis.set_xlim(xmin=-0.05, xmax=1.5)
                    axis.set_xlabel('Hi-Plex Signal', fontsize=20)
                    axis.set_ylabel('Gene Expression', fontsize=20)
                texts = []
                print(f"Plotting: gene_select_name={gene_select_name}, plt_type={plt_type}, rm_target_pair={rm_target_pair}")
                line_color = colors[row_idx]
                individual_lines = get_average_lines(pdp_results, rm_target_pair)
                individual_lines = individual_lines[(cluster_results_stacked["cluster"]==cluster_id),:]
                y = individual_lines.mean(axis=0)
                if plt_type == "centered":
                    if min(grid_values) < 1:
                        y_at_zero = y[np.argmin(np.abs(x))]
                        y = y - y_at_zero
                    else:
                        y = y  - y[0]
                y = y[:len(x)]
                y = np.sign(y) * np.sqrt(np.abs(y))
                linestyle = '-'
                if gene_select_name == "coding_non_cpg":
                    linestyle='--'
                axis.plot(
                    x,
                    y,
                    linewidth=1.8,
                    label=f'cluster {cluster_id}',
                    color=line_color,
                    linestyle=linestyle
                )
                axis.tick_params(axis='x', labelsize=18)  # X-axis ticks
                axis.tick_params(axis='y', labelsize=18)  # Y-axis ticks
                if gene_select_name == "coding_all":
                    texts.append(axis.text(x[-1] - 0.5,
                                           y[-1]+ 0.025,
                                           f'{rm_target_pair}',
                                           verticalalignment='center',
                                           color=axis.get_lines()[-1].get_color(),
                                           alpha=axis.get_lines()[-1].get_alpha(),
                                           fontsize=18))
        
        adjust_text(texts, ax=axis)
        
        def inv_signed_sqrt(y):
            return np.sign(y) * (y ** 2)
        from matplotlib.ticker import FuncFormatter

        axis.yaxis.set_major_formatter(FuncFormatter(lambda val, pos: f"{inv_signed_sqrt(val):.2f}"))
        axis.spines['top'].set_visible(False)
        axis.spines['right'].set_visible(False)
        rm_target_pair_short =map_target_names([rm_target_pair], target_pair_mapping_df)[0]
        axis.set_title(f"Cluster {cluster_id+1}", fontsize=28)
        axis.grid(False)
        plt.tight_layout()
        save_dir = f"/dcs05/hongkai/data/next_cutntag/bulk/explainability/{model_design}/fig/coding_all/sqrt_coding_all_{plt_type}_same_range/"
        plt.show()
        os.makedirs(save_dir, exist_ok=True)
        fig.savefig(os.path.join(save_dir, f"{cluster_id}.pdf"))
        plt.close(fig)


In [None]:
# PDP plots for all epitope pairs
width=10
height=10
model_design="rnaseq_vs_hiplex_rm_outlier_log"
ys_final = []
for gene_select_name in ["coding_all"]:
    for model_name in ["random_forest"]:
        fig = plt.figure(figsize=(width,height))
        axis1 = fig.add_subplot(111)
        if model_design == "rnaseq_vs_hiplex":
            axis1.set_xlim(xmin=-0.1, xmax=1.5)
            axis1.set_ylim(ymin=-0.5, ymax=1.25)
        else:
            axis1.set_xlim(xmin=-4.5, xmax=4.5)
        axis1.set_xlim(xmin=-0.1, xmax=1.5)
        axis1.set_ylim(ymin=-0.5, ymax=1.25)
        axis1.set_ylim(ymin=-0.1, ymax=0.1)
        # model = all_models[model_design][gene_select_name][model_name]
        target_pairs = list(top_features)
        short_target_pairs = map_target_names(target_pairs, target_pair_mapping_df)
        color_map = 'nipy_spectral'
        cmap = plt.get_cmap(color_map)
        cmap = cmap.reversed()
        colors = cmap(np.linspace(0, 1, len(target_pairs)))
        colors = [
            "#1b9e77",  # dark teal green
            "#d95f02",  # burnt orange
            "#7570b3",  # muted indigo
            "#e7298a",  # deep magenta
            "#66a61e",  # olive green
            "#e6ab02",  # golden mustard
            "#a6761d",  # brown ochre
            "#666666",  # dark gray
            "#1f78b4",  # steel blue
            "#6a3d9a",  # deep purple
            "#b15928",  # rust brown
            "#01665e",  # forest teal
            "#8c510a",  # dark amber brown
            "#2c3e50",  # dark slate blue-gray
            "#4b0082",   # indigo (very dark purple)
            "#003f5c",  # deep navy blue
            "#2f4f4f",  # dark slate gray
            "#800000",  # maroon
            "#191970",  # midnight blue
            "#3c1053",  # deep violet
        ]
        texts = []
        effect_sizes = []
        for i in range(len(target_pairs)):
            target_pair = target_pairs[i]
            target1 = target_pair.split("-")[0]
            target2 = target_pair.split("-")[1]
            linestyle = "solid"
            # if target1 == target2:
            #     linestyle = "dashed"
            # else:
            #     linestyle = "solid"
            # pdp_lines_dict_i = pdp_lines_result["coding_all"]
            grid_values = pdp_results[random_seed]["coding_all"][target_pair]["grid_values"][0]
            x = grid_values
            individual_lines = get_average_lines(pdp_results, target_pair)
            individual_lines = individual_lines
            # print(individual_lines.shape)
            y = individual_lines.mean(axis=0)
            # print(y)
            if model_design == "rnaseq_vs_hiplex":
                y = y - y[0]
            else:
                y_at_zero = y[np.argmin(np.abs(x))]
                y = y - y_at_zero
            line_color = "gray"
            ys_final.append(y[-1])
            line_alpha = 0.8
            
            line_color = colors[i]
            axis1.plot(
                x,
                y,
                linewidth=2,
                linestyle=linestyle,
                label=target_pair,
                color=line_color,
                alpha=line_alpha
            )
            
            texts.append(axis1.text(x[-1]+0.0125, 
                                            y[-1], 
                                            short_target_pairs[i], 
                                            verticalalignment='center', 
                                            color=axis1.get_lines()[-1].get_color(), 
                                            alpha=axis1.get_lines()[-1].get_alpha(),
                                            fontsize=22))
        adjust_text(texts, expand=(1.2, 2))
        if model_design == "rnaseq_vs_hiplex":
            axis1.set_xlabel('Hi-Plex Signal', fontsize=20)
            axis1.set_ylabel('Gene Expression', fontsize=20)
        else:
            axis1.set_xlabel('log2(hiplex reads) FC', fontsize=20)
            axis1.set_ylabel('log2(RNAseq TPM) FC', fontsize=20)
        axis1.spines['top'].set_visible(False)
        axis1.spines['right'].set_visible(False)
        axis1.tick_params(axis='x', labelsize=18)  # X-axis ticks
        axis1.tick_params(axis='y', labelsize=18)  # Y-axis ticks
        # matplotx.line_labels()
        # axis1.legend(target_pairs, loc='upper center', bbox_to_anchor=(0.5, -0.05), fancybox=True, shadow=True, ncol=5)
        # axis1.title.set_text(f"{gene_select_name} {model_name}")
        # axis1.set_title(f"{model_design} {gene_select_name} {model_name}", fontsize=25)
        if gene_select_name == "coding_all":
            title_name = "Coding Genes"
        axis1.set_title(title_name, fontsize=28)
        save_dir=f"/dcs05/hongkai/data/next_cutntag/bulk/explainability/{model_design}/fig/{gene_select_name}/"
        os.makedirs(save_dir, exist_ok=True)
        # axis1.set_title(f"Partial Dependence Plot of Top-150 target Pairs \non Gene Expression", fontsize=28)
        print(f'{save_dir}/summary_{width}x{height}.pdf')
        # fig.savefig(f'{save_dir}/summary_{width}x{height}.pdf', bbox_inches = "tight", transparent=True)
        # axis1.set_ylim([-0.1, 0.1])
        # fig.savefig(f'{save_dir}/summary_zoom_in.pdf', bbox_inches = "tight")

In [None]:
# effect size calculation
top_features = list(top_features)
gene_select_name = 'coding_all'
for effect_size_method in ["auc"]:
# Iterate over plot types and features (rm_target_pair)
    for col_idx, plt_type in enumerate(["centered"]):
        effect_size_df = pd.DataFrame(np.zeros((len(cluster_ids), len(top_features))), index=cluster_ids, columns=top_features)
        for row_idx, rm_target_pair in enumerate(top_features):
            grid_values = grid_values = pdp_results[random_seed]["coding_all"][rm_target_pair]["grid_values"][0]
            x = grid_values
            for i, cluster_id in enumerate(cluster_ids):
                line_color = colors[i]
                individual_lines = get_average_lines(pdp_results, rm_target_pair)
                if gene_select_name == "coding_all":
                    individual_lines = individual_lines[cluster_results_stacked["cluster"]==cluster_id,:]
                else:
                    individual_lines = individual_lines[(cluster_results_stacked["cluster"]==cluster_id) & (cluster_results_stacked["category"]==gene_select_name),:]
                y = individual_lines.mean(axis=0)
                if plt_type == "centered":
                    if min(grid_values) < 1:
                        y_at_zero = y[np.argmin(np.abs(x))]
                        y = y - y_at_zero
                    else:
                        y = y  - y[0]
                effect_size = calculate_effect_size(x, y, method=effect_size_method)
                effect_size_df.loc[cluster_id, rm_target_pair] = effect_size
        effect_size_df = effect_size_df / effect_size_df.std()
        effect_size_df = effect_size_df.loc[effect_size_df.index.sort_values(),:]
        short_target_pairs = map_target_names(effect_size_df.columns.tolist(), target_pair_mapping_df)
        effect_size_df.columns = short_target_pairs
        effect_size_df.index = effect_size_df.index + 1
        effect_size_df.to_csv(f'{save_dir}/effect_size_{effect_size_method}.csv')
