In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys 
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc 
import muon as mu
from statsmodels.stats.multitest import multipletests
from adjustText import adjust_text
from scipy.stats import percentileofscore, pearsonr, mannwhitneyu
from collections import defaultdict


sys.path.append(os.path.abspath('/Users/reetmishra/GitHub/kampmann/mglia_regulators/utils'))

import signature_heatmaps as signature_heatmaps
import factor_labels as factor_labels


# Read in Data

In [None]:
data_dir = "/Users/reetmishra/GitHub/kampmann/mglia_regulators/data/after_metadata/"
cite_6tf_path = os.path.join(data_dir, "cite_6tf_cleaned.h5mu")
cite_imgl_path = os.path.join(data_dir, "cite_imgl_cleaned.h5mu")
merged_6tf_path = os.path.join(data_dir, "adata_merged_6tf.h5ad")

In [None]:
mdata_dict = {}
mdata_dict['cite_6tf'] = mu.read_h5mu(cite_6tf_path)
mdata_dict['cite_imgl'] = mu.read_h5mu(cite_imgl_path)

adata_dict = {}
adata_dict['merged_6tf'] = sc.read_h5ad(merged_6tf_path)
adata_dict['cite_6tf'] = mdata_dict['cite_6tf'].mod['rna'].copy()
adata_dict['cite_imgl'] = mdata_dict['cite_imgl'].mod['rna'].copy()

In [None]:
signature_cols_ordered = ['homeostatic_score_ucell',
 'interferon_score_ucell',
 'chemokine_score_ucell',
 'antigen_presenting_score_ucell',
 'dam_score_ucell',
 'lipid_dam_score_ucell']

# Masks to exclude ntc_g5 + foxk1_g2 + mixscale_cutoff

In [None]:
guides_to_exclude = ['FOXK1_g2', 'non-targeting_g5']

adata_6tf_clean = adata_dict['merged_6tf'][~adata_dict['merged_6tf'].obs['guide'].isin(["non-targeting_g5", "FOXK1_g2"])]
adata_imgl_clean = adata_dict['cite_imgl'][~adata_dict['cite_imgl'].obs['guide'].isin(["non-targeting_g5", "FOXK1_g2"])]
adata_6tf_clean.shape, adata_imgl_clean.shape

In [None]:
print(mdata_dict['cite_6tf'].shape, mdata_dict['cite_imgl'].shape)
mdata_6tf_clean = mdata_dict['cite_6tf'][~mdata_dict['cite_6tf'].mod['rna'].obs['guide'].isin(["non-targeting_g5", "FOXK1_g2"])]
mdata_imgl_clean = mdata_dict['cite_imgl'][~mdata_dict['cite_imgl'].mod['rna'].obs['guide'].isin(["non-targeting_g5", "FOXK1_g2"])]
mdata_6tf_clean.shape, mdata_imgl_clean.shape

In [None]:
mixscale_col = "new_mixscale_score"

adata_imgl_clean.obs[mixscale_col] = new_mixscale_scores_imgl['x'].tolist()
adata_6tf_clean.obs[mixscale_col] = new_mixscale_scores_6tf['x'].tolist()

In [None]:
adata_6tf_masked = adata_6tf_clean[adata_6tf_clean.obs[mixscale_col] >= 0].copy()
adata_imgl_masked = adata_imgl_clean[adata_imgl_clean.obs[mixscale_col] >= 0].copy()
adata_6tf_masked.shape, adata_imgl_masked.shape

# Heatmaps

In [None]:
sns.set_theme('poster', style="white", palette="viridis")
fig = sc.pl.umap(adata_6tf_clean,
                 color=signature_cols_ordered,
                 ncols=1,
                 cmap="viridis",
                 size=30,
                 vmin = 0,
                 vmax = 0.7,
                 frameon=False,
                 title=[''] * len(signature_cols_ordered),
                 show=False,
                 return_fig=True)
fig.show()

In [None]:
sns.set_theme('poster', style="white", palette="viridis")
fig = sc.pl.umap(adata_imgl_clean,
                 color=signature_cols_ordered,
                 ncols=1,
                 cmap="viridis",
                 size=30,
                 vmin = 0,
                 vmax = 0.7,
                 frameon=False,
                 title=[''] * len(signature_cols_ordered),
                 show=False,
                 return_fig=True)
