In [None]:
import dill
import pathlib as pl

import os
from tqdm.notebook import tqdm

import scanpy as sc

import pandas as pd

import palettable

import numpy as np

In [None]:
def pretty_ax(ax):
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(
        axis='both',  
        which='both',      
        bottom=True,     
        top=False,
        left=False,
        labelbottom=True,
        labelleft = True)
    ax.spines["bottom"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)

In [None]:
work_dir = pl.Path("/add/path/here/")

refined_wcancer = pd.read_csv("/add/path/here/refined_annotations_wsampleid.csv",index_col=0)

In [None]:
all_scplus = {}
for f in tqdm(work_dir.iterdir()):
    sample_name = f.stem
    if sample_name=="CCG1153_4411":
        continue
    print(sample_name)
    infile = open(f / 'scplus_obj.pkl', 'rb')
    all_scplus[sample_name] = dill.load(infile)
    infile.close()

In [None]:
from scenicplus.utils import format_egrns
for sample_name in tqdm(all_scplus):
    format_egrns(all_scplus[sample_name], eregulons_key = 'eRegulons_importance', TF2G_key = 'TF2G_adj', key_added = 'eRegulon_metadata')

In [None]:
all_scplus[sample_name].uns['eRegulon_metadata'][all_scplus[sample_name].uns['eRegulon_metadata'].TF=="BNC2"]

In [None]:
from scenicplus.eregulon_enrichment import *
for sample_name in tqdm(all_scplus):
    get_eRegulons_as_signatures(all_scplus[sample_name], eRegulon_metadata_key='eRegulon_metadata', key_added='eRegulon_signatures')

In [None]:
from scenicplus.cistromes import *
import time

for sample_name in tqdm(all_scplus):
    start_time = time.time()
    region_ranking = make_rankings(all_scplus[sample_name], target='region')
    # Score region regulons
    score_eRegulons(all_scplus[sample_name],
                    ranking = region_ranking,
                    eRegulon_signatures_key = 'eRegulon_signatures',
                    key_added = 'eRegulon_AUC',
                    enrichment_type= 'region',
                    auc_threshold = 0.05,
                    normalize = False,
                    n_cpu = 1)
    tm = time.time()-start_time
    print(sample_name,tm/60)

In [None]:
## Score transcriptome layer
# Gene based raking
from scenicplus.cistromes import *
import time

for sample_name in tqdm(all_scplus):
    start_time = time.time()
    gene_ranking = make_rankings(all_scplus[sample_name], target='gene')
    # Score gene regulons
    score_eRegulons(all_scplus[sample_name],
                    gene_ranking,
                    eRegulon_signatures_key = 'eRegulon_signatures',
                    key_added = 'eRegulon_AUC',
                    enrichment_type = 'gene',
                    auc_threshold = 0.05,
                    normalize= False,
                    n_cpu = 1)
    tm = time.time()-start_time
    print(sample_name,tm/60)

In [None]:
for sample_name in tqdm(all_scplus):

    patrefined = refined_wcancer[refined_wcancer.sample_id==sample_name].copy()
    patrefined.index = patrefined.index.str[:-2] + "-" + sample_name

    all_scplus[sample_name].metadata_cell = pd.concat([all_scplus[sample_name].metadata_cell,patrefined.refined_annotations.loc[all_scplus[sample_name].metadata_cell.index]],axis=1)

    subset_cells = all_scplus[sample_name].metadata_cell[~all_scplus[sample_name].metadata_cell["ACC_highlevel_wcancer"].isin(["Other"])].index
    
    all_scplus[sample_name].subset(cells=subset_cells)
    
    all_scplus[sample_name].uns["eRegulon_AUC"]["Gene_based"] = all_scplus[sample_name].uns["eRegulon_AUC"]["Gene_based"].loc[subset_cells]
    
    all_scplus[sample_name].uns["eRegulon_AUC"]["Region_based"] = all_scplus[sample_name].uns["eRegulon_AUC"]["Region_based"].loc[subset_cells]

