In [1]:
import os
import glob
import pickle
import pandas as pd
import numpy as np
from dask.diagnostics import ProgressBar
from arboreto.utils import load_tf_names
from arboreto.algo import grnboost2
from ctxcore.rnkdb import FeatherRankingDatabase as RankingDatabase
from pyscenic.utils import modules_from_adjacencies, load_motifs
from pyscenic.prune import prune2df, df2regulons
from pyscenic.aucell import aucell
import seaborn as sns
import scanpy as sc
from pyscenic.rss import regulon_specificity_scores

# load adata for count matrix

In [2]:
adata = sc.read('/home/jovyan/scripts/renal_covid_19/steroid_pipeline/trajectory_with_Rik_model_output/wave2_steroid_2021_covid_GPLVM.h5ad')

In [3]:
adata

AnnData object with n_obs × n_vars = 17256 × 32913
    obs: 'annot4', 'centre', 'sample_id', 'sample_id_broad', 'sample_date', 'sample_date_yr', 'pool', 'pool_broad', 'haniffa_broad_predLabel', 'orig.ident', 'merged_souporcell_cluster', 'merged_souporcell_status', 'patient_id', 'case_control', 'WHO_severity', 'WHO_temp_severity', 'sex', 'calc_age', 'discharge_date', 'date_positive_swab', 'date_first_symptoms', 'admission_date', 'ethnicity', 'individual_id', 'pseudotime_GPLVM', 'pseudobatch_GPLVM', 'steroid_date', 'days_from_steroid', 'time_from_first_symptoms', 'time_from_positive_swab', 'time_from_infection'
    var: 'GEX'

In [4]:
'C141' in list(adata.obs['sample_id_broad'].unique())

False

In [5]:
adata.obs['days_from_steroid']=adata.obs['days_from_steroid'].astype(float)

In [6]:
adata.obs['steroid_timeline']='unknown'
adata.obs.loc[adata.obs["days_from_steroid"]<=0, "steroid_timeline"] = 'before_steroid'
adata.obs.loc[adata.obs["days_from_steroid"]>0, "steroid_timeline"] = 'after_steroid'

In [7]:
adata.obs['steroid_timeline']=adata.obs['steroid_timeline'].astype('category')

In [8]:
adata.obs['steroid_timeline'].unique()

['after_steroid', 'before_steroid']
Categories (2, object): ['after_steroid', 'before_steroid']

In [9]:
list(adata.obs['annot4'].unique())

['HSPC',
 'Platelet',
 'RBC',
 'CD14mono',
 'CD16mono',
 'Int.mono',
 'CD14mono_anti_inflammatory',
 'CD14mono_IFN',
 'CD14mono_activated',
 'CD16mono_IFN',
 'CD16mono_C1',
 'DC3_IFN',
 'DC2',
 'pDC',
 'DC3',
 'DC1',
 'ASDC']

In [10]:
adata = adata[adata.obs['annot4'].isin(['CD14mono','CD14mono_IFN','CD14mono_activated','CD14mono_anti_inflammatory'])] #'CD14mono','CD14mono_IFN','CD14mono_activated',

In [11]:
adata.obs['annot4'].unique()

['CD14mono', 'CD14mono_anti_inflammatory', 'CD14mono_IFN', 'CD14mono_activated']
Categories (4, object): ['CD14mono', 'CD14mono_IFN', 'CD14mono_activated', 'CD14mono_anti_inflammatory']

In [12]:
#gettting count matrix
ex_matrix = pd.DataFrame(adata.X.toarray())
ex_matrix.columns = adata.var.index
ex_matrix.index = adata.obs.index

In [13]:
# define data folder and files
DATA_FOLDER="/home/jovyan/scripts/renal_covid_19/steroid_pipeline/regulon_analysis/"
RESOURCES_FOLDER="/lustre/scratch117/cellgen/team298/win/regulon_own_data/"
DATABASES_GLOB = os.path.join(RESOURCES_FOLDER, "hg38__refseq-r80__10kb_up_and_down_tss.mc9nr.feather")#from https://resources.aertslab.org/cistarget/
MOTIF_ANNOTATIONS_FNAME = os.path.join(RESOURCES_FOLDER, "motifs-v9-nr.hgnc-m0.001-o0.0.tbl") #from https://resources.aertslab.org/cistarget/
MM_TFS_FNAME = os.path.join(RESOURCES_FOLDER, 'lambert2018.txt') # from https://github.com/aertslab/pySCENIC/blob/master/resources/lambert2018.txt
REGULONS_FNAME = os.path.join(DATA_FOLDER, "CD14_wave2_steroid_2021.p") #CD14_wave2_steroid_2021.p
MOTIFS_FNAME = os.path.join(DATA_FOLDER, "CD14_wave2_steroid_2021.csv")