fig.show()

In [None]:
sns.set_theme("talk")
name = "merged_6tf"
adata = adata_6tf_masked

all_median_differences, annotations, row_labels = signature_heatmaps.calculate_median_differences_and_annotations(adata, 
                                                                                                                  signature_cols_ordered, 
                                                                                                                  annot_with_num=False, 
                                                                                                                  by_guide=False)
annotations_fdr = annotations[annotations.columns[annotations.columns.str.endswith("fdr")]]

g = signature_heatmaps.plot_clustermap(
        data=all_median_differences,
        annotations=annotations_fdr,
        row_labels=row_labels,
        title=f"Median Difference for {name}",
        xlabel="Functional Scores",
        ylabel="Gene",
        alphabetical=True,
        vmin=-0.03,
        vmax=0.03,
        figsize=(4,20),
        by_guide = False,
        rowC = True,
        colC=True,
)

In [None]:
sns.set_theme("talk")
name = "cite_imgl"
adata = adata_imgl_masked

all_median_differences, annotations, row_labels = signature_heatmaps.calculate_median_differences_and_annotations(adata, 
                                                                                                                  signature_cols_ordered, 
                                                                                                                  annot_with_num=False, 
                                                                                                                  by_guide=False)
annotations_fdr = annotations[annotations.columns[annotations.columns.str.endswith("fdr")]]

g = signature_heatmaps.plot_clustermap(
        data=all_median_differences,
        annotations=annotations_fdr,
        row_labels=row_labels,
        title=f"Median Difference for {name}",
        xlabel="Functional Scores",
        ylabel="Gene",
        alphabetical=True,
        vmin=-0.03,
        vmax=0.03,
        figsize=(4,20),
        by_guide = False,
        rowC = True,
        colC=True,
        
)

# Signature Point Plots with Error Bars

In [None]:
def bootstrapped_ci(data: list, num_replicates: int = 10000, seed: int = 42):
    data_arr = np.array(data)

    np.random.seed(seed)
    bootstrap_medians = []
    for _ in range(num_replicates):
        bootstrap_sample = np.random.choice(data_arr, size=len(data), replace=True)
        bootstrap_median = np.median(bootstrap_sample)
        bootstrap_medians.append(bootstrap_median)

    bootstrap_medians = np.array(bootstrap_medians)

    lower_bound = np.percentile(bootstrap_medians, 2.5)
    upper_bound = np.percentile(bootstrap_medians, 97.5)
    
    median = np.median(data_arr)
    
    print(f"Original median: {np.median(data)}")
    print(f"95% Confidence Interval for the median: [{lower_bound:.2f}, {upper_bound:.2f}]")
    return median, lower_bound, upper_bound


    

In [None]:
def ntc_percentile_effect_v2(df_long, factor_col="Signature", guide_col="perturbed_guide"):
    out = []
    for f, sub_f in df_long.groupby(factor_col):
        ntc = sub_f.loc[sub_f[guide_col] == "NTC", "Score"].to_numpy()
        sub_f = sub_f.assign(
            P=[percentileofscore(ntc, x, kind="mean")/100 for x in sub_f["Score"]]
        )
        
        ntc_med, ntc_ci_low, ntc_ci_high = bootstrapped_ci(sub_f[sub_f["perturbed_guide"] == "NTC"].P)
        for g, sg in sub_f.groupby(guide_col):
            print(g)
            p_med, p_ci_low, p_ci_high = bootstrapped_ci(sub_f[sub_f["perturbed_guide"] == g].P)
            
            pval = mannwhitneyu(sg["Score"], ntc, alternative="two-sided", method="auto").pvalue
            
            out.append({
                "Factor": f,
                guide_col: g,
                "MedianPct": p_med,
                "PctShift": (p_med-ntc_med)*100,
                "pval": pval,
                "MedianPct_CI_low": p_ci_low,
                "MedianPct_CI_high": p_ci_high,
                "PctShift_CI_low": (p_ci_low - ntc_med) * 100,
                "PctShift_CI_high": (p_ci_high - ntc_med) * 100,
                "NTC_MedianPct_CI_low": ntc_ci_low,
                "NTC_MedianPct_CI_high": ntc_ci_high,
                "NTC_PctShift_CI_low": (ntc_ci_low - ntc_med) * 100,
                "NTC_PctShift_CI_high": (ntc_ci_high - ntc_med) * 100,
                
            })
    eff = pd.DataFrame(out)
    if not eff.empty:
        eff["qval"] = multipletests(eff["pval"], method="fdr_bh")[1]
    return eff