In [None]:
highlevel_refined = {"Hepatocyte": "Epithelial", 
                     "Carcinoma": "Carcinoma", 
                     "Fibroblast": "Fibroblast", 
                     "Quiescent endothelial cells": "Endothelial", 
                     "Smooth muscle": "Muscle", 
                     "Skeletal muscle": "Muscle",
                     "TAM2": "Myeloid", "TAM3": "Myeloid",
                     "TCD4": "Lymphoid", 
                     "Inflammatory CAF": "Fibroblast", 
                     "Adipose CAF": "Fibroblast",
                     "HGF-CAF": "Fibroblast",
                     "TAM1": "Myeloid", 
                     "Myeloid-HighMT": "Unknown/technical", 
                     "Angiogenic EC": "Endothelial", 
                     "Quiescent EC": "Endothelial", 
                     "Venous EC": "Endothelial",
                     "TCD8": "Lymphoid", 
                     "B": "Lymphoid", 
                     "DC": "Myeloid", 
                     "Hepatic EC": "Endothelial", 
                     "Kupffer cells": "Myeloid", 
                     "NK": "Lymphoid", 
                     "Treg": "Lymphoid", 
                     "StrMus-HighMT": "Unknown/technical", 
                     "T-HighMT": "Unknown/technical", 
                     "Mast": "Myeloid", 
                     "Adipocytes": "Stromal/Muscle", 
                     "Endo-HighMT": "Unknown/technical"}

for sample_name in tqdm(all_scplus):
    all_scplus[sample_name].metadata_cell["highlevel_wcancer"] = all_scplus[sample_name].metadata_cell.refined_wcancer.replace(highlevel_refined)

In [None]:
# Generate pseudobulks
import time

for sample_name in tqdm(all_scplus):
    start_time = time.time()
    df = all_scplus[sample_name].metadata_cell.refined_wcancer.copy()
    df = df.apply(lambda x: x if df.value_counts().loc[x]>=10 else "Other")
    all_scplus[sample_name].metadata_cell["refined_wcancer_red"] = df
    generate_pseudobulks(all_scplus[sample_name],
                             variable = "refined_wcancer_red",
                             auc_key = 'eRegulon_AUC',
                             signature_key = 'Gene_based',
                             nr_cells = 5,
                             nr_pseudobulks = 100,
                             seed=555)
    generate_pseudobulks(all_scplus[sample_name],
                             variable = "refined_wcancer_red",
                             auc_key = 'eRegulon_AUC',
                             signature_key = 'Region_based',
                             nr_cells = 5,
                             nr_pseudobulks = 100,
                             seed=555)
    tm = time.time()-start_time
    print(tm/60)

In [None]:
# Correlation between TF and eRegulons
import time

for sample_name in tqdm(all_scplus):
    start_time = time.time()
    TF_cistrome_correlation(all_scplus[sample_name],
                            variable = 'refined_wcancer_red',
                            auc_key = 'eRegulon_AUC',
                            signature_key = 'Gene_based',
                            out_key = 'ACC_refined_wcancer_eGRN_gene_based')
    TF_cistrome_correlation(all_scplus[sample_name],
                            variable = 'refined_wcancer_red',
                            auc_key = 'eRegulon_AUC',
                            signature_key = 'Region_based',
                            out_key = 'ACC_refined_wcancer_eGRN_region_based')
    tm = time.time()-start_time
    print(tm/60)