In [None]:
DATABASES_GLOB

In [None]:
ex_matrix.head()

# calculate adjacencies - co expression genes and TF

In [None]:
tf_names = load_tf_names(MM_TFS_FNAME)

In [None]:
db_fnames = glob.glob(DATABASES_GLOB)
def name(fname):
    return os.path.splitext(os.path.basename(fname))[0]
dbs = [RankingDatabase(fname=fname, name=name(fname)) for fname in db_fnames]
dbs

In [None]:
adjacencies = grnboost2(ex_matrix, tf_names=tf_names, verbose=True)


In [None]:
adjacencies

In [None]:
adjacencies.to_csv('/home/jovyan/scripts/renal_covid_19/steroid_pipeline/regulon_analysis/adjacencies_CD14_wave2_steroid_2021.csv')

In [None]:
# Derive modules from adjacencies
modules = list(modules_from_adjacencies(adjacencies, ex_matrix))

# pruning for enriched motif

In [None]:
# Calculate a list of enriched motifs and the corresponding target genes for all modules.
with ProgressBar():
    df = prune2df(dbs, modules, MOTIF_ANNOTATIONS_FNAME)
# Create regulons from this table of enriched motifs.
regulons = df2regulons(df)

In [None]:
df.shape

In [None]:
df.to_csv(MOTIFS_FNAME)
with open(REGULONS_FNAME, "wb") as f:
    pickle.dump(regulons, f)


# for stringDB analysis

In [None]:
# getting high var

In [None]:
adata

In [14]:
adata.obs['annot4'].unique()

['CD14mono', 'CD14mono_anti_inflammatory', 'CD14mono_IFN', 'CD14mono_activated']
Categories (4, object): ['CD14mono', 'CD14mono_IFN', 'CD14mono_activated', 'CD14mono_anti_inflammatory']

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)

In [None]:
sc.pp.log1p(adata)

In [None]:
# find highly variable genes
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5,n_top_genes=2000)
sc.pl.highly_variable_genes(adata)

In [None]:
# remove vdj and light chain constant genes from highly variable genes, and also the viral reads
import re
for i in adata.var.index:
    if re.search('^IG[HKL][VDJC]|VIRAL', i):
        adata.var.at[i, 'highly_variable'] = False
sc.pl.highly_variable_genes(adata)

In [None]:
adata

In [None]:
# transfer to .raw slot
adata.raw = adata

In [None]:
adata = adata[:, adata.var.highly_variable].copy()
adata

In [None]:
highvar= (list(adata.var.index))

In [15]:
with open(REGULONS_FNAME,'rb') as f:
    data = pickle.load(f)

In [None]:
#data[0]

In [None]:
data

In [72]:
genes = {}
for i in range(0, len(data)):
    if data[i].name in ['MAZ(+)']:
        genes[data[i].name] = list(data[i].gene2weight.keys())

In [73]:
genes