In [None]:
df_long_6tf = factor_labels.build_df_long(adata_6tf_masked, descriptive_cols= signature_cols_ordered, by_guide=False,diff_type_name="Signature")
df_long_imgl = factor_labels.build_df_long(adata_imgl_masked, descriptive_cols= signature_cols_ordered, by_guide=False,diff_type_name="Signature")


In [None]:
percentile_df_6tf = ntc_percentile_effect_v2(df_long_6tf, factor_col="Signature", guide_col = "perturbed_guide")
percentile_df_imgl = ntc_percentile_effect_v2(df_long_imgl, factor_col="Signature", guide_col = "perturbed_guide")


In [None]:
percentile_df_dnmt1 = percentile_df_6tf[percentile_df_6tf['perturbed_guide'].isin(['DNMT1_g1', 'DNMT1_g2', 'NTC'])]
percentile_df_stat2 = percentile_df_6tf[percentile_df_6tf['perturbed_guide'].isin(['STAT2_g1', 'STAT2_g2', 'NTC'])]
percentile_df_prdm1 = percentile_df_imgl[percentile_df_imgl['perturbed_guide'].isin(['PRDM1_g1', 'PRDM1_g2', 'NTC'])]

In [None]:
percentile_df_dnmt1['perturbed_gene'] = np.where(percentile_df_dnmt1['perturbed_guide'] == "NTC", "NTC", "DNMT1")
percentile_df_stat2['perturbed_gene'] =  np.where(percentile_df_stat2['perturbed_guide'] == "NTC", "NTC", "STAT2")
percentile_df_prdm1['perturbed_gene'] =  np.where(percentile_df_prdm1['perturbed_guide'] == "NTC", "NTC", "PRDM1")
percentile_df_prdm1

In [None]:
def _p_to_stars(p):
    if not np.isfinite(p):
        return ""
    if p < 0.001:
        return "***"
    elif p < 0.01:
        return "**"
    elif p < 0.05:
        return "*"
    return ""

In [None]:
percentile_df_dnmt1['sig'] = percentile_df_dnmt1.qval.map(_p_to_stars)
percentile_df_dnmt1
percentile_df_stat2['sig'] = percentile_df_stat2.qval.map(_p_to_stars)
percentile_df_stat2
percentile_df_prdm1['sig'] = percentile_df_prdm1.qval.map(_p_to_stars)
percentile_df_prdm1