In [None]:
# Correlation between region based regulons and gene based regulons
def select_regulons_filter(scplus_obj, corr_lim= 0.2):
    df1 = scplus_obj.uns['eRegulon_AUC']['Gene_based'].copy()
    df2 = scplus_obj.uns['eRegulon_AUC']['Region_based'].copy()
    df1.columns = [x.split('_(')[0] for x in df1.columns]
    df2.columns = [x.split('_(')[0] for x in df2.columns]
    correlations = df1.corrwith(df2, axis = 0)
    correlations = correlations[abs(correlations) > corr_lim]
    # Kepp only activator
    keep = [x for x in correlations.index if '+_+' in x] + [x for x in correlations.index if '+_-' in x] 
    # Keep extended if not direct
    extended = [x for x in keep if 'extended' in x]
    direct = [x for x in keep if not 'extended' in x]
    keep_extended = [x for x in extended if not x.replace('extended_', '') in direct]
    keep = direct + keep_extended
    # Keep regulons with more than 10 genes
    keep_gene = [x for x in scplus_obj.uns['eRegulon_AUC']['Gene_based'].columns if x.split('_(')[0] in keep]
    keep_gene = [x for x in keep_gene if (int(x.split('_(')[1].replace('g)', '')) > 10)]
    keep_all = [x.split('_(')[0] for x in keep_gene]
    keep_region = [x for x in scplus_obj.uns['eRegulon_AUC']['Region_based'].columns if x.split('_(')[0] in keep]
    scplus_obj.uns['selected_eRegulons'] = {}
    scplus_obj.uns['selected_eRegulons']['Gene_based'] = keep_gene
    scplus_obj.uns['selected_eRegulons']['Region_based'] = keep_region
    return scplus_obj

In [None]:
len(all_scplus[sample_name].uns["Cistromes"]["Unfiltered"].keys())

In [None]:
for sample_name in tqdm(all_scplus):
    all_scplus[sample_name] = select_regulons_filter(all_scplus[sample_name])

# Prioritizing TFs

In [None]:
all_tfs = pd.read_csv("/add/path/here/DatabaseExtract_v_1.01.csv",index_col=0)

list_tfs = all_tfs[all_tfs["Is TF?"]=="Yes"]["HGNC symbol"].unique()

scores = pd.read_csv("/add/path/here/adata_cNMF_scores.csv",index_col=0)

## For Carcinoma cells

In [None]:
from scipy.stats import pearsonr
from statsmodels.stats.multitest import multipletests
from adjustText import adjust_text

def get_corr_and_p(all_plot_dfs: pd.DataFrame):
    full_auc = all_plot_dfs.drop(["cNMF_1_score", "cNMF_2_score", "cNMF_3_score",
                              "cNMF_4_score", "cNMF_5_score", "sample_id", "highlevel_refined"],axis=1).copy()

    tf_z_corr = []
    tf_z_ps = []
    tf_z_qs = []
    for score in ["cNMF_1_score", "cNMF_2_score", "cNMF_3_score",
                                  "cNMF_4_score", "cNMF_5_score", ]:
        corrs = full_auc.corrwith(all_plot_dfs[score])
        ps = full_auc.corrwith(all_plot_dfs[score], method=lambda x, y: pearsonr(x, y)[1])
        qs = pd.DataFrame(multipletests(ps.ravel())[1], index=ps.index, columns=[score])
        corrs.name = score
        ps.name = score
        
        tf_z_corr.append(corrs)
        tf_z_ps.append(ps)
        tf_z_qs.append(qs)

    return pd.concat(tf_z_corr,axis=1), pd.concat(tf_z_ps,axis=1), pd.concat(tf_z_qs,axis=1)

In [None]:
trajec_df = []
for sample_name in all_scplus:
    gex_adata = sc.AnnData(all_scplus[sample_name].to_df("EXP").copy())
    gex_adata = gex_adata[:,gex_adata.var_names.intersection(list_tfs)].copy()
    sc.pp.normalize_total(adata=gex_adata, target_sum=10000)
    sc.pp.log1p(gex_adata)

    patscores = scores[(scores.sample_id==sample_name) & (scores.highlevel_refined=="Carcinoma")].copy()
    patscores.index = patscores.index.str[:-2] + "-" + sample_name
    gex_adata = gex_adata[gex_adata.obs_names.intersection(patscores.index)].copy()
    gex_adata.obs = pd.concat([gex_adata.obs, patscores.loc[gex_adata.obs_names.intersection(patscores.index)]],axis=1)
    
    gex_df = pd.DataFrame(gex_adata.X.copy(),
                          index=gex_adata.obs_names,
                          columns=gex_adata.var_names)
    plot_df = pd.concat([gex_df,
                     gex_adata.obs],axis=1)

    trajec_df.append(plot_df)