{'MAZ(+)': ['BAG6',
  'EPN1',
  'NOSIP',
  'CAP1',
  'UQCRQ',
  'ARPC3',
  'ROMO1',
  'NDUFS3',
  'SEC24C',
  'COX5B',
  'RPL27A',
  'UBL5',
  'WDR83OS',
  'RPS19',
  'COG3',
  'CIB1',
  'EIF3I',
  'SLC35C2',
  'CDC123',
  'POLDIP2',
  'UQCRH',
  'LMAN2',
  'ABHD17A',
  'SF3B4',
  'ANAPC15',
  'PATL1',
  'SRP14',
  'RAB11A',
  'SNRPE',
  'URM1',
  'KAT5',
  'TADA3',
  'CHMP2A',
  'EIF3L',
  'NOC3L',
  'NIPBL',
  'MARK2',
  'RPL37',
  'MRPL3',
  'SART1',
  'DR1',
  'C11orf24',
  'PDAP1',
  'AP1M1',
  'GNAI2',
  'RPL32',
  'DHX36',
  'RPS21',
  'SNRPD1',
  'CHCHD5',
  'FKBP2',
  'STARD3',
  'BANF1',
  'EIF1AD',
  'EIF3M',
  'GABARAPL2',
  'PUF60',
  'RPS14',
  'SLC30A7',
  'NORAD',
  'UBALD1',
  'LSM1',
  'ACTB',
  'HINT1',
  'VMA21',
  'WDR45',
  'ELK3',
  'LARP4',
  'PET100',
  'SCNM1',
  'GSPT1',
  'HSP90AB1',
  'RPL6',
  'CFL1',
  'C14orf119',
  'EIF6',
  'UBR1',
  'CELF1',
  'SLC35A4',
  'RPL7L1',
  'TINF2',
  'UFC1',
  'MRPS21',
  'SNRPB',
  'RPS6',
  'PEX16',
  'PHB2',
  'SPG21',


In [74]:
allgenes = []
for k, r in genes.items():
    for rr in r:
        allgenes.append(rr)
allgenes = list(set(allgenes))

In [75]:
len(allgenes)

239

In [76]:
len(allgenes)

239

In [77]:
import csv
newfilePath = '/home/jovyan/scripts/renal_covid_19/steroid_pipeline_corrected/steroid_pipeline/regulon_analysis/target_genes1.csv'
with open(newfilePath, "w") as f:
    writer = csv.writer(f)
    writer.writerow(allgenes)

In [None]:
len(allgenes)

In [None]:
d = [allgenes,highvar]

In [None]:
selected = set.intersection(*map(set,d))

In [None]:
len(list(selected))

In [None]:
to_plot = list(selected)

In [None]:
len(to_plot)

In [None]:
to_plot_YY1 = allgenes

In [None]:
to_plot_YY1

In [None]:
to_plot_AHR = to_plot

In [None]:
to_plot_ATF4 = to_plot

In [None]:
to_plot_STAT1 = allgenes

In [None]:
to_plot_STAT2 = allgenes

In [None]:
to_plot_IRF7 = allgenes

In [None]:
final = to_plot_STAT1 + to_plot_YY1

In [None]:
to_plot_STAT1[1:300]

In [None]:
all_nonYY1 =to_plot_STAT1+to_plot_STAT2+to_plot_IRF7

In [None]:
s1y =list(set.intersection(*map(set,[all_nonYY1,to_plot_YY1])))

In [None]:
(s1y)

In [None]:
len(all_nonYY1)

In [None]:
len(to_plot_STAT1+to_plot_STAT2+to_plot_IRF7)

In [None]:
s2y =list(set.intersection(*map(set,[to_plot_STAT2,to_plot_YY1])))

In [None]:
s2y

In [None]:
i7y =list(set.intersection(*map(set,[to_plot_IRF7,to_plot_YY1])))

In [None]:
(i7y)

In [None]:
list_to_check =['CORO1A',
'ACTR3',
'EFHD2',
'ARPC2',
'ARPC3',
'RCSD1',
'CAPZA1',
'DBNL',
'CAP1',
'CFL1',
'CORO1C',
'VASP',
'PFN1',
'MYH9',
'ACTB',
'DYNC1H1',
'CBX5',
'H3F3A',
'CDK13',
'SMARCA2',
'POLR2A',
'CHD1',
'SUPT4H1',
'STK24',
'MST4',
'ARF1',
'RAB11A',
'GDI2',
'ATP6V0E1',
'SUPT16H',
'PRPSAP1',
'PRPS1',
'SF3B2',
'SRRM1',
'LSM1',
'PPIG',
'PRPF6',
'RBM25',
'PRPF4B',
'SF3B6',
'SNRPE',
'LSM3',
'SLU7',
'MAGOH',
'ZMAT2',
'SF3B4',
'SEPT7',
'CDC42EP4',
'SEPT6',
'SEPT9',
'ATP6V1D',
'ATP6V1A',
'RAB7A',
'MON1B',
'SMARCA5',
'BAZ2A',
'RAB2A',
'RCOR1',
'PHF21A',
'OSBPL1A',
'FUBP3',
'SYNCRIP',
'FUS',
'HNRNPA3',
'HNRNPD',
'HNRNPM',
'HNRNPK',
'QKI',
'STAG2',
'RAD21',
'LPXN',
'TCF3',
'ID2',
'RALY',
'HNRNPL',
'LSM14A',
'HNRNPF',
'PABPC1',
'EIF4E',
'HNRNPUL1',
'COX6A1',
'UQCRC2',
'NDUFS2',
'COX8A',
'COX5B',
'NDUFS4',
'CALM1',
'PPP3CA',
'LARP4B',
'ERP44',
'PRDX4',
'KHDRBS1',
'RBM15',
'PTK2B',
'MEF2C',
'GAB2',
'GRB2',
'PIK3CA',
'SRC',
'ADAM15',
'CBL',
'BTK',
'YWHAB',
'NEDD9',
'RBBP6',
'C20orf27',
'PPP1CC',
'ERBB2IP',
'TGFBR1',
'SMAD3',
'FOXO3',
'GTF2I',
'GNB2L1',
'EIF3C',
'EIF3I',
'RPS21',
'RPS7',
'RPS12',
'RPS14',
'PLEC',
'TRMT112',
'VDAC1',
'IMP3',
'HEATR1',
'BTF3',
'RPL28',
'RPL29',
'RPL23',
'SRPR',
'MRPL27',
'EIF6',
'NSA2',
'UBC',
'USP46',
'TP53BP1',
'AZI2',
'TAX1BP1',
'JOSD2',
'UBL7',
'OTUB1',
'UBE2D3',
'RNF185',
'USP3',
'MYCBP2',
'PSMA3',
'XIAP',
'RNF4',
'HTRA2',
'POMP',
'IRAK1',
'TAB2',
'PSMB3',
'PSMB7',
'EIF1',
'PTPN1',
'PTPRE'


]

In [None]:
for i in range (0,len(s1y)):
    print ('gene name is',s1y[i])
    print ('STAT1',s1y[i] in to_plot_STAT1)
    #print ('STAT2',s1y[i] in to_plot_STAT2)
    #print ('IRF7',s1y[i] in to_plot_IRF7)
    print ('YY1',s1y[i] in to_plot_YY1)
   # print ('ATF4',s1y[i] in to_plot_ATF4)
    #print ('AHR',s1y[i] in to_plot_AHR)
    print ('=====================')

In [None]:
list1 = [to_plot_STAT1,to_plot_YY1]

In [None]:
list2 = [to_plot_STAT1,to_plot_ATF4]

In [None]:
list3 = [to_plot_STAT1,to_plot_AHR]

In [None]:
list4 = [to_plot_STAT2,to_plot_YY1]

In [None]:
list5 = [to_plot_STAT2,to_plot_ATF4]

In [None]:
list6 = [to_plot_STAT2,to_plot_AHR]

In [None]:
list7 = [to_plot_IRF7,to_plot_YY1]

In [None]:
list8 = [to_plot_IRF7,to_plot_ATF4]

In [None]:
list9 = [to_plot_IRF7,to_plot_AHR]

In [None]:
list_to_intersect1 =list(set.intersection(*map(set,list1)))

In [None]:
list_to_intersect2 =list(set.intersection(*map(set,list2)))

In [None]:
list_to_intersect3 =list(set.intersection(*map(set,list3)))

In [None]:
list_to_intersect4 =list(set.intersection(*map(set,list4)))

In [None]:
list_to_intersect5 =list(set.intersection(*map(set,list5)))

In [None]:
list_to_intersect6 =list(set.intersection(*map(set,list6)))

In [None]:
list_to_intersect7 =list(set.intersection(*map(set,list7)))

In [None]:
list_to_intersect8 =list(set.intersection(*map(set,list8)))

In [None]:
list_to_intersect9 =list(set.intersection(*map(set,list9)))

In [None]:
len(list_to_intersect9)

In [None]:
final_to_plot = list_to_intersect1 +list_to_intersect2+list_to_intersect3+list_to_intersect4 +list_to_intersect5 +list_to_intersect6 +list_to_intersect7 +list_to_intersect8 +list_to_intersect9

In [None]:
final_plot = list(set(final_to_plot))

In [None]:
len(final_plot)

In [None]:
final_plot

In [None]:
name = 'CORO1A'

In [None]:
name in to_plot_YY1

In [None]:
name in to_plot_ATF4

In [None]:
name in to_plot_AHR

In [None]:
%%%%%%%%%%%%%%%%%%%%%%%

In [None]:
name in to_plot_STAT1

In [None]:
name in to_plot_STAT2

In [None]:
name in to_plot_IRF7

In [None]:
to_check = [final_plot,highvar1 ]

In [None]:
to_check1 =list(set.intersection(*map(set,to_check)))

In [None]:
to_check1

In [None]:
tosave = list(set(aaaa))

In [None]:
len(tosave)

In [None]:
import csv

with open('/home/jovyan/scripts/renal_covid_19/steroid_pipeline/strem_DB.csv', 'w') as myfile:
    wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
    wr.writerow(list(aaaa))

# calculate AUC - regulon for individual cells

In [None]:
# Enrichment of a regulon is measured as the Area Under the recovery Curve (AUC) of the genes that define this regulon.
auc_mtx = aucell(ex_matrix, regulons, num_workers=4)



In [None]:
# Save regulon enrichment to csv 
auc_mtx.to_csv('/home/jovyan/scripts/renal_covid_19/steroid_pipeline/regulon_analysis/CD14_wave2_steroid_2021_auc_mtx.csv')

In [None]:
#sns.clustermap(auc_mtx, figsize=(16,16))

In [None]:
adata

In [None]:
auc_mtx.shape

# post scenic

# RSS

In [None]:
#reload saved auc_mtx
auc_mtx = pd.read_csv('/home/jovyan/scripts/renal_covid_19/steroid_pipeline/regulon_analysis/CD14_wave2_steroid_2021_auc_mtx.csv')

In [None]:
auc_mtx

In [None]:
auc_mtx.set_index('Cell',inplace=True)
auc_mtx.head()

In [2]:
adata = sc.read('/home/jovyan/scripts/renal_covid_19/steroid_pipeline/regulon_analysis/adata_CD14_wave2_steroid_2021_auc_mtx.h5ad')

In [3]:
adata.obs['case_control'].unique()

['POSITIVE', 'RECOVERY']
Categories (2, object): ['POSITIVE', 'RECOVERY']

In [None]:
# Calculate regulon Specificity Score
rss_cellType = regulon_specificity_scores(auc_mtx, adata.obs['steroid_timeline'])
rss_cellType

In [None]:
#rss_cellType.to_csv('/lustre/scratch117/cellgen/team298/win/for_lisa/regulons_before_after_steroid.csv')

In [None]:
adata.obs['annot4'].unique()

In [None]:
import matplotlib.pyplot as plt
from adjustText import adjust_text
from pyscenic.plotting import plot_rss
# RSS panel plot with all cell types
plt.rcParams.update({'font.size': 18})
cats = sorted(list(set(adata.obs['steroid_timeline'])))

fig = plt.figure(figsize=(8, 8))
for c,num in zip(cats, range(1,len(cats)+1)):
    x=rss_cellType.T[c]

    ax = fig.add_subplot(1,2,num)
    plot_rss(rss_cellType, c, top_n=5, max_n=None, ax=ax)
    ax.set_ylim( x.min()-(x.max()-x.min())*0.05 , x.max()+(x.max()-x.min())*0.05 )
    for t in ax.texts:
        t.set_fontsize(12)
    ax.set_ylabel('')
    ax.set_xlabel('')
    adjust_text(ax.texts, autoalign='xy', ha='right', va='bottom', arrowprops=dict(arrowstyle='-',color='lightgrey'), precision=0.001 )
 
fig.text(0.5, 0.0, 'Regulon', ha='center', va='center', size='x-large')
fig.text(0.00, 0.5, 'Regulon specificity score (RSS)', ha='center', va='center', rotation='vertical', size='x-large')
plt.tight_layout()
plt.rcParams.update({
    'figure.autolayout': True,
        'figure.titlesize': 'large' ,
        'axes.labelsize': 'large',
        'axes.titlesize':'large',
        'xtick.labelsize':'large',
        'ytick.labelsize':'large'
        })

#plt.show()
plt.savefig('regulon_poster.pdf',bbox_inches="tight",dpi=300)

# customised plot

In [None]:
T_rss_cellType = rss_cellType.T
T_rss_cellType.head()

In [None]:
before_top =T_rss_cellType.sort_values('before_steroid',ascending=False)
before_top.head(n=10)

In [None]:
before_top.index[0:10]

In [None]:
after_top =T_rss_cellType.sort_values('after_steroid',ascending=False)
after_top.head(n=10)

In [None]:
after_top['row_num'] = np.arange(len(after_top))
after_top

In [None]:
after_top['row_names']=after_top.index

In [None]:
after_top.set_index('row_num',inplace=True)


In [None]:
ind =[]
for l in list(before_top.index[0:5]):
    ind.append(after_top.index[after_top.row_names == l][0])
    

In [None]:
ind

In [None]:
additional_plot=list(before_top.index[0:5])

In [None]:
import matplotlib.pyplot as plt
from adjustText import adjust_text
from pyscenic.plotting import plot_rss
# RSS panel plot with all cell types
plt.rcParams.update({'font.size': 18})
cats = ['after_steroid']

fig = plt.figure(figsize=(8, 8))
for c,num in zip(cats, range(1,len(cats)+1)):
    
    x=rss_cellType.T[c]
    
    ax = fig.add_subplot(1,2,num)
    plot_rss1(rss_cellType, c,ind,additional_plot, top_n=5, max_n=None, ax=ax)
    ax.set_ylim( x.min()-(x.max()-x.min())*0.05 , x.max()+(x.max()-x.min())*0.05 )
    for t in ax.texts:
        t.set_fontsize(12)
    ax.set_ylabel('')
    ax.set_xlabel('')
    adjust_text(ax.texts, autoalign='xy', ha='right', va='bottom', arrowprops=dict(arrowstyle='-',color='lightgrey'), precision=0.001 )
 
fig.text(0.5, 0.0, 'Regulon', ha='center', va='center', size='x-large')
fig.text(0.00, 0.5, 'Regulon specificity score (RSS)', ha='center', va='center', rotation='vertical', size='x-large')
plt.tight_layout()
plt.rcParams.update({
    'figure.autolayout': True,
        'figure.titlesize': 'large' ,
        'axes.labelsize': 'large',
        'axes.titlesize':'large',
        'xtick.labelsize':'medium',
        'ytick.labelsize':'medium'
        })


plt.savefig('regulon_combined.pdf',bbox_inches="tight",dpi=300)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from math import ceil, floor
def plot_rss1(rss, cell_type,ind,list_to_plot,top_n=5, max_n=None, ax=None):
    if ax is None:
        _, ax = plt.subplots(1, 1, figsize=(4, 4))
    if max_n is None:
        max_n = rss.shape[1]
    data = rss.T[cell_type].sort_values(ascending=False)[0:max_n]
    ax.plot(np.arange(len(data)), data, '.')
    ax.set_ylim([floor(data.min() * 100.0) / 100.0, ceil(data.max() * 100.0) / 100.0])
    ax.set_ylabel('RSS')
    ax.set_xlabel('Regulon')
    ax.set_title(cell_type)
    ax.set_xticklabels([])

    font = {
        'color': 'red',
        'weight': 'normal',
        'size': 2,
    }
    
    for idx, (regulon_name, rss_val) in enumerate(zip(data[0:top_n].index, data[0:top_n].values)):
        ax.plot([idx, idx], [rss_val, rss_val], 'r.')
        ax.text(
            idx + (max_n / 25),
            rss_val,
            regulon_name,
            fontdict=font,
            horizontalalignment='left',
            verticalalignment='center',
        )
    
    font1 = {
        'color': 'green',
        'weight': 'normal',
        'size': 1,
    }
    #ind = [138, 132, 134, 122, 129, 96, 118, 95, 146, 99]
    num=0
    for idx, (regulon_name, rss_val) in enumerate(zip(data[list_to_plot].index, data[list_to_plot].values)):
        idx =ind[num]
        ax.plot([idx, idx], [rss_val, rss_val], 'r.')
        
        num=num+1
        ax.text(
            idx + (max_n / 25),
            rss_val,
            regulon_name,
            fontdict=font1,
            horizontalalignment='left',
            verticalalignment='center',
        )

# heatmap

In [None]:
# Select top 5 regulons per cell type 
topreg = []
for i,c in enumerate(cats):
    topreg.extend(
        list(rss_cellType.T[c].sort_values(ascending=False)[:5].index)
    )
topreg = list(set(topreg))

In [None]:
# Generate a Z-score for each regulon to enable comparison between regulons
auc_mtx_Z = pd.DataFrame( index=auc_mtx.index )
for col in list(auc_mtx.columns):
    auc_mtx_Z[ col ] = ( auc_mtx[col] - auc_mtx[col].mean()) / auc_mtx[col].std(ddof=0)
#auc_mtx_Z.sort_index(inplace=True)

In [None]:
# Generate heatmap
def palplot(pal, names, colors=None, size=1):
    n = len(pal)
    f, ax = plt.subplots(1, 1, figsize=(n * size, size))
    ax.imshow(np.arange(n).reshape(1, n),
              cmap=mpl.colors.ListedColormap(list(pal)),
              interpolation="nearest", aspect="auto")
    ax.set_xticks(np.arange(n) - .5)
    ax.set_yticks([-.5, .5])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    colors = n * ['k'] if colors is None else colors
    for idx, (name, color) in enumerate(zip(names, colors)):
        ax.text(0.0+idx, 0.0, name, color=color, horizontalalignment='center', verticalalignment='center')
    return f

In [None]:
colors = sns.color_palette('bright',n_colors=len(cats) )
colorsd = dict( zip( cats, colors ))
colormap = [ colorsd[x] for x in adata.obs['steroid_timeline'] ]

In [None]:
import matplotlib as mpl
sns.set()
sns.set(font_scale=0.8)
fig = palplot( colors, cats, size=1.0)

In [None]:
sns.set(font_scale=1.2)
g = sns.clustermap(auc_mtx_Z[topreg], annot=False,  square=False,  linecolor='gray',
    yticklabels=False, xticklabels=True, vmin=-2, vmax=6, row_colors=colormap,row_cluster=True,
    cmap="YlGnBu", figsize=(21,16) )
g.cax.set_visible(True)
g.ax_heatmap.set_ylabel('')
g.ax_heatmap.set_xlabel('')

# display motif

In [None]:
BASE_URL = "http://motifcollections.aertslab.org/v9/logos/"
def display_logos(df: pd.DataFrame, top_target_genes: int = 3, base_url: str = BASE_URL):
    """
    :param df:
    :param base_url:
    """
    # Make sure the original dataframe is not altered.
    df = df.copy()
    
    # Add column with URLs to sequence logo.
    def create_url(motif_id):
        return '<img src="{}{}.png" style="max-height:124px;"></img>'.format(base_url, motif_id)
    df[("Enrichment", COLUMN_NAME_LOGO)] = list(map(create_url, df.index.get_level_values(COLUMN_NAME_MOTIF_ID)))
    
    # Truncate TargetGenes.
    def truncate(col_val):
        return sorted(col_val, key=op.itemgetter(1))[:top_target_genes]
    df[("Enrichment", COLUMN_NAME_TARGETS)] = list(map(truncate, df[("Enrichment", COLUMN_NAME_TARGETS)]))
    
    MAX_COL_WIDTH = pd.get_option('display.max_colwidth')
    pd.set_option('display.max_colwidth', -1)
    display(HTML(df.head().to_html(escape=False)))
    pd.set_option('display.max_colwidth', MAX_COL_WIDTH)

In [None]:
#BASE_URL = "http://motifcollections.aertslab.org/v9/logos/"
def fetch_logo(regulon, base_url = BASE_URL):
    for elem in regulon.context:
        if elem.endswith('.png'):
            return '<img src="{}{}" style="max-height:124px;"></img>'.format(base_url, elem)
    return ""

In [None]:
with open(REGULONS_FNAME,'rb') as f:
    regulons = pickle.load(f)

In [None]:
import operator as op
df_regulons = pd.DataFrame(data=[list(map(op.attrgetter('name'), regulons)),
                                 list(map(len, regulons)),
                                 list(map(fetch_logo, regulons))], index=['name', 'count', 'logo']).T

In [None]:
df_regulons.head(10)

In [None]:
df_regulon_to_plot = df_regulons[df_regulons['name'].isin(['STAT1(+)','STAT2(+)','IRF7(+)','YY1(+)','AHR(+)','ATF4(+)','IRF9(+)'])]

In [None]:
df_regulon_to_plot

In [None]:
from IPython.display import HTML, display
MAX_COL_WIDTH = pd.get_option('display.max_colwidth')
pd.set_option('display.max_colwidth', -1)
display(HTML(df_regulon_to_plot.to_html(escape=False)))
pd.set_option('display.max_colwidth', MAX_COL_WIDTH)

# AUC and density

In [None]:
from pyscenic.binarization import binarize
binary_mtx, auc_thresholds = binarize( auc_mtx, num_workers=25 )
binary_mtx.head()

In [None]:
# select regulons:
import matplotlib.pyplot as plt
r = [ 'STAT1(+)', 'STAT2(+)', 'AHR(+)' ]

fig, axs = plt.subplots(1, 3, figsize=(12, 4), dpi=150, sharey=False)
for i,ax in enumerate(axs):
    sns.distplot(auc_mtx[ r[i] ], ax=ax, norm_hist=True, bins=100)
    ax.plot( [ auc_thresholds[ r[i] ] ]*2, ax.get_ylim(), 'r:')
    ax.title.set_text( r[i] )
    ax.set_xlabel('')
    
fig.text(-0.01, 0.5, 'Frequency', ha='center', va='center', rotation='vertical', size='large')
fig.text(0.5, -0.01, 'AUC', ha='center', va='center', rotation='horizontal', size='large')

fig.tight_layout()
fig.savefig('/home/jovyan/scripts/renal_covid_19/steroid_pipeline/regulon_analysis/auc_CD14_wave2_steroid_2021_auc_mtx.png', dpi=600, bbox_inches='tight')

In [None]:
type(auc_mtx)

In [None]:
adata.obs['steroid_timeline'].unique()

In [None]:
before = adata[adata.obs['steroid_timeline']=='before_steroid']

In [None]:
after = adata[adata.obs['steroid_timeline']=='after_steroid']

In [None]:
auc_mtx_before = auc_mtx[auc_mtx.index.isin(before.obs.index)]

In [None]:
auc_mtx_before.shape

In [None]:
auc_mtx_after = auc_mtx[auc_mtx.index.isin(after.obs.index)]

In [None]:
auc_mtx_after.shape

In [None]:
auc_mtx.shape

In [None]:
binary_mtx_before, auc_thresholds_before = binarize( auc_mtx_before, num_workers=25 )
binary_mtx_before.head()

In [None]:
binary_mtx_after, auc_thresholds_after = binarize( auc_mtx_after, num_workers=25 )
binary_mtx_after.head()

In [None]:
# select regulons:
import matplotlib.pyplot as plt
r = ['STAT1(+)', 'STAT2(+)', 'IRF7(+)','YY1(+)','ATF4(+)','AHR(+)'  ]

fig, axs = plt.subplots(1, 6, figsize=(12, 4), dpi=150, sharey=False)
for i,ax in enumerate(axs):
    sns.distplot(auc_mtx_before[ r[i] ], ax=ax, norm_hist=True, bins=100)
    ax.plot( [ auc_thresholds_before[ r[i] ] ]*2, ax.get_ylim(), 'r:')
    ax.title.set_text( r[i] )
    ax.set_xlabel('')
    
fig.text(-0.01, 0.5, 'Frequency', ha='center', va='center', rotation='vertical', size='large')
fig.text(0.5, -0.01, 'AUC', ha='center', va='center', rotation='horizontal', size='large')

In [None]:
# select regulons:
import matplotlib.pyplot as plt
r = [ 'STAT1(+)', 'STAT2(+)', 'IRF7(+)','YY1(+)','ATF4(+)','AHR(+)' ]

fig, axs = plt.subplots(1, 6, figsize=(12, 4), dpi=150, sharey=False)
for i,ax in enumerate(axs):
    sns.distplot(auc_mtx_after[ r[i] ], ax=ax, norm_hist=True, bins=100)
    ax.plot( [ auc_thresholds_after[ r[i] ] ]*2, ax.get_ylim(), 'r:')
    ax.title.set_text( r[i] )
    ax.set_xlabel('')
    
fig.text(-0.01, 0.5, 'Frequency', ha='center', va='center', rotation='vertical', size='large')
fig.text(0.5, -0.01, 'AUC', ha='center', va='center', rotation='horizontal', size='large')

In [None]:
auc_thresholds_after

In [None]:
# select regulons:
plt.rcParams.update({'font.size': 12})
import matplotlib.pyplot as plt
r = [ 'STAT1(+)', 'STAT2(+)', 'IRF7(+)','YY1(+)' ]

fig, axs = plt.subplots(1, 4, figsize=(12, 4), dpi=150, sharey=False)
for i,ax in enumerate(axs):
    sns.distplot(auc_mtx_before[ r[i] ], ax=ax, norm_hist=True, bins=100,color='green',label='before')
    sns.distplot(auc_mtx_after[ r[i] ], ax=ax, norm_hist=True, bins=100,color='red',label='after')
    
    #ax.plot( [ auc_thresholds_after[ r[i] ] ]*2, ax.get_ylim(), 'r:')
    ax.title.set_text( r[i] )
    ax.set_xlabel('')
    ax.legend()
    
fig.text(-0.01, 0.5, 'Frequency', ha='center', va='center', rotation='vertical', size='large')
fig.text(0.5, -0.01, 'AUC', ha='center', va='center', rotation='horizontal', size='large')
plt.savefig('regulon_auc.pdf',bbox_inches="tight",dpi=300)