In [None]:
def draw_pointplot(
    df,
    guides_to_use,
    figsize_num,
    by_guide=True,
    diff_type_name="Signature",
    difference=False,
    nrows=3,
    ncols=2,
    palette = None,
    orderlist = None,
    x_val = "PctShift",
    filename = "temp.svg"
):
    guide_col = "perturbed_guide" if by_guide else "perturbed_gene"
    
    unique_guides = guides_to_use
    fig, axes = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=figsize_num, sharex=True, sharey=True, dpi=300
    )
    axes = axes.flatten()
    for i, guide in enumerate(unique_guides):
        
        ax = axes[i]
        
        subset_guide = df[
            (df[guide_col] == guide) | (df[guide_col] == "NTC")
        ]
        sns.pointplot(
            data=subset_guide,
            x=x_val,
            y=diff_type_name,
            hue="perturbed_guide",
            dodge=False,
            join=False,
            palette=palette,
            scale=1.5,
            ax=ax,
            errorbar=None,
            order = orderlist
           
        )
        ax.grid(True, which='both', axis='both', linestyle='--', linewidth=0.5, alpha=0.7)
       
        ax.set_title(f"{guide}")
        ax.set_xlabel(f"Median Score Shift vs NTC (pp)")
       
        ax.set_ylabel(f"{diff_type_name}")
        
        ntc_rows = subset_guide[subset_guide[guide_col] == "NTC"]
        if not ntc_rows.empty:
            x_ntc = ntc_rows[x_val].values
            y_ntc = ntc_rows[diff_type_name].values
            err_low_ntc = ntc_rows[f"NTC_{x_val}_CI_low"].values
            err_high_ntc = ntc_rows[f"NTC_{x_val}_CI_high"].values
            
            err_low = np.abs(x_ntc - err_low_ntc)  
            err_high = np.abs(err_high_ntc - x_ntc) 
            
            y_order = orderlist if orderlist is not None else sorted(subset_guide[diff_type_name].unique())
            y_pos = [y_order.index(y) for y in y_ntc]
            ax.errorbar(
            x=x_ntc,
            y=y_pos,
            xerr=[err_low, err_high],
            fmt='none',
            ecolor='#CCCCCC',
            elinewidth=2.5
            )
            
        g1_guidename = guide + "_g1"
        guide1_rows = subset_guide[subset_guide["perturbed_guide"] == g1_guidename]
        
        if not guide1_rows.empty:
            x_g1 = guide1_rows[x_val].values
            y_g1 = guide1_rows[diff_type_name].values
            err_low_g1 = guide1_rows[f"{x_val}_CI_low"].values
            err_high_g1 = guide1_rows[f"{x_val}_CI_high"].values
            
            err_low = np.abs(x_g1 - err_low_g1)  
            err_high = np.abs(err_high_g1 - x_g1) 
            
            y_order = orderlist if orderlist is not None else sorted(subset_guide[diff_type_name].unique())
            y_pos = [y_order.index(y) for y in y_g1]
            ax.errorbar(
            x=x_g1,
            y=y_g1,
            xerr=[err_low, err_high],
            fmt='none',
            ecolor=palette[g1_guidename],
            elinewidth=2.5
            )
        
        g2_guidename = guide + "_g2"
        guide2_rows = subset_guide[subset_guide["perturbed_guide"] == g2_guidename]
        
        if not guide2_rows.empty:
            x_g2 = guide2_rows[x_val].values
            y_g2 = guide2_rows[diff_type_name].values
            err_low_g2 = guide2_rows[f"{x_val}_CI_low"].values
            err_high_g2 = guide2_rows[f"{x_val}_CI_high"].values
            
            err_low = np.abs(x_g2 - err_low_g2)   
            err_high = np.abs(err_high_g2 - x_g2) 
            
            y_order = orderlist if orderlist is not None else sorted(subset_guide[diff_type_name].unique())
            y_pos = [y_order.index(y) for y in y_g2]
            ax.errorbar(
            x=x_g2,
            y=y_g2,
            xerr=[err_low, err_high],
            fmt='none',
            ecolor=palette[g2_guidename],
            elinewidth=2.5
            )

        ax.legend().remove()

    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout()
    plt.savefig(filename)
    plt.show()

In [None]:
my_palette_dnmt1 = { "DNMT1_g1": "#A685C4","DNMT1_g2": "#5C466F", "NTC":"#CCCCCC"}
my_palette_stat2 =  { "STAT2_g1": "#6495E0","STAT2_g2": "#3B5989", "NTC":"#CCCCCC"}
my_palette_prdm1 = { "PRDM1_g1": "#E36769","PRDM1_g2": "#9D2937", "NTC":"#CCCCCC"}

In [None]:
ordered_list_dnmt1 = list(percentile_df_dnmt1[percentile_df_dnmt1['perturbed_guide'] != "NTC"].sort_values("MedianPct").Factor.unique())
ordered_list_stat2 = list(percentile_df_stat2[percentile_df_stat2['perturbed_guide'] != "NTC"].sort_values("MedianPct").Factor.unique())
ordered_list_prdm1 = list(percentile_df_prdm1[percentile_df_prdm1['perturbed_guide'] != "NTC"].sort_values("MedianPct").Factor.unique())


In [None]:
sns.set_theme("notebook", "white")

draw_pointplot(percentile_df_dnmt1, ['DNMT1'], (12,3), by_guide=False, difference=True,
                             diff_type_name="Factor", nrows=1, ncols=2, palette = my_palette_dnmt1,
                             orderlist = ordered_list_dnmt1)