trajec_df = pd.concat(trajec_df)

In [None]:
trajec_df["Diff"] = trajec_df["cNMF_4_score"] - trajec_df["cNMF_5_score"]

trajec_df = trajec_df.sort_values("Diff").copy()

trajec_df["Bin"] = pd.cut(trajec_df["Diff"], bins=10, labels=np.arange(1,11))

In [None]:
for TF in ["FOXO3", "HNF4A", "KLF3", "ZBTB43", "THRB", "CREB3L1", "MAFK", "PPARD", "LCOR", "FOXO1", "SMAD3", 
           "MECOM", "HMGA2","GRHL2", "ZKSCAN1"]:
    fig, ax = plt.subplots(1,1,figsize=(2,2))
    sns.barplot(data=trajec_df, x="Bin", y=TF, ax=ax)
    ax.set_xlabel("cNMF$_5$ to cNMF$_4$")
    ax.set_xticks([])
    pretty_ax(ax)
    fig.savefig(f"figures/malignant/{TF}_expression_cNMF_4_to_cNMF_5.svg", dpi=200, bbox_inches='tight')

In [None]:
all_plot_dfs = []
for sample_name in all_scplus:
    gex_adata = sc.AnnData(all_scplus[sample_name].to_df("EXP").copy())
    gex_adata = gex_adata[:,gex_adata.var_names.intersection(list_tfs)].copy()
    sc.pp.normalize_total(adata=gex_adata, target_sum=10000)
    sc.pp.log1p(gex_adata)

    patscores = scores[(scores.sample_id==sample_name) & (scores.highlevel_refined=="Carcinoma")].copy()
    patscores.index = patscores.index.str[:-2] + "-" + sample_name
    gex_adata = gex_adata[gex_adata.obs_names.intersection(patscores.index)].copy()
    gex_adata.obs = pd.concat([gex_adata.obs, patscores.loc[gex_adata.obs_names.intersection(patscores.index)]],axis=1)
    
    gex_df = pd.DataFrame(gex_adata.X.copy(),
                          index=gex_adata.obs_names,
                          columns=gex_adata.var_names)
    plot_df = pd.concat([gex_df,
                     gex_adata.obs],axis=1)

    all_plot_dfs.append(plot_df)
all_plot_dfs = pd.concat(all_plot_dfs)

tf_z_corr, tf_z_ps, tf_z_qs = get_corr_and_p(all_plot_dfs)

all_plot_dfs = []
for sample_name in tqdm(all_scplus):
    auc_alltfs = []
    patscores = scores[(scores.sample_id==sample_name) & (scores.highlevel_refined=="Carcinoma")].copy()
    patscores.index = patscores.index.str[:-2] + "-" + sample_name
    
    for tf in list_tfs:
        auc_df = all_scplus[sample_name].uns["eRegulon_AUC"]["Gene_based"].copy()
        auc_df = auc_df.loc[:,auc_df.columns.str.contains(tf)]
        if auc_df.shape[1]==0:
            continue
        
        if auc_df.shape[1]>1:
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("extended")]
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("+_-", regex=False)]
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("-_-", regex=False)]
        
        auc_alltfs.append(auc_df)
    auc_alltfs = pd.concat(auc_alltfs,axis=1)
    auc_alltfs = auc_alltfs.groupby(level=0,axis=1).mean()
    auc_alltfs.columns = auc_alltfs.columns.str.split("_").str[0]
    auc_alltfs = pd.concat([auc_alltfs, patscores],
                           axis=1).dropna()
    all_plot_dfs.append(auc_alltfs)

all_plot_dfs = pd.concat(all_plot_dfs).fillna(0)

gene_based_z_corr, gene_based_z_ps, gene_based_z_qs = get_corr_and_p(all_plot_dfs)

all_plot_dfs = []
for sample_name in tqdm(all_scplus):
    auc_alltfs = []
    patscores = scores[(scores.sample_id==sample_name) & (scores.highlevel_refined=="Carcinoma")].copy()
    patscores.index = patscores.index.str[:-2] + "-" + sample_name
    
    for tf in list_tfs:
        auc_df = all_scplus[sample_name].uns["eRegulon_AUC"]["Region_based"].copy()
        auc_df = auc_df.loc[:,auc_df.columns.str.contains(tf)]
        if auc_df.shape[1]==0:
            continue
        
        if auc_df.shape[1]>1:
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("extended")]
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("+_-", regex=False)]
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("-_-", regex=False)]
        
        auc_alltfs.append(auc_df)
    auc_alltfs = pd.concat(auc_alltfs,axis=1)
    auc_alltfs = auc_alltfs.groupby(level=0,axis=1).mean()
    auc_alltfs.columns = auc_alltfs.columns.str.split("_").str[0]
    auc_alltfs = pd.concat([auc_alltfs, patscores],
                           axis=1).dropna()
    all_plot_dfs.append(auc_alltfs)

all_plot_dfs = pd.concat(all_plot_dfs).fillna(0)

region_based_z_corr, region_based_z_ps, region_based_z_qs = get_corr_and_p(all_plot_dfs)

triad_corrs = {}
toptfs = {}
for score in ["cNMF_1_score", "cNMF_2_score", "cNMF_3_score",
                              "cNMF_4_score", "cNMF_5_score", ]:
    triad_corrs[score] = pd.concat([tf_z_corr[score],
           gene_based_z_corr[score],
           region_based_z_corr[score]],axis=1)
    seltf = tf_z_corr[(tf_z_corr[score]>0.1) & (tf_z_qs[score]<0.1)].index
    selgene = gene_based_z_corr[(gene_based_z_corr[score]>0.1) & (gene_based_z_qs[score]<0.1)].index
    selreg = region_based_z_corr[(region_based_z_corr[score]>0.1) & (region_based_z_qs[score]<0.1)].index
    
    common = seltf.intersection(selgene)
    common = common.intersection(selreg)
    toptfs[score] = triad_corrs[score].dropna().mean(axis=1).loc[common].sort_values(ascending=False).head(20).index

In [None]:
df_toptfs = pd.DataFrame(dict([ (k,pd.Series(v)) for k,v in toptfs.items() ]))

In [None]:
for state in triad_corrs:
    df = triad_corrs[state].copy()
    df.columns = ["TF GEX r", "eReg. Gene r", "eReg. Reg. r"]
    df.to_csv(f"/add/path/here/{state}_triad_corr.csv")

In [None]:
pd.DataFrame(df_toptfs).to_csv("/add/path/here/toptfs_top20.csv")

## For TME cells

In [None]:
from scenicplus.RSS import *
for sample_name in tqdm(all_scplus):
    
    regulon_specificity_scores(all_scplus[sample_name],
                         'refined_wcancer_red',
                         signature_keys=['Gene_based'],
                         selected_regulons=all_scplus[sample_name].uns['selected_eRegulons']['Gene_based'],
                         out_key_suffix='_gene_based',
                         scale=False)