draw_pointplot(percentile_df_stat2, ['STAT2'], (12,3), by_guide=False, difference=True,
                             diff_type_name="Factor", nrows=1, ncols=2, palette = my_palette_stat2,
                             orderlist = ordered_list_stat2)

draw_pointplot(percentile_df_prdm1, ['PRDM1'], (12,3), by_guide=False, difference=True,
                             diff_type_name="Factor", nrows=1, ncols=2, palette = my_palette_prdm1,
                             orderlist = ordered_list_prdm1)



In [None]:
sns.set_theme("notebook", "white")
draw_pointplot(percentile_df_dnmt1, ['DNMT1'], (12,3), by_guide=False, difference=True,
                             diff_type_name="Factor", nrows=1, ncols=2, palette = my_palette_dnmt1, x_val = "MedianPct",
                             orderlist = ordered_list_dnmt1)

draw_pointplot(percentile_df_stat2, ['STAT2'], (12,3), by_guide=False, difference=True,
                             diff_type_name="Factor", nrows=1, ncols=2, palette = my_palette_stat2, x_val = "MedianPct",
                             orderlist = ordered_list_stat2)

draw_pointplot(percentile_df_prdm1, ['PRDM1'], (12,3), by_guide=False, difference=True,
                             diff_type_name="Factor", nrows=1, ncols=2, palette = my_palette_prdm1, x_val = "MedianPct",
                             orderlist = ordered_list_prdm1)


# KDE plots of knockdowns for each state

In [None]:
gene = "STAT2"
sub_df = adata_6tf_masked[adata_6tf_masked.obs.perturbed_gene.isin(["NTC", gene])]

gene_mask = sub_df.obs['perturbed_gene'] == gene
ntc_mask = sub_df.obs['perturbed_gene'] == "NTC"

gene_cells = sub_df[gene_mask]
ntc_cells = sub_df[ntc_mask]

ntc_subsample = ntc_cells[np.random.choice(ntc_cells.shape[0], gene_cells.shape[0], replace=False)]

sub_df = gene_cells.concatenate(ntc_subsample)

sns.set_theme("poster", "white")

for signature in signature_cols_ordered:
    fig = plt.figure(figsize=(5,5), dpi=300)
    sns.kdeplot(sub_df.obs,
                x=signature,
                hue="perturbed_guide",
                legend=False,
                palette=my_palette_stat2,
                fill=False,
                linewidth = 6,
                )


In [None]:
gene = "DNMT1"
sub_df = adata_6tf_masked[adata_6tf_masked.obs.perturbed_gene.isin(["NTC", gene])]

gene_mask = sub_df.obs['perturbed_gene'] == gene
ntc_mask = sub_df.obs['perturbed_gene'] == "NTC"

gene_cells = sub_df[gene_mask]
ntc_cells = sub_df[ntc_mask]

guide_mean_count = round((sub_df[sub_df.obs['perturbed_guide'] == "DNMT1_g1"].shape[0] + sub_df[sub_df.obs['perturbed_guide'] == "DNMT1_g2"].shape[0])/2)

gene_subsample = gene_cells[np.random.choice(gene_cells.shape[0], ntc_cells.shape[0], replace=False)]
sub_df = ntc_cells.concatenate(gene_subsample)

sns.set_theme("poster", "white")

for signature in signature_cols_ordered:
    fig = plt.figure(figsize=(5,5), dpi=300)
    sns.kdeplot(sub_df.obs,
                x=signature,
                hue="perturbed_guide",
                legend=False,
                palette=my_palette_dnmt1,
                fill=False,
                linewidth = 6,)


In [None]:
gene = "PRDM1"
sub_df = adata_imgl_masked[adata_imgl_masked.obs.perturbed_gene.isin(["NTC", gene])]

gene_mask = sub_df.obs['perturbed_gene'] == gene
ntc_mask = sub_df.obs['perturbed_gene'] == "NTC"

gene_cells = sub_df[gene_mask]
ntc_cells = sub_df[ntc_mask]

ntc_subsample = ntc_cells[np.random.choice(ntc_cells.shape[0], gene_cells.shape[0], replace=False)]

sub_df = gene_cells.concatenate(ntc_subsample)