In [None]:
rank_df_TME = {}
for ct in ["Hepatic EC", "Kupffer cells", "TAM2", "TAM1", "Quiescent EC", "Inflammatory CAF", "HGF-CAF", "Fibroblast", 
           "DC", "B", "TCD4", "TCD8", "Angiogenic EC","NK","Treg","Skeletal muscle", "Smooth muscle"]:
    print(ct)
    rank_df_TME[ct] = []
    for sample_name in tqdm(all_scplus):
        df = all_scplus[sample_name].uns["RSS"]['refined_wcancer_red_gene_based']
        if ct in df.index:
            print(sample_name)
            seldf = df.loc[ct,:].sort_values(ascending=False).head(10)
            #seldf = seldf[seldf>0.15]
            seldf.name = sample_name
            seldf.index = seldf.index.str.split("_").str[0]
            rank_df_TME[ct].append(seldf.groupby(level=0).median())
    if len(rank_df_TME[ct])==0:
        continue
    else:
        rank_df_TME[ct] = pd.concat(rank_df_TME[ct],axis=1)

In [None]:
seltfs_TME = {}
for state in rank_df_TME.keys():
    if len(rank_df_TME[state])>0:
        seltfs_TME[state] = rank_df_TME[state][(~rank_df_TME[state].isna()).sum(axis=1)>=2]

In [None]:
all_seltfs_TME = []
for state in seltfs_TME:
    all_seltfs_TME.append(seltfs_TME[state].index)
all_seltfs_TME = np.unique(np.hstack(all_seltfs_TME))

# Plotting SCENIC+ results

## cNMF results

In [None]:
def pretty_ax(ax, linew: float=1.5):
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(
        axis='both',  
        which='both',      
        bottom=True,     
        top=False,
        left=False,
        labelbottom=True,
        labelleft = True)
    ax.spines["bottom"].set_linewidth(linew)
    ax.spines["left"].set_linewidth(linew)

In [None]:
for i,state in enumerate([f"cNMF_{i}" for i in range(1,6)]):

    df = triad_corrs[f"{state}_score"].copy().dropna()
    df.columns = ["TF GEX r", "eReg. Gene r", "eReg. Region r"]
    
    fig, ax = plt.subplots(1,1,figsize=(3,3))
    
    sns.scatterplot(data=df, x="eReg. Gene r", y="eReg. Region r", hue="TF GEX r", 
                    palette="vlag", hue_norm=(-0.4,0.4), ax=ax)
    TFs_candidate = toptfs[f"{state}_score"]
    pretty_ax(ax)
    plt.legend(frameon=False, bbox_to_anchor=(1,1,0,0), 
               title="TF GEX r", fontsize=13, prop={"size": 13})
    #ax.get_legend().get_frame().set_facecolor('none')
    #ax.get_legend().get_frame().set_linewidth(0.0)
    #ax.get_legend().set(bbox_to_anchor=(1,1,0,0))
    ax.set_title(f"cNMF$_{i+1}$ top TFs", fontsize=13)

    ax.xaxis.set_tick_params(labelsize=13)
    ax.yaxis.set_tick_params(labelsize=13)
    ax.set_xlabel("Gene-based eReg. r", fontsize=13)
    ax.set_ylabel("Region-based eReg. r", fontsize=13)
    pretty_ax(ax, linew=3)
    
    
    
    if len(TFs_candidate)>0:
        texts = []
        for g in TFs_candidate:
            x = df.loc[g,"eReg. Gene r"]
            y = df.loc[g,"eReg. Region r"]
            texts.append(ax.text(x,y,g,fontsize=13,c="red"))
        adjust_text(texts, only_move={'points':'y', 'texts':'y'}, arrowprops=dict(arrowstyle="-", color='r', lw=1.5))

    fig.savefig(f"figures/{state}_scenicplus_discovery.png", 
            dpi=200, bbox_inches="tight")
    fig.savefig(f"figures/{state}_scenicplus_discovery.svg", 
            dpi=200, bbox_inches="tight")

## TME results

In [None]:
os.makedirs("figures/TME", exist_ok=True)

In [None]:
tf_order = ["ERG", "ELK3", "GATA4",'SMAD9',"HLX", "ETS1", "FLI1", 'KLF2', "RUNX3",'SOX17',
            "IKZF1","IKZF3","NFATC3",
            "IRF8","SPI1","IRF5","RBPJ","ETV5", "MAFB","MITF","MEF2C",
            "BNC2","RUNX1","RUNX2","PRRX1","NFATC4",'MEIS1',
            "RARB","SOX5","TCF4", ]

In [None]:
celltype_order = ["Angiogenic EC", "Hepatic EC", "Quiescent EC", "B", "TCD4", "TCD8", "Treg", "NK",
                  "TAM1","TAM2", "DC","Kupffer cells",
                   "Inflammatory CAF", "HGF-CAF", "Fibroblast", "Smooth muscle", "Skeletal muscle",
                  "cNMF_2","cNMF_3","cNMF_4","Carcinoma"]

In [None]:
all_plot_dfs = []
for sample_name in all_scplus:
    gex_adata = sc.AnnData(all_scplus[sample_name].to_df("EXP").copy())
    sc.pp.normalize_total(adata=gex_adata, target_sum=10000)
    sc.pp.log1p(gex_adata)
    gex_df = pd.DataFrame(gex_adata[:,all_seltfs_TME].X.copy(),
                          index=all_scplus[sample_name].cell_names,
                          columns=all_seltfs_TME)
    plot_df = pd.concat([gex_df,
                     all_scplus[sample_name].metadata_cell["refined_wcancer_red"]],axis=1)

    all_plot_dfs.append(plot_df)
all_plot_dfs = pd.concat(all_plot_dfs)

full_gex = all_plot_dfs.iloc[:,:-1].copy()
full_gex = (full_gex - full_gex.mean())/full_gex.std()
all_plot_dfs = pd.concat([full_gex,all_plot_dfs.iloc[:,-1]],axis=1)

all_plot_dfs = all_plot_dfs.groupby(by="refined_wcancer_red").mean()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(10,8))
sns.heatmap(data=all_plot_dfs.loc[celltype_order,tf_order], cmap="vlag", linewidths=1, linecolor="grey",
            center=0, vmin=-1, vmax=2, ax=ax, cbar_kws={"label": "TF Z-score"})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.set_ylabel("")
fig.savefig("figures/TME/TME_scenic_TF_zscore.svg", dpi=200, bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1,1,figsize=(3,2))
sns.heatmap(data=all_plot_dfs.loc[["Inflammatory CAF","HGF-CAF","Fibroblast"],
            ["BNC2","RUNX1","RUNX2","PRRX1","NFATC4",'MEIS1',]], linewidths=1, linecolor="grey",
            cmap="vlag", center=0, vmin=-1, vmax=2, ax=ax, cbar_kws={"label": "TF Z-score"})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.set_ylabel("")
fig.savefig("figures/TME/TME_Fibroblasts_scenic_TF_zscore.svg", dpi=200, bbox_inches="tight")

In [None]:
all_plot_dfs = []
for sample_name in all_scplus:
    auc_alltfs = []
    for tf in all_seltfs_TME:
        auc_df = all_scplus[sample_name].uns["eRegulon_AUC"]["Gene_based"].copy()
        auc_df = auc_df.loc[:,auc_df.columns.str.contains(tf)]
        if auc_df.shape[1]==0:
            continue
        
        if auc_df.shape[1]>1:
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("extended")]
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("+_-", regex=False)]
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("-_+", regex=False)]
        auc_df.columns = auc_df.columns.str.split("_").str[0]
        auc_alltfs.append(auc_df)
    auc_alltfs = pd.concat(auc_alltfs,axis=1)
    auc_alltfs = pd.concat([auc_alltfs, all_scplus[sample_name].metadata_cell["refined_wcancer_red"]],
                           axis=1)
    all_plot_dfs.append(auc_alltfs)

all_plot_dfs = pd.concat(all_plot_dfs).fillna(0)