sns.set_theme("poster", "white")
for signature in signature_cols_ordered:
    fig = plt.figure(figsize=(5,5), dpi=300)
    sns.kdeplot(sub_df.obs,
                x=signature,
                hue="perturbed_guide",
                legend=False,
                palette=my_palette_prdm1,
                fill=False,
                linewidth = 6,)

In [None]:
sub_df.obs[signature_cols_ordered].describe()

# Enrichment Score Line Plots

In [None]:
adata_6tf_masked.shape, adata_imgl_masked.shape

In [None]:
signature_palette =  {
    "dam_score_ucell": "#BF0063",
    "homeostatic_score_ucell": "#107B35",
    "interferon_score_ucell": "#5FADAF",
    "chemokine_score_ucell": "#A335C2",
    "lipid_dam_score_ucell": "#E76333",
    "antigen_presenting_score_ucell": "#FDAC10"
    }

# UMAP

In [None]:
sc.pl.umap(adata_6tf_masked, color=["AIF1"], cmap="viridis", size=30)

In [None]:
sns.set_theme('poster', style='white', palette='viridis')
fig = sc.pl.umap(
    adata_dict['merged_6tf'],
    color=["AIF1", "CSF1R"],
    ncols=1,
    cmap='viridis',
    size=30,
    vmin=0,
    vmax=4,
    frameon=False,
    title=[''] * len(signature_cols_ordered),
    show=False,
    return_fig=True
)
fig.show()

In [None]:
sns.set_theme('poster', style='white', palette='viridis')
fig = sc.pl.umap(
    adata_dict['cite_imgl'],
    color=["AIF1", "CSF1R"],
    ncols=1,
    cmap='viridis',
    size=30,
    vmin=0,
    vmax=4,
    frameon=False,
    title=[''] * len(signature_cols_ordered),
    show=False,
    return_fig=True
)
fig.show()

# Protein heatmaps

In [None]:
mdata_6tf_masked.mod['prot'].var[mdata_6tf_masked.mod['prot'].var.index == "None"]
mdata_imgl_masked.mod['prot'].var[mdata_imgl_masked.mod['prot'].var.index == "None"]

In [None]:
mdata_masked_dict = {}
mdata_masked_dict['cite_6tf'] = mdata_6tf_masked
mdata_masked_dict['cite_imgl'] = mdata_imgl_masked

In [None]:
def get_corr_pval(df1, df2):
    corr = pd.DataFrame(index=df1.columns, columns=df2.columns, dtype=float)
    pval = pd.DataFrame(index=df1.columns, columns=df2.columns, dtype=float)
    adj_pval = pd.DataFrame(index=df1.columns, columns=df2.columns, dtype=float)

    for c1 in df1.columns:
        for c2 in df2.columns:
            r, p = pearsonr(df1[c1], df2[c2])
            corr.loc[c1, c2] = r
            pval.loc[c1, c2] = p

    pvals_flat = pval.values.flatten()
    _, adj_pvals_flat, _, _ = multipletests(pvals_flat, method='fdr_bh')
    adj_pval.iloc[:, :] = adj_pvals_flat.reshape(pval.shape)
    return corr, adj_pval

def significance_stars(pvals):
    """Convert p-values to significance stars."""
    stars = np.full(pvals.shape, "", dtype=object)
    stars[pvals < 0.05] = "*"
    stars[pvals < 0.01] = "**"
    stars[pvals < 0.001] = "***"
    return stars



In [None]:
def plot_corr_heatmap(corr_df, pval_df, figname, vmin = -1, vmax = 1, cmap = "vlag", 
                      alphabetical = True, figsize = (6, 40), zoom = False, flip=False,
                      row_cluster=True, col_cluster=True):
    
    if zoom:
        vmin = -0.2
        vmax = 0.2
    
    annot_df = significance_stars(pval_df)
    
    if flip:
        corr_df = corr_df.T
        annot_df = annot_df.T
        figsize = (figsize[1], figsize[0])
    
    if alphabetical:
        plt.figure(figsize = figsize, dpi=300)
        sns.heatmap(corr_df, cmap= cmap, center=0, annot=annot_df, fmt="", 
                    cbar_kws={"label": "Correlation"}, vmin = vmin, vmax= vmax)
        plt.title(f"Protein vs RNA Signature Score Correlation: {name}")
        plt.xlabel("RNA Program Scores")
        plt.ylabel("Proteins")
        plt.tight_layout()
        plt.savefig(figname, dpi=300)

        plt.show()
    else:
        g = sns.clustermap(corr_df, cmap= cmap, center=0, annot=annot_df, fmt="", 
                    cbar_kws={"label": "Correlation"}, vmin = vmin, vmax= vmax, figsize = (figsize[0], figsize[1] + 20),
                    row_cluster=row_cluster, col_cluster=col_cluster)
        plt.title(f"Protein vs RNA Signature Score Correlation: {name}")
        plt.xlabel("RNA Program Scores")
        plt.ylabel("Proteins")
        plt.tight_layout()
        g.savefig(figname, dpi=300)

        plt.show()