full_auc = all_plot_dfs.drop("refined_wcancer_red",axis=1).copy()
full_auc = (full_auc - full_auc.mean())/full_auc.std()
all_plot_dfs = pd.concat([full_auc,all_plot_dfs["refined_wcancer_red"]],axis=1)

all_plot_dfs = all_plot_dfs.groupby(by="refined_wcancer_red").mean()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8,8))
sns.heatmap(data=all_plot_dfs.loc[celltype_order,tf_order], cmap="vlag", center=0, linewidths=1, linecolor="grey",
            vmin=-1, vmax=2, ax=ax, cbar_kws={"label": "eRegulon Gene-based expression, Z-score"})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.set_ylabel("")
fig.savefig("figures/TME/TME_scenic_eRegulon_gene_based_zscore.svg", dpi=200, bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1,1,figsize=(3,2))
sns.heatmap(data=all_plot_dfs.loc[["Inflammatory CAF","HGF-CAF","Fibroblast"],
            ["BNC2","RUNX1","RUNX2","PRRX1","NFATC4",'MEIS1',]], cmap="vlag", center=0, vmin=-1, vmax=3, linewidths=1, linecolor="grey",
            ax=ax, cbar_kws={"label": "eRegulon Gene-based expression, Z-score"})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.set_ylabel("")
fig.savefig("figures/TME/TME_Fibroblasts_eRegulon_gene_based_zscore.svg", dpi=200, bbox_inches="tight")

In [None]:
all_plot_dfs = []
for sample_name in all_scplus:
    auc_alltfs = []
    for tf in all_seltfs_TME:
        auc_df = all_scplus[sample_name].uns["eRegulon_AUC"]["Region_based"].copy()
        auc_df = auc_df.loc[:,auc_df.columns.str.contains(tf)]
        if auc_df.shape[1]==0:
            continue
        
        if auc_df.shape[1]>1:
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("extended")]
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("+_-", regex=False)]
            auc_df = auc_df.loc[:,~auc_df.columns.str.contains("-_+", regex=False)]
        auc_df.columns = auc_df.columns.str.split("_").str[0]
        auc_alltfs.append(auc_df)
    auc_alltfs = pd.concat(auc_alltfs,axis=1)
    auc_alltfs = pd.concat([auc_alltfs, all_scplus[sample_name].metadata_cell["refined_wcancer_red"]],
                           axis=1)
    all_plot_dfs.append(auc_alltfs)

all_plot_dfs = pd.concat(all_plot_dfs).fillna(0)

full_auc = all_plot_dfs.drop("refined_wcancer_red",axis=1).copy()
full_auc = (full_auc - full_auc.mean())/full_auc.std()
all_plot_dfs = pd.concat([full_auc,all_plot_dfs["refined_wcancer_red"]],axis=1)

all_plot_dfs = all_plot_dfs.groupby(by="refined_wcancer_red").mean()

In [None]:
fig, ax = plt.subplots(1,1,figsize=(8,8))
sns.heatmap(data=all_plot_dfs.loc[celltype_order,tf_order], cmap="vlag", center=0, vmin=-1, vmax=2, linewidths=1, linecolor="grey",
            ax=ax, cbar_kws={"label": "eRegulon Region-based expression, Z-score"})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.set_ylabel("")
fig.savefig("figures/TME/TME_scenic_eRegulon_region_based_zscore.svg", dpi=200, bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1,1,figsize=(3,2))
sns.heatmap(data=all_plot_dfs.loc[["Inflammatory CAF","HGF-CAF","Fibroblast"],
            ["BNC2","RUNX1","RUNX2","PRRX1","NFATC4",'MEIS1',]], cmap="vlag", center=0, vmin=-1, vmax=2, linewidths=1, linecolor="grey",
            ax=ax, cbar_kws={"label": "eRegulon Region-based expression, Z-score"})
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.set_ylabel("")
fig.savefig("figures/TME/TME_Fibroblasts_eRegulon_region_based_zscore.svg", dpi=200, bbox_inches="tight")