In [None]:
color_dict  = {
    "dam_score_ucell": "#BF0063",
    "homeostatic_score_ucell": "#107B35",
    "interferon_score_ucell": "#5FADAF",
    "chemokine_score_ucell": "#A335C2",
    "lipid_dam_score_ucell": "#E76333",
    "antigen_presenting_score_ucell": "#FDAC10",
}
  

In [None]:
def plot_corr_volanos(corr_df, pval_df, type, figname, ncols=5):
    nrows = int(np.ceil(len(corr_df.columns) / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*5, nrows*4), constrained_layout=True)
    for idx, col in enumerate(corr_df.columns):
        x = corr_df[col]
        pvals = pval_df[col]
        _, pvals_adj, _, _ = multipletests(pvals, method='fdr_bh')
        y = pd.Series(-np.log10(pvals_adj), index=x.index)
        y_clipped = y.clip(upper=500)  

        top10_idx = np.argsort(pvals_adj)[:20]
        ax = axes.flat[idx]
        ax.scatter(x, y_clipped, color=color_dict[col])
        texts = []
        
        for j, tidx in enumerate(top10_idx):
            true_y = y.iloc[tidx]
            y_val = y_clipped[tidx]
            if true_y > 500:
                y_val = 500 + j * 2
            texts.append(
            ax.text( 
                x.iloc[tidx],
                y_val,
                x.index[tidx],
                fontsize=10,
                bbox=dict(boxstyle='round,pad=0.2', fc='lightblue', alpha=0.5, lw=0.5),
            ))
        adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="->"), )

        ax.set_xlabel('Pearson correlation')
        ax.set_ylabel('-log10(FDR)')
        ax.set_title(f'{col}')
        ax.grid(True, alpha=0.3)

    for i in range(len(corr_df.columns), nrows*ncols):
        fig.delaxes(axes.flat[i])

    plt.suptitle(f'Correlation vs Significance for Each {type}', fontsize=16, y=1.02)
    plt.savefig(figname, dpi=300)
    plt.show()


In [None]:

corr_df_dict = {}
pval_df_dict = {}
for name, mdata in mdata_masked_dict.items():
    print(name, mdata.shape)

    prot_expr =  mdata['prot'].X.toarray()
    prot_names = mdata['prot'].var_names
    rna_scores = mdata['rna'].obs.filter(regex='_score_ucell')
    rna_score_names = rna_scores.columns
    rna_mat = rna_scores.values
    
    corr_matrix = np.zeros((prot_expr.shape[1], rna_mat.shape[1]))
    pval_matrix = np.ones((prot_expr.shape[1], rna_mat.shape[1]))

    for i in range(prot_expr.shape[1]):      
        for j in range(rna_mat.shape[1]):    
            r, p = pearsonr(prot_expr[:, i], rna_mat[:, j])
            corr_matrix[i, j] = r
            pval_matrix[i, j] = p
            

    corr_df = pd.DataFrame(corr_matrix, index=prot_names, columns=rna_score_names).sort_index()
    pval_df = pd.DataFrame(pval_matrix, index=prot_names, columns=rna_score_names).sort_index()
    
    adj_pval_df = pval_df.copy()
    for col in pval_df.columns:
        _, adj_pvals_col, _, _ = multipletests(pval_df[col].values, method='fdr_bh')
        adj_pval_df[col] = adj_pvals_col

    pval_df_dict[name] = adj_pval_df
    corr_df_dict[name] = corr_df
    
    mdata.uns['corr_df_signature'] = corr_df
    mdata.uns['pval_df_signature'] = adj_pval_df
    
    sns.set_theme("poster", "white")
    
    plot_corr_volanos(corr_df, adj_pval_df, "Signature", ncols=1, figname = f"{name}_protein_volcano.jpeg")

In [None]:
mdata_masked_dict['cite_6tf'].uns['corr_df_signature'].to_csv("cite_6tf_prot_corr_df_signature.csv")
mdata_masked_dict['cite_6tf'].uns['pval_df_signature'].to_csv("cite_6tf_prot_adj_pval_df_signature.csv")

mdata_masked_dict['cite_imgl'].uns['corr_df_signature'].to_csv("cite_imgl_prot_corr_df_signature.csv")
mdata_masked_dict['cite_imgl'].uns['pval_df_signature'].to_csv("cite_imgl_prot_adj_pval_df_signature.csv")

# EM-seq and RNA-seq correlation

In [None]:
hypo_df.shape, hyper_df.shape

In [None]:
dnmt1_degs_sig_df = dnmt1_degs_df[dnmt1_degs_df.adj_p_weight < 0.05]
dnmt1_degs_sig_df

In [None]:
dnmt1_rna_degs_dict = {}
dnmt1_rna_degs_dict['pos'] = dnmt1_degs_sig_df[dnmt1_degs_sig_df['log2FC'] > 0].gene_ID.tolist()
dnmt1_rna_degs_dict['neg'] = dnmt1_degs_sig_df[dnmt1_degs_sig_df['log2FC'] < 0].gene_ID.tolist()

In [None]:
hypo_df = hypo_df.rename(columns={"external_gene_name":"gene_ID"})
hyper_df = hyper_df.rename(columns={"external_gene_name":"gene_ID"})

In [None]:
hypo_df = hypo_df.rename(columns={"min_fdr":"adj_p_weight"})
hyper_df = hyper_df.rename(columns={"min_fdr":"adj_p_weight"})

In [None]:
overlap_results = defaultdict(dict)

for key, gene_list in dnmt1_rna_degs_dict.items():
    hypo_genes = set(hypo_df['gene_ID'].dropna())
    hyper_genes = set(hyper_df['gene_ID'].dropna())
    ref_genes = set(gene_list)
    overlap_hypo = hypo_genes & ref_genes
    overlap_hyper = hyper_genes & ref_genes
    overlap_results[key]['hypo'] = overlap_hypo
    overlap_results[key]['hyper'] = overlap_hyper
    overlap_results[key]['hypo_count'] = len(overlap_hypo)
    overlap_results[key]['hyper_count'] = len(overlap_hyper)

In [None]:
dnmt1_pos_genes = set(dnmt1_rna_degs_dict['pos'])
dnmt1_neg_genes = set(dnmt1_rna_degs_dict['neg'])

hypo_pos = hypo_df[hypo_df['gene_ID'].isin(dnmt1_pos_genes)]
hyper_pos = hyper_df[hyper_df['gene_ID'].isin(dnmt1_pos_genes)]

hypo_neg = hypo_df[hypo_df['gene_ID'].isin(dnmt1_neg_genes)]
hyper_neg = hyper_df[hyper_df['gene_ID'].isin(dnmt1_neg_genes)]

hypo_pos['source'] = 'hypo'
hyper_pos['source'] = 'hyper'
pos_merged = pd.concat([hypo_pos, hyper_pos], ignore_index=True)

hypo_neg['source'] = 'hypo'
hyper_neg['source'] = 'hyper'
neg_merged = pd.concat([hypo_neg, hyper_neg], ignore_index=True)


In [None]:
pos_merged_dnmt1 = pos_merged.merge(dnmt1_degs_sig_df, on="gene_ID")[["gene_ID", 'mean_diff', 'adj_p_weight_x', "source", "log2FC", 'adj_p_weight_y']]

neg_merged_dnmt1 = neg_merged.merge(dnmt1_degs_sig_df, on="gene_ID")[["gene_ID", 'mean_diff', 'adj_p_weight_x', "source", "log2FC", 'adj_p_weight_y']]

In [None]:
all_merged = pd.concat([pos_merged_dnmt1, neg_merged_dnmt1])