## SCENIC+ Mullerian duct mesenchymal cells

### method benchmarking

In [None]:
#supress warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import sys
import os

In [None]:
# Get chromosome sizes (for hg38 here)
import pyranges as pr
import requests
import pandas as pd
target_url='http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.chrom.sizes'
chromsizes=pd.read_csv(target_url, sep='\t', header=None)
chromsizes.columns=['Chromosome', 'End']
chromsizes['Start']=[0]*chromsizes.shape[0]
chromsizes=chromsizes.loc[:,['Chromosome', 'Start', 'End']]
# Exceptionally in this case, to agree with CellRangerARC annotations
chromsizes['Chromosome'] = [chromsizes['Chromosome'][x].replace('v', '.') for x in range(len(chromsizes['Chromosome']))]
chromsizes['Chromosome'] = [chromsizes['Chromosome'][x].split('_')[1] if len(chromsizes['Chromosome'][x].split('_')) > 1 else chromsizes['Chromosome'][x] for x in range(len(chromsizes['Chromosome']))]
chromsizes=pr.PyRanges(chromsizes)

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

In [None]:
chromsizes

### Add cell type annotation information 

The barcode metadata should be provided as a pd.DataFrame.

* The index* of the pandas dataframe should correspond to BARCODE (e.g. ATGTCTGATAGA-1, additional tags are possible using ___; e.g. ATGTCTGATAGA-1___sample_1) and it must contain a ‘sample_id’ column indicating the sample label fo origin. It is also possible to use other separation pattern (e.g. -), but then it will have to be specified in the function.

* Alternative: add a column named ‘barcode’ to the metadata with the corresponding cell barcodes (in this case the name of the cells will not be used to infer the barcode id). This is the option we use in this tutorial as well.

In [None]:
females_late = pd.read_csv("/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/ArchR/females_late/umap_coords.csv", index_col = 0)
females_late.head()

In [None]:
cell_data = females_late.copy()
cell_data.shape

In [None]:
cell_data['predictedGroup_Un'].value_counts(dropna = False)

In [None]:
cell_data = cell_data[cell_data['predictedGroup_Un'].isin(['Fallopian Mese', 
                        'Uterus Mese',  
                         'Cervix Mese', 'Upper Vagina Mese'])]

In [None]:
import numpy as np

In [None]:
cell_data.shape

In [None]:
cell_data.tail()

In [None]:
import seaborn as sns 
import matplotlib.pyplot as plt

sns.set(rc={'figure.figsize':(7, 5)}, font_scale=1)
sns.set_style("whitegrid")
ax = sns.boxplot(x = 'predictedGroup_Un', y = 'predictedScore_Un', hue = 'predictedGroup_Un', data = cell_data, width = 0.8, orient = 'v', dodge = True, fliersize = 2)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)
ax.set_ylabel('predictedScore_Un')
ax.set_xlabel('predictedGroup_Un')
ax.grid(False)
ax.axhline(y=0.5, color = 'gray', linestyle =  '--')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.1, title = 'predictedGroup')
fig = plt.gcf()
plt.show()
plt.clf()
plt.close()

In [None]:
mapping_dict = {'Fallopian Mese' : 'FallopianMese', 
                        'Uterus Mese' : 'UterusMese', 'Cervix Mese' : 'CervixMese', 
               'Upper Vagina Mese' : 'UpperVaginaMese'}
cell_data['HarmonisedClusters'] = cell_data['predictedGroup_Un'].map(mapping_dict)

In [None]:
cell_data.shape

In [None]:
cell_data[['Sample']].value_counts()

In [None]:
color_palette = {
                'FallopianMese' : 'orange', 
                'UterusMese' : 'orangered', 
                'CervixMese' : 'palevioletred', 
'UpperVaginaMese' : 'lightpink'}

In [None]:
cell_data[['Sample', 'HarmonisedClusters']].value_counts()

In [None]:
import seaborn as sns 
import matplotlib.pyplot as plt

sns.set(rc={'figure.figsize':(8, 5)}, font_scale=1)
sns.set_style("whitegrid")
ax = sns.boxplot(x = 'Sample', y = 'predictedScore_Un', hue = 'HarmonisedClusters', data = cell_data, width = 0.8, palette = color_palette, orient = 'v', dodge = True, fliersize = 2)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)
ax.set_ylabel('predictedScore')
ax.set_xlabel('sample')
ax.grid(False)
ax.axhline(y=0.5, color = 'gray', linestyle =  '--')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.1, title = 'HarmonisedClusters')
fig = plt.gcf()
plt.show()
plt.clf()
plt.close()
fig.savefig('./boxplot_differentiatingmullerianmesenchyme.pdf', bbox_inches = 'tight') 

In [None]:
cell_data = cell_data[cell_data['predictedScore_Un'] >= 0.5]

In [None]:
cell_data[['HarmonisedClusters']].value_counts()

In [None]:
cell_data[['donor']].value_counts()

In [None]:
cell_data[['stage']].value_counts()

In [None]:
import random
from itertools import chain
def downsample(df, labels, n): 
    
    myindex = df[labels].value_counts().index 
    myvalues = df[labels].value_counts().values
    clusters = pd.Series(myvalues, index = myindex)
    
    # Find clusters with > n cells 
    cl2downsample = clusters.index[ clusters.values > n ]

    # save all barcode ids from small clusters
    holder = []
    holder.append( df.index[[ i not in cl2downsample for i in df[labels] ]] ) 

    # randomly sample n cells in the cl2downsample
    for cl in cl2downsample:
        print(cl)
        cl_sample = df[[ i == cl for i in df[labels]]].index
        cl_downsample = random.sample(set(cl_sample), n )
        holder.append(cl_downsample)
    
    # samples to include
    samples = list(chain(*holder))

    # Filter adata_count
    df = df[[ i in samples for i in df.index ]]
    return df

In [None]:
cell_data_downsampled = downsample(cell_data, 'HarmonisedClusters', 1500)

In [None]:
cell_data_downsampled['donor'].value_counts()

In [None]:
cell_data_downsampled['stage'].value_counts()

In [None]:
cell_data_downsampled[['HarmonisedClusters']].value_counts()

### Try first without downsampling

In [None]:
#cell_data_downsampled = cell_data.copy()

In [None]:
cell_data_downsampled.head()

In [None]:
import numpy as np

In [None]:
np.unique(cell_data_downsampled['Sample'])

In [None]:
cell_data_downsampled.shape

In [None]:
cell_data_downsampled['HarmonisedClusters'].value_counts()

In [None]:
cell_data_downsampled['Sample'].value_counts()

In [None]:
cell_data_downsampled = cell_data_downsampled[cell_data_downsampled['Sample'] != 'HD_F_GON12449010']
cell_data_downsampled = cell_data_downsampled[cell_data_downsampled['Sample'] != 'HD_F_GON12877982']

In [None]:
cell_data_downsampled['Sample'].value_counts()

In [None]:
cell_data_downsampled['HarmonisedClusters'].value_counts(dropna = False)

In [None]:
cell_data_downsampled['barcode'] = [x.split('#')[1] for x in cell_data_downsampled.index.tolist()]

In [None]:
cell_data_downsampled['index'] = cell_data_downsampled['barcode'] + '___' + cell_data_downsampled['Sample'].astype(str)

In [None]:
cell_data_downsampled = cell_data_downsampled.set_index('index')

In [None]:
cell_data_downsampled.head()

### Generate pseudobulk files per cell type

Now we have all the ingredients we need to generate the pseudobulk files. With this function we will generate fragments files per group and the corresponding bigwigs. The mandatory input to this function are: 
 * The annotation dataframe (input_data) 
 * The variable used to group the cells (multiome_GermCells)
 * The chromosome sizes 
 * The paths to where the bed and bigiwg files will be written
 * A dictionary indicating the fragments file corresponsing to each sample. The sample ids used as keys in this dictionary must match with the sample ids in the annotation data frame!

The output will be two dictionaries containing the paths to the bed and bigwig files, respectively, to each group.

In [None]:
np.unique(cell_data_downsampled['Sample'])

In [None]:
## Path to fragments files of samples
fragments_dict = {'HD_F_GON11282675' : '/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/data/HD_F_GON11282675/fragments.tsv.gz', 
                  'HD_F_GON11389960' : '/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/data/HD_F_GON11389960/fragments.tsv.gz', 
                  'HD_F_GON11389961' : '/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/data/HD_F_GON11389961/fragments.tsv.gz',
                  'HD_F_GON12449011' : '/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/data/HD_F_GON12449011/fragments.tsv.gz',
                  'HD_F_GON11282676' : '/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/data/HD_F_GON11282676/fragments.tsv.gz',
                  'HD_F_GON12877983' : '/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/data/HD_F_GON12877983/fragments.tsv.gz',
                  'HD_F_GON12877984' : '/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/data/HD_F_GON12877984/fragments.tsv.gz',
                  'HD_F_GON14609874' : '/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/data/HD_F_GON14609874/fragments.tsv.gz',
                  'HD_F_GON14666992' : '/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/data/HD_F_GON14666992/fragments.tsv.gz',
                  'HD_F_GON13941947' : '/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/data/HD_F_GON13941947/fragments.tsv.gz',
                  'HD_F_GON13941946' : '/nfs/team292/vl6/FetalReproductiveTract/ATAC_QC/data/HD_F_GON13941946/fragments.tsv.gz',
              
                'HCA_F_GON11173192_and_HCA_F_GON11212447' : '/nfs/team292/vl6/FetalReproductiveTract/MULTIOME_QC/data/HCA_F_GON11173192_and_HCA_F_GON11212447/fragments.tsv.gz', # 12 PCW (Hrv103)
                  'HD_F_GON13077785_and_HD_F_GON13094224' : '/nfs/team292/vl6/FetalReproductiveTract/MULTIOME_QC/data/HD_F_GON13077785_and_HD_F_GON13094224/fragments.tsv.gz', 
                 
                 }

In [None]:
outDir = '/lustre/scratch126/cellgen/team292/vl6/pycistopic/mullerian_mese_withvagina_post9pcw/'
tmpDir = '/lustre/scratch126/cellgen/team292/vl6/pycistopic/temp/'

In [None]:
from pycisTopic.pseudobulk_peak_calling import *
bw_paths, bed_paths = export_pseudobulk(input_data = cell_data_downsampled,
                 variable = 'HarmonisedClusters',
                 sample_id_col = 'Sample',
                 chromsizes = chromsizes,
                 bed_path = outDir + 'consensus_peak_calling/pseudobulk_bed_files/',
                 bigwig_path = outDir + 'consensus_peak_calling/pseudobulk_bw_files/',
                 path_to_fragments = fragments_dict,
                 n_cpu = 1,
                 normalize_bigwig = True,
                 remove_duplicates = True,
                 #_temp_dir = tmpDir + 'ray_spill',
                 split_pattern = '___')

In [None]:
# Save
import pickle
with open(outDir + 'consensus_peak_calling/pseudobulk_bed_files/bed_paths.pkl', 'wb') as f:
  pickle.dump(bed_paths, f)

import pickle
with open(outDir + 'consensus_peak_calling/pseudobulk_bed_files/bw_paths.pkl', 'wb') as f:
  pickle.dump(bw_paths, f)

### Calling peaks with MACS2

In [None]:
from pycisTopic.pseudobulk_peak_calling import *
macs_path='/opt/conda/envs/scenicplus/bin/macs2'
macs_outdir = outDir + 'consensus_peak_calling/MACS/'
# os.mkdir(macs_outdir)

In [None]:
#sys.stderr = open(os.devnull, "w")  # silence stderr

In [None]:
#ray.shutdown()

In [None]:
# Run peak calling
narrow_peaks_dict = peak_calling(macs_path,
                                 bed_paths,
                                 macs_outdir,
                                 genome_size='hs',
                                 n_cpu=1,
                                 input_format='BEDPE',
                                 shift=73,
                                 ext_size=146,
                                 keep_dup = 'all',
                                 q_value = 0.05,
                                 #_temp_dir = tmpDir + 'ray_spill'
                                )
sys.stderr = sys.__stderr__  # unsilence stderr

In [None]:
# Save
import pickle
with open(outDir + 'consensus_peak_calling/MACS/narrow_peaks_dict.pkl', 'wb') as f:
  pickle.dump(narrow_peaks_dict, f)

### Deriving consensus peaks with iterative overlapping

Finally, it is time to derive the consensus peaks. To do so, we use the TGCA iterative peak filtering approach. First, each summit is extended a peak_half_width in each direction and then we iteratively filter out less significant peaks that overlap with a more significant one. During this procedure peaks will be merged and depending on the number of peaks included into them, different processes will happen:

 * 1 peak: The original peak region will be kept

 * 2 peaks: The original peak region with the highest score will be kept

 * 3 or more peaks: The orignal peak region with the most significant score will be taken, and all the original peak regions in this merged peak region that overlap with the significant peak region will be removed. The process is repeated with the next most significant peak (if it was not removed already) until all peaks are processed.

This process will happen twice, first in each pseudobulk peaks; and after peak score normalization, to process all peaks together.

In [None]:
path_to_blacklist = '/nfs/team292/vl6/scenicplus/pycisTopic/blacklist/hg38-blacklist.v2.bed'

In [None]:
from pycisTopic.iterative_peak_calling import *
# Other param
peak_half_width = 250
# Get consensus peaks
sys.stderr = open(os.devnull, "w")  # silence stderr
consensus_peaks=get_consensus_peaks(narrow_peaks_dict, peak_half_width, chromsizes=chromsizes, path_to_blacklist=path_to_blacklist)
sys.stderr = sys.__stderr__  # unsilence stderr

In [None]:
# Write to bed
consensus_peaks.to_bed(path= outDir + 'consensus_peak_calling/consensus_regions.bed', keep=True, compression='infer', chain=False)

### Quality control 

The next step is to perform QC in the scATAC-seq samples (in this case, only one run). There are several measurements and visualizations performed in this step:

 * Barcode rank plot

 * Duplication rate

 * Insertion size

 * TSS enrichment

 * Fraction of Reads In Peaks (FRIP)

To calculate the TSS enrichment we need to provide TSS annotations. You can easily download them via pybiomart.

In [None]:
# Get TSS annotations
import pybiomart as pbm
dataset = pbm.Dataset(name='hsapiens_gene_ensembl',  host='http://www.ensembl.org')
annot = dataset.query(attributes=['chromosome_name', 'transcription_start_site', 'strand', 'external_gene_name', 'transcript_biotype'])
annot['Chromosome/scaffold name'] = annot['Chromosome/scaffold name'].to_numpy(dtype = str)
filter = annot['Chromosome/scaffold name'].str.contains('CHR|GL|JH|MT')
annot = annot[~filter]
annot['Chromosome/scaffold name'] = annot['Chromosome/scaffold name'].str.replace(r'(\b\S)', r'chr\1')
annot.columns=['Chromosome', 'Start', 'Strand', 'Gene', 'Transcript_type']
annot = annot[annot.Transcript_type == 'protein_coding']

In [None]:
annot.tail()

In [None]:
#ray.shutdown()

In [None]:
fragments_dict

In [None]:
from pycisTopic.qc import *
## Set regions. We will use the consensus peaks we have just called, but we could also use the bulk peaks per sample instead for this step
path_to_regions= {'HD_F_GON11282675' : outDir + 'consensus_peak_calling/consensus_regions.bed',
                 'HD_F_GON11389960' : outDir + 'consensus_peak_calling/consensus_regions.bed', 
                 'HD_F_GON11389961' : outDir + 'consensus_peak_calling/consensus_regions.bed', 
                 'HD_F_GON12449011' : outDir + 'consensus_peak_calling/consensus_regions.bed', 
                 'HD_F_GON11282676' : outDir + 'consensus_peak_calling/consensus_regions.bed',
                 'HD_F_GON12877983' : outDir + 'consensus_peak_calling/consensus_regions.bed',
                 'HD_F_GON12877984' : outDir + 'consensus_peak_calling/consensus_regions.bed',
                 'HD_F_GON14609874' : outDir + 'consensus_peak_calling/consensus_regions.bed', 
                  'HD_F_GON14666992' : outDir + 'consensus_peak_calling/consensus_regions.bed', 
                   'HD_F_GON13941947' : outDir + 'consensus_peak_calling/consensus_regions.bed', 
                  'HD_F_GON13941946': outDir + 'consensus_peak_calling/consensus_regions.bed', 
                  'HCA_F_GON11173192_and_HCA_F_GON11212447' : outDir + 'consensus_peak_calling/consensus_regions.bed', 
                  'HD_F_GON13077785_and_HD_F_GON13094224' :  outDir + 'consensus_peak_calling/consensus_regions.bed',
                 }

metadata_bc, profile_data_dict = compute_qc_stats(fragments_dict = fragments_dict,
                tss_annotation = annot,
                stats=['barcode_rank_plot', 'duplicate_rate', 'insert_size_distribution', 'profile_tss', 'frip'],
                label_list = None,
                path_to_regions = path_to_regions,
                n_cpu = 1,
                valid_bc = None,
                n_frag = 100,
                n_bc = None,
                tss_flank_window = 1000,
                tss_window = 50,
                tss_minimum_signal_window = 100,
                tss_rolling_window = 10,
                remove_duplicates = True,
                #_temp_dir = '/nfs/team292/vl6/symtopic/'
                )

In [None]:
#os.makedirs(outDir+'quality_control')
import pickle
with open(outDir + 'quality_control/metadata_bc.pkl', 'wb') as f:
  pickle.dump(metadata_bc, f)

import pickle
with open(outDir + 'quality_control/profile_data_dict.pkl', 'wb') as f:
  pickle.dump(profile_data_dict, f)

### Sample-level statistics

Once the QC metrics have been computed you can visualize the results at the sample-level and the barcode-level. Sample-level statistics can be used to assess the overall quality of the sample, while barcode level statistics can be use to differentiate good quality cells versus the rest. The sample-level graphs include:

 * **Barcode rank plot**: The barcode rank plot shows the distribution of non-duplicate reads and which barcodes were inferred to be associated with cells. A steep drop-off (‘knee’) is indicative of good separation between the cell-associated barcodes and the barcodes associated with empty partitions.

 * **Insertion size**: ATAC-seq requires a proper pair of Tn5 transposase cutting events at the ends of DNA. In the nucleosome-free open chromatin regions, many molecules of Tn5 can kick in and chop the DNA into small pieces; around nucleosome-occupied regions, and Tn5 can only access the linker regions. Therefore, in a good ATAC-seq library, you should expect to see a sharp peak at the <100 bp region (open chromatin), and a peak at ~200bp region (mono-nucleosome), and other larger peaks (multi-nucleosomes). A clear nucleosome pattern indicates a good quality of the experiment.

 * **Sample TSS enrichment**: The TSS enrichment calculation is a signal to noise calculation. The reads around a reference set of TSSs are collected to form an aggregate distribution of reads centered on the TSSs and extending to 1000 bp in either direction (for a total of 2000bp). This distribution is then normalized by taking the average read depth in the 100 bps at each of the end flanks of the distribution (for a total of 200bp of averaged data) and calculating a fold change at each position over that average read depth. This means that the flanks should start at 1, and if there is high read signal at transcription start sites (highly open regions of the genome) there should be an increase in signal up to a peak in the middle.

 * **FRIP distribution**: Fraction of all mapped reads that fall into the called peak regions, i.e. usable reads in significantly enriched peaks divided by all usable reads. A low FRIP indicates that many reads form part of the background, and so that the data is noisy.

 * **Duplication rate**: A fragment is considered “usable” if it uniquely maps to the genome and remains after removing PCR duplicates (defined as two fragments that map to the same genomic position and have the same unique molecular identifier). The duplication rate serves to estimate the amount of usable reads per barcode. High duplication rates may indicate over-sequencing or lack of fragments after transposition and encapsulation. We recommend using duplicate_rate_as_hexbin = True when working with big fragments files.

In [None]:
# Load sample metrics
import pickle
infile = open(outDir + 'quality_control/profile_data_dict.pkl', 'rb')
profile_data_dict = pickle.load(infile)
infile.close()

In [None]:
from pycisTopic.qc import *
plot_sample_metrics(profile_data_dict,
           insert_size_distribution_xlim=[0,600],
           ncol=2,
           plot=True,
           save= outDir + 'quality_control/sample_metrics.pdf',
           duplicate_rate_as_hexbin = True)

### Barcode level statistics 

Barcode-level statistics can be used to select high quality cells. Typical measurements that can be used are:

 * **Total number of (unique) fragments**

 * **TSS enrichment**: The score at position in the TSS enrichmen score for for each barcode (at position 0, the TSS). Noisy cells will have a low TSS enrichment.

 * **FRIP**: The fraction of reads in peaks for each barcode. Noisy cells have low FRIP values. However, this filter should be used with nuance, as it depends on the quality of the original peaks. For example, if there is a rare population in the sample, its specific peaks may be missed by peak calling algorithms, causing a decrease in their FRIP values.

In [None]:
# Load barcode metrics
import pickle
infile = open(outDir + 'quality_control/metadata_bc.pkl', 'rb')
metadata_bc = pickle.load(infile)
infile.close()

In [None]:
# Return figure to plot together with other metrics, and cells passing filters. Figure will be saved as pdf.
from pycisTopic.qc import *
FRIP_NR_FRAG_fig = {}
FRIP_NR_FRAG_filter = {}
TSS_NR_FRAG_fig = {}
TSS_NR_FRAG_filter = {}
DR_NR_FRAG_fig = {}
for sample in metadata_bc.keys():
    FRIP_NR_FRAG_fig[sample], FRIP_NR_FRAG_filter[sample]=plot_barcode_metrics(metadata_bc[sample],
                                           var_x='Log_unique_nr_frag',
                                           var_y='FRIP',
                                           min_x=3,
                                           max_x=None,
                                           min_y=0.4,
                                           max_y=None,
                                           return_cells=True,
                                           return_fig=True,
                                           plot=False,
                                           save= outDir + 'quality_control/barcode_metrics_FRIP-VS-NRFRAG_'+sample+'.pdf')
    # Return figure to plot together with other metrics, and cells passing filters
    TSS_NR_FRAG_fig[sample], TSS_NR_FRAG_filter[sample]=plot_barcode_metrics(metadata_bc[sample],
                                          var_x='Log_unique_nr_frag',
                                          var_y='TSS_enrichment',
                                          min_x=3,
                                          max_x=None,
                                          min_y=4,
                                          max_y=None,
                                          return_cells=True,
                                          return_fig=True,
                                          plot=False,
                                          save= outDir + 'quality_control/barcode_metrics_TSS-VS-NRFRAG_'+sample+'.pdf')
    # Return figure to plot together with other metrics, but not returning cells (no filter applied for the duplication rate  per barcode)
    DR_NR_FRAG_fig[sample]=plot_barcode_metrics(metadata_bc[sample],
                                          var_x='Log_unique_nr_frag',
                                          var_y='Dupl_rate',
                                          min_x=3,
                                          max_x=None,
                                          min_y=None,
                                          max_y=None,
                                          return_cells=False,
                                          return_fig=True,
                                          plot=False,
                                          plot_as_hexbin = True)

In [None]:
# # Plot barcode stats in one figure
# fig=plt.figure(figsize=(40, 100))
# i=1
# for sample in FRIP_NR_FRAG_fig.keys():
#     plt.subplot(9, 3, i)
#     plt.gca().set_title(sample, fontsize=30)
#     i += 1
#     img = fig2img(FRIP_NR_FRAG_fig[sample]) #To convert figures to png to plot together, see .utils.py. This converts the figure to png.
#     plt.imshow(img)
#     plt.axis('off')
#     plt.subplot(10, 3, i)
#     plt.gca().set_title(sample, fontsize=30)
#     i += 1
#     img = fig2img(TSS_NR_FRAG_fig[sample])
#     plt.imshow(img)
#     plt.axis('off')
#     plt.subplot(10, 3, i)
#     plt.gca().set_title(sample, fontsize=30)
#     i += 1
#     img = fig2img(DR_NR_FRAG_fig[sample])
#     plt.imshow(img)
#     plt.axis('off')
# plt.savefig(outDir + 'quality_control/combined_qc.pdf')

In [None]:
cell_data_downsampled.head()

In [None]:
sel_cells_dict = {}
for sample in np.unique(cell_data_downsampled['Sample']):
    sel_cells_dict[sample] = list(set(cell_data_downsampled[cell_data_downsampled['Sample'] == sample]['barcode']))
    print(f"{len(sel_cells_dict[sample])} barcodes passed filters for sample {sample}")

In [None]:
4+3

In [None]:
import pickle
with open(outDir +'/quality_control/bc_passing_filters.pkl', 'wb') as f:
  pickle.dump(sel_cells_dict, f)

### Create cisTopic object

In this step a fragments count matrix will be generated, in which the fragments in each region for each barcode is indicated. For multiple samples, you can add additional entries in fragment_dict, and a cisTopic object will be generated per sample. As regions, we will use the consensus peaks derived from the scRNA-seq annotations. This cisTopic object will contain:

 * **Path/s to fragment file/s (if generated from fragments files)**

 * **Fragment count matrix and binary accessibility matrix**

 * **Cell and region metadata**

In [None]:
# Metrics
import pickle
infile = open(outDir + 'quality_control/metadata_bc.pkl', 'rb')
metadata_bc = pickle.load(infile)
infile.close()
# Valid barcodes
import pickle
infile = open(outDir +'/quality_control/bc_passing_filters.pkl', 'rb')
bc_passing_filters = pickle.load(infile)
infile.close()

In [None]:
# Path to regions
path_to_regions = outDir + 'consensus_peak_calling/consensus_regions.bed'
path_to_blacklist = '/nfs/team292/vl6/scenicplus/pycisTopic/blacklist/hg38-blacklist.v2.bed'

In [None]:
#Create objects
from pycisTopic.cistopic_class import *
cistopic_obj_list=[create_cistopic_object_from_fragments(path_to_fragments=fragments_dict[key],
                                               path_to_regions=path_to_regions,
                                               path_to_blacklist=path_to_blacklist,
                                               metrics=metadata_bc[key],
                                               valid_bc=bc_passing_filters[key],
                                               n_cpu=1,
                                               project=key) for key in fragments_dict.keys()]

In [None]:
cistopic_obj = merge(cistopic_obj_list)

In [None]:
print(cistopic_obj)

In [None]:
# Save
with open(outDir + 'cisTopicObject.pkl', 'wb') as f:
  pickle.dump(cistopic_obj, f)

In [None]:
# Load cisTopic object
import pickle
infile = open(outDir + 'cisTopicObject.pkl', 'rb')
cistopic_obj = pickle.load(infile)
infile.close()

In [None]:
cistopic_obj.add_cell_data(cell_data_downsampled)

In [None]:
print(cistopic_obj)

In [None]:
cistopic_obj.cell_data['Sample'].value_counts(dropna = False)

In [None]:
cistopic_obj.cell_data.HarmonisedClusters = cistopic_obj.cell_data.HarmonisedClusters.astype(str)

In [None]:
high_quality = cistopic_obj.cell_data[cistopic_obj.cell_data.HarmonisedClusters != 'nan'].index.tolist()
cistopic_obj = cistopic_obj.subset(high_quality, copy=True)

In [None]:
cistopic_obj.cell_data['HarmonisedClusters'].value_counts(dropna = False)

In [None]:
# Save
with open(outDir + 'cisTopicObject.pkl', 'wb') as f:
  pickle.dump(cistopic_obj, f)

#### Since the early sample is male and the older are only female, exclude Y chromosome regions

In [None]:
cistopic_obj.region_data

In [None]:
ychrom = cistopic_obj.region_data[cistopic_obj.region_data['Chromosome'] == 'chrY'].index.to_list()

In [None]:
len(ychrom)

In [None]:
nonychrom = [i for i in cistopic_obj.region_data.index.to_list() if i not in ychrom]
len(nonychrom)

In [None]:
cistopic_obj = cistopic_obj.subset(regions = nonychrom, copy = True)

In [None]:
cistopic_obj.region_data

In [None]:
# Save
with open(outDir + 'cisTopicObject.pkl', 'wb') as f:
  pickle.dump(cistopic_obj, f)

### Run LDA models 

The next step is to run the LDA models. There are two types of LDA models (with Collapsed Gibbs Sampling) you can run:

 * **Serial LDA**: The parallelization is done between models rather than within each model. Recommended for small-medium sized data sets in which several models with different number os topics are being tested. You can run these models with runCGSModels().

 * **Parallel LDA with MALLET**: The parallelization is done within each model. Recommended for large data sets where a few models with different number of topics are being tested. If working in a cluster, we recommed to submit a job per model so they can run simultaneously. You can run it with runCGSModelsMallet().

In [None]:
# Load cisTopic object
import pickle
infile = open(outDir + 'cisTopicObject.pkl', 'rb')
cistopic_obj = pickle.load(infile)
infile.close()

In [None]:
cistopic_obj.cell_data.head()

In [None]:
outDir

In [None]:
from pycisTopic.cistopic_class import *
# Configure path Mallet
path_to_mallet_binary='/nfs/team292/vl6/scenicplus/Mallet/bin/mallet'
import os
os.environ['MALLET_MEMORY'] = '300G'
# Run models
models=run_cgs_models_mallet(path_to_mallet_binary,
                    cistopic_obj,
                    n_topics=[2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 
                             32, 34, 36, 38, 40, 42, 44, 46, 48, 50],
                    n_cpu=24,
                    n_iter=150,
                    random_state=555,
                    alpha=50,
                    alpha_by_topic=True,
                    eta=0.1,
                    eta_by_topic=False,
                    tmp_path='/lustre/scratch126/cellgen/team292/vl6/pycistopic/temp/', #Use SCRATCH if many models or big data set
                    save_path='/lustre/scratch126/cellgen/team292/vl6/pycistopic/temp/')

# Save
with open(outDir + 'models/mallet.pkl', 'wb') as f:
  pickle.dump(models, f)

In [None]:
# Save
#with open(outDir + 'models/mallet.pkl', 'wb') as f:
#  pickle.dump(models, f)

### Model selection 

There are several methods that can be used for model selection:

 * **Minmo_2011**: Uses the average model coherence as calculated by Mimno et al (2011). In order to reduce the impact of the number of topics, we calculate the average coherence based on the top selected average values. The better the model, the higher coherence.

 * **Log-likelihood**: Uses the log-likelihood in the last iteration as calculated by Griffiths and Steyvers (2004). The better the model, the higher the log-likelihood.

 * **Arun_2010**: Uses a density-based metric as in Arun et al (2010) using the topic-region distribution, the cell-topic distribution and the cell coverage. The better the model, the lower the metric.

 * **Cao_Juan_2009**: Uses a divergence-based metric as in Cao Juan et al (2009) using the topic-region distribution. The better the model, the lower the metric.

For scATAC-seq data models, the most helpful methods are Minmo (topic coherence) and the log-likelihood in the last iteration.

In [None]:
outDir

In [None]:
# Load cisTopic object
import pickle
infile = open(outDir + 'cisTopicObject.pkl', 'rb')
cistopic_obj = pickle.load(infile)
infile.close()
# Load models
import pickle
infile = open(outDir + 'models/mallet.pkl', 'rb')
models = pickle.load(infile)
infile.close()

In [None]:
numTopics = 24
from pycisTopic.lda_models import *
model=evaluate_models(models,
                     select_model=numTopics,
                     return_model=True,
                     metrics=['Arun_2010','Cao_Juan_2009', 'Minmo_2011', 'loglikelihood'],
                     plot_metrics=False,
                     save= outDir + 'models/model_selection.pdf')

In [None]:
# Add model to cisTopicObject
cistopic_obj.add_LDA_model(model)

In [None]:
# Save
with open(outDir + 'cisTopicObject.pkl', 'wb') as f:
  pickle.dump(cistopic_obj, f)

In [None]:
# Load cisTopic object
import pickle
infile = open(outDir + 'cisTopicObject.pkl', 'rb')
cistopic_obj = pickle.load(infile)
infile.close()

In [None]:
print(cistopic_obj)

In [None]:
cistopic_obj.fragment_matrix.shape

In [None]:
cistopic_obj.cell_data.shape

In [None]:
cistopic_obj.region_data.shape

In [None]:
from pycisTopic.clust_vis import *
run_umap(cistopic_obj,
                 target  = 'cell', scale=False)
run_tsne(cistopic_obj,
                 target  = 'cell', scale=False)

In [None]:
from pycisTopic.clust_vis import *
plot_metadata(cistopic_obj,
                 reduction_name='UMAP',
                 variables=['HarmonisedClusters', 'Sample', 'stage', 'predictedScore'], # Labels from RNA and new clusters
                 target='cell', num_columns=2,
                 text_size=10,
                 dot_size=5,
                 figsize=(10,10),
                 save= outDir + 'visualization/umap_dimensionality_reduction_label_uncorrected.pdf')


In [None]:
from pycisTopic.clust_vis import *
plot_metadata(cistopic_obj,
                 reduction_name='tSNE',
                 variables=['HarmonisedClusters', 'Sample', 'stage', 'predictedScore'], # Labels from RNA and new clusters
                 target='cell', num_columns=2,
                 text_size=10,
                 dot_size=5,
                 figsize=(10,10),
                 save= outDir + 'visualization/tsne_dimensionality_reduction_label_uncorrected.pdf')

In [None]:
cistopic_obj.cell_data.Sample.value_counts()

In [None]:
cistopic_obj.cell_data.HarmonisedClusters.value_counts()

In [None]:
color_palette

In [None]:
from pycisTopic.clust_vis import *

plot_metadata(cistopic_obj,
                 reduction_name='UMAP',
                 variables=['Sample', 'HarmonisedClusters'], # Labels from RNA and new clusters
                 target='cell', num_columns=2,
                 text_size=10,
                 dot_size=2,
                 figsize=(10,5),
               color_dictionary = {
                                   'HarmonisedClusters' : {
'FallopianMese': 'darkorange',
 'UterusMese': 'orangered',
 'CervixMese': 'palevioletred',
 'UpperVaginaMese': 'lightpink'}},
                 save= outDir + 'visualization/umap_dimensionality_reduction_label_uncorrected2.pdf')

In [None]:
plot_topic(cistopic_obj,
            reduction_name = 'UMAP',
            target = 'cell',
            num_columns=5,
            save= outDir + 'visualization/umap_dimensionality_reduction_topic_uncorrected.pdf')

In [None]:
from pycisTopic.clust_vis import *
cell_topic_heatmap(cistopic_obj,
                     variables = ['HarmonisedClusters'],
                     scale = False,
                     legend_loc_x = 1.05,
                     legend_loc_y = -1.2,
                     legend_dist_y = -1,
                     figsize=(5,10),
                   color_dict = {'HarmonisedClusters' : {
 'FallopianMese': 'darkorange',
 'UterusMese': 'orangered',
 'CervixMese': 'palevioletred',
 'UpperVaginaMese': 'lightpink'}},
                     save = outDir + 'visualization/heatmap_topic_contr.pdf')

### Harmony 

In [None]:
cistopic_obj.cell_data['donor'].value_counts(dropna = False)


In [None]:
# Harmony
harmony(cistopic_obj, 'donor', random_state=555, theta = 0)
# UMAP
run_umap(cistopic_obj, reduction_name='harmony_UMAP',
                 target  = 'cell', harmony=True)
run_tsne(cistopic_obj, reduction_name='harmony_tSNE',
                 target  = 'cell', harmony=True)

In [None]:
plot_metadata(cistopic_obj,
                 reduction_name='harmony_UMAP',
                 variables=[ 'HarmonisedClusters', 'donor', 'stage', 'predictedScore'], # Labels from RNA and new clusters
                 target='cell', num_columns=2,
                 text_size=10,
                 dot_size=5,
                 figsize=(10,10),
              color_dictionary = {
                                   'HarmonisedClusters' : {'FallopianMese': 'darkorange',
 'UterusMese': 'orangered',
 'CervixMese': 'palevioletred',
 'UpperVaginaMese': 'lightpink'} 
                                },
                 save= outDir + 'visualization/umap_dimensionality_reduction_label_corrected.pdf')

In [None]:
plot_metadata(cistopic_obj,
                 reduction_name='harmony_tSNE',
                 variables=[ 'HarmonisedClusters', 'donor', 'stage', 'predictedScore'], # Labels from RNA and new clusters
                 target='cell', num_columns=2,
                 text_size=10,
                 dot_size=5,
                 figsize=(10,10),
              color_dictionary = {
                                   'HarmonisedClusters' : {'FallopianMese': 'darkorange',
 'UterusMese': 'orangered',
 'CervixMese': 'palevioletred',
 'UpperVaginaMese': 'lightpink'}},
                 save= outDir + 'visualization/tsne_dimensionality_reduction_label_corrected.pdf')

In [None]:
plot_topic(cistopic_obj,
            reduction_name = 'harmony_tSNE',
            target = 'cell',
            num_columns=5,
            save= outDir + 'visualization/tsne_dimensionality_reduction_topic_corrected.pdf')

In [None]:
from pycisTopic.clust_vis import *
find_clusters(cistopic_obj,
                 target  = 'cell',
                  harmony = True,
                 k = 12,
                 res = [0.1, 0.3, 0.7],
                 prefix = 'pycisTopic_',
                 scale = True,
                 split_pattern = '-')

In [None]:
plot_metadata(cistopic_obj,
                 reduction_name = 'harmony_tSNE',
                 variables=['HarmonisedClusters', 'pycisTopic_leiden_12_0.1', 'pycisTopic_leiden_12_0.3', 'pycisTopic_leiden_12_0.7'], # Labels from RNA and new clusters
                 target='cell', num_columns=2,
                 text_size=10,
                 dot_size=5,
                 figsize=(10,10),
                 save= outDir + 'visualization/tsne_dimensionality_reduction_clustering.pdf')

In [None]:
color_palette

In [None]:
annot_dict_lowres={}

annot_dict_lowres['pycisTopic_leiden_12_0.7'] = {'1':'UpperVaginaMese', '0':'FallopianMese', 
                                                 '2': 'CervixMese', '3': 'UterusMese',
                                                 '4' : 'UterusMese',
                                                 '5': 'UpperVaginaMese', 
                                                 '6' : 'FallopianMese', 
                                                 '7' : 'UpperVaginaMese',
                                                 '8' : 'CervixMese', 
                                                }

cistopic_obj.cell_data['mese_mullerian_lowres'] = [annot_dict_lowres['pycisTopic_leiden_12_0.7'][x] for x in cistopic_obj.cell_data['pycisTopic_leiden_12_0.7'].tolist()]



In [None]:
plot_metadata(cistopic_obj,
                 reduction_name = 'harmony_tSNE',
                 variables=['mese_mullerian_lowres'], # Labels from RNA and new clusters
                 target='cell', num_columns=2,
                 text_size=10,
                 dot_size=5,
                 figsize=(10,5),
              color_dictionary = {
                                   'mese_mullerian_lowres' : {'FallopianMese': 'orange',
 'UterusMese': 'orangered',
 'CervixMese': 'palevioletred',
 'UpperVaginaMese': 'lightpink'}},
                 save= outDir + 'visualization/tsne_dimensionality_reduction_lowres.pdf')

In [None]:
# Save
with open(outDir + 'cisTopicObject_clean.pkl', 'wb') as f:
  pickle.dump(cistopic_obj, f)


In [None]:
# Load cisTopic object
import pickle
infile = open(outDir + 'cisTopicObject_clean.pkl', 'rb')
cistopic_obj = pickle.load(infile)
infile.close()

In [None]:
from pycisTopic.clust_vis import *

In [None]:
outDir

In [None]:
plot_metadata(cistopic_obj,
                 reduction_name = 'harmony_tSNE',
                 variables=['mese_mullerian_lowres'], # Labels from RNA and new clusters
                 target='cell', 
              num_columns=1,
                 text_size=10,
                 dot_size=2,
                 figsize=(5,5),
              show_label = False, 
              show_legend = False,
              color_dictionary = {
                                   'mese_mullerian_lowres' : {'FallopianMese': 'orange',
 'UterusMese': 'orangered',
 'CervixMese': 'palevioletred',
 'UpperVaginaMese': 'lightpink'}},
                 save= outDir + 'visualization/tsne_dimensionality_reduction_lowres.pdf')

In [None]:
plot_metadata(cistopic_obj,
                 reduction_name = 'harmony_tSNE',
                 variables=['stage'], # Labels from RNA and new clusters
                 target='cell', 
              num_columns=1,
                 text_size=10,
                 dot_size=2,
                 figsize=(6,5),
              show_label = False, 
              show_legend = False,
                 save= outDir + 'visualization/tsne_dimensionality_reduction_lowres_stage.pdf')

In [None]:
plot_metadata(cistopic_obj,
                 reduction_name = 'harmony_tSNE',
                 variables=['donor'], # Labels from RNA and new clusters
                 target='cell', 
              num_columns=1,
                 text_size=10,
                 dot_size=2,
                 figsize=(6,5),
              show_label = False, 
              show_legend = True,
                 save= outDir + 'visualization/tsne_dimensionality_reduction_lowres_donor.pdf')

In [None]:
# os.mkdir(outDir+'topic_binarization')
from pycisTopic.topic_binarization import *
region_bin_topics = binarize_topics(cistopic_obj, method='otsu', ntop=3000, plot=True, num_columns=5, save= outDir + 'topic_binarization/otsu.pdf')


In [None]:
binarized_cell_topic = binarize_topics(cistopic_obj, target='cell', method='li', plot=True, num_columns=5, nbins=100)


In [None]:
from pycisTopic.topic_qc import *
topic_qc_metrics = compute_topic_metrics(cistopic_obj)

In [None]:
fig_dict={}
fig_dict['CoherenceVSAssignments']=plot_topic_qc(topic_qc_metrics, var_x='Coherence', var_y='Log10_Assignments', var_color='Gini_index', plot=False, return_fig=True)
fig_dict['AssignmentsVSCells_in_bin']=plot_topic_qc(topic_qc_metrics, var_x='Log10_Assignments', var_y='Cells_in_binarized_topic', var_color='Gini_index', plot=False, return_fig=True)
fig_dict['CoherenceVSCells_in_bin']=plot_topic_qc(topic_qc_metrics, var_x='Coherence', var_y='Cells_in_binarized_topic', var_color='Gini_index', plot=False, return_fig=True)
fig_dict['CoherenceVSRegions_in_bin']=plot_topic_qc(topic_qc_metrics, var_x='Coherence', var_y='Regions_in_binarized_topic', var_color='Gini_index', plot=False, return_fig=True)
fig_dict['CoherenceVSMarginal_dist']=plot_topic_qc(topic_qc_metrics, var_x='Coherence', var_y='Marginal_topic_dist', var_color='Gini_index', plot=False, return_fig=True)
fig_dict['CoherenceVSGini_index']=plot_topic_qc(topic_qc_metrics, var_x='Coherence', var_y='Gini_index', var_color='Gini_index', plot=False, return_fig=True)


In [None]:
# Plot topic stats in one figure
fig=plt.figure(figsize=(40, 43))
i = 1
for fig_ in fig_dict.keys():
    plt.subplot(2, 3, i)
    img = fig2img(fig_dict[fig_]) #To convert figures to png to plot together, see .utils.py. This converts the figure to png.
    plt.imshow(img)
    plt.axis('off')
    i += 1
plt.subplots_adjust(wspace=0, hspace=-0.70)
fig.savefig(outDir + 'topic_binarization/Topic_qc.pdf', bbox_inches='tight')
plt.show()

In [None]:
topic_annot = topic_annotation(cistopic_obj, annot_var='mese_mullerian_lowres', binarized_cell_topic=binarized_cell_topic, general_topic_thr = 0.2)
topic_qc_metrics = pd.concat([topic_annot[['mese_mullerian_lowres', 'Ratio_cells_in_topic', 'Ratio_group_in_population']], topic_qc_metrics], axis=1)
topic_qc_metrics.head()


In [None]:
# Save
with open(outDir + 'topic_binarization/Topic_qc_metrics_annot.pkl', 'wb') as f:
  pickle.dump(topic_qc_metrics, f)
with open(outDir + 'topic_binarization/binarized_cell_topic.pkl', 'wb') as f:
  pickle.dump(binarized_cell_topic, f)
with open(outDir + 'topic_binarization/binarized_topic_region.pkl', 'wb') as f:
  pickle.dump(region_bin_topics, f)

### Differentially Accessible Regions

In [None]:
# Load cisTopic object
import pickle
infile = open(outDir + 'cisTopicObject_clean.pkl', 'rb')
cistopic_obj = pickle.load(infile)
infile.close()

In [None]:
from pycisTopic.diff_features import *
imputed_acc_obj = impute_accessibility(cistopic_obj, selected_cells=None, selected_regions=None, scale_factor=10**6)

In [None]:
include = set(imputed_acc_obj.feature_names) & set(cistopic_obj.region_data.index.to_list())
len(include)

In [None]:
diff = set(imputed_acc_obj.feature_names) - set(cistopic_obj.region_data.index.to_list())

In [None]:
diff

In [None]:
imputed_acc_obj = imputed_acc_obj.subset(features = list(include), copy = True)

In [None]:
str(imputed_acc_obj)

In [None]:
normalized_imputed_acc_obj = normalize_scores(imputed_acc_obj, scale_factor=10**4)

In [None]:
# os.mkdir(outDir + 'DARs/')
variable_regions = find_highly_variable_features(normalized_imputed_acc_obj,
                                           min_disp = 0.05,
                                           min_mean = 0.0125,
                                           max_mean = 3,
                                           max_disp = np.inf,
                                           n_bins=20,
                                           n_top_features=None,
                                           plot=True,
                                           save= outDir + 'DARs/HVR_plot.pdf')

In [None]:
len(variable_regions)

In [None]:
markers_dict= find_diff_features(cistopic_obj,
                      imputed_acc_obj,
                      variable='mese_mullerian_lowres',
                      var_features=variable_regions,
                      contrasts=None,
                      adjpval_thr=0.05,
                      log2fc_thr=np.log2(1.5),
                      n_cpu=10)

In [None]:
x = [print(x + ': '+ str(len(markers_dict[x]))) for x in markers_dict.keys()]

In [None]:
# Save
with open(outDir + 'DARs/Imputed_accessibility.pkl', 'wb') as f:
  pickle.dump(imputed_acc_obj, f)
with open(outDir + 'DARs/DARs.pkl', 'wb') as f:
  pickle.dump(markers_dict, f)
with open(outDir + 'DARs/variable_regions.pkl', 'wb') as f:
  pickle.dump(variable_regions, f)

In [None]:
from pycisTopic.clust_vis import *
plot_imputed_features(cistopic_obj,
                    reduction_name='harmony_tSNE',
                    imputed_data=imputed_acc_obj,
                    features=[markers_dict[x].index.tolist()[0] for x in ['FallopianMese',
 'UterusMese',
 'CervixMese',
 'UpperVaginaMese']],
                    scale=False,
                    num_columns=3,
                    save= outDir + 'DARs/example_best_DARs.pdf')

### Gene activity scores

In [None]:
# Load cisTopic object
import pickle
infile = open(outDir + 'cisTopicObject_clean.pkl', 'rb')
cistopic_obj = pickle.load(infile)
infile.close()
# Load imputed accessibility
import pickle
infile = open(outDir + 'DARs/Imputed_accessibility.pkl', 'rb')
imputed_acc_obj = pickle.load(infile)
infile.close()
# Load DARs
import pickle
infile = open(outDir + 'DARs/DARs.pkl', 'rb')
DARs_dict = pickle.load(infile)
infile.close()

In [None]:
str(imputed_acc_obj)

In [None]:
# Get TSS annotations
import pybiomart as pbm
import pyranges as pr
# For mouse
#dataset = pbm.Dataset(name='mmusculus_gene_ensembl',  host='http://www.ensembl.org')
# For human (hg38)
dataset = pbm.Dataset(name='hsapiens_gene_ensembl',  host='http://www.ensembl.org')
# For human (hg19)
#dataset = pbm.Dataset(name='hsapiens_gene_ensembl',  host='http://grch37.ensembl.org/')
# For fly
#dataset = pbm.Dataset(name='dmelanogaster_gene_ensembl',  host='http://www.ensembl.org')
annot = dataset.query(attributes=['chromosome_name', 'start_position', 'end_position', 'strand', 'external_gene_name', 'transcription_start_site', 'transcript_biotype'])
annot['Chromosome/scaffold name'] = 'chr' + annot['Chromosome/scaffold name'].astype(str)
annot.columns=['Chromosome', 'Start', 'End', 'Strand', 'Gene','Transcription_Start_Site', 'Transcript_type']
annot = annot[annot.Transcript_type == 'protein_coding']
annot.Strand[annot.Strand == 1] = '+'
annot.Strand[annot.Strand == -1] = '-'
pr_annotation = pr.PyRanges(annot.dropna(axis = 0))

In [None]:

# Get chromosome sizes
import pandas as pd
import requests
target_url='http://hgdownload.cse.ucsc.edu/goldenPath/hg38/bigZips/hg38.chrom.sizes'
chromsizes=pd.read_csv(target_url, sep='\t', header=None)
chromsizes.columns=['Chromosome', 'End']
chromsizes['Start']=[0]*chromsizes.shape[0]
chromsizes=chromsizes.loc[:,['Chromosome', 'Start', 'End']]
chromsizes=pr.PyRanges(chromsizes)

In [None]:
from pycisTopic.gene_activity import *
gene_act, weigths = get_gene_activity(imputed_acc_obj, # Region-cell probabilities
                pr_annotation, # Gene annotation
                chromsizes, # Chromosome size
                use_gene_boundaries=True, # Whether to use the whole search space or stop when encountering another gene
                upstream=[1000, 100000], # Search space upstream. The minimum means that even if there is a gene right next to it
                                      #these bp will be taken (1kbp here)
                downstream=[1000,100000], # Search space downstream
                distance_weight=True, # Whether to add a distance weight (an exponential function, the weight will decrease with distance)
                decay_rate=1, # Exponent for the distance exponential funciton (the higher the faster will be the decrease)
                extend_gene_body_upstream=10000, # Number of bp upstream immune to the distance weight (their value will be maximum for
                                      #this weight)
                extend_gene_body_downstream=500, # Number of bp downstream immune to the distance weight
                gene_size_weight=False, # Whether to add a weights based on the length of the gene
                gene_size_scale_factor='median', # Dividend to calculate the gene size weigth. Default is the median value of all genes
                                      #in the genome
                remove_promoters=False, # Whether to remove promoters when computing gene activity scores
                average_scores=True, # Whether to divide by the total number of region assigned to a gene when calculating the gene
                                      #activity score
                scale_factor=1, # Value to multiply for the final gene activity matrix
                extend_tss=[10,10], # Space to consider a promoter
                gini_weight = True, # Whether to add a gini index weigth. The more unique the region is, the higher this weight will be
                return_weights= True, # Whether to return the final weights
                project='Gene_activity') # Project name for the gene activity object

In [None]:
markers_dict= find_diff_features(cistopic_obj,
                      gene_act,
                      variable='mese_mullerian_lowres',
                      var_features=None,
                      contrasts=None,
                      adjpval_thr=0.05,
                      log2fc_thr=np.log2(1.5),
                      n_cpu=1,
                      #_temp_dir=tmpDir + 'ray_spill'
                                )

In [None]:
# os.mkdir(outDir+'DAGs')
from pycisTopic.clust_vis import *
plot_imputed_features(cistopic_obj,
                    reduction_name='harmony_tSNE',
                    imputed_data=gene_act,
                    features=['LGR5', 'TSPAN8', 'CD36', 'ITGBL1', 'HMGA2', 'KRT8', 'KRT18', 'ATF3', 'KLF2', 'ITGA4', 'SEMA3A', 'NR4A1', 'MAFF', 'CSRNP1',
                             'HOXA9', 'HOXD9', 'HOXA10', 'HOXD10', 'HOXA11', 'HOXD11', 'HOXA7', 'HOXC8', 'HOXC6', 'HOXC5', 'HOXC4',
                              'ETV4', 'CRABP1', 'CNTN1', 'TMEM163', 'ZAP70', 'MMP28', 'HOXA13', 'SRD5A2', 'WIF1'],
                    scale=True,
                    num_columns=4, cmap = 'jet',
                    save= outDir + 'DAGs/example_best_DAGs.pdf')

In [None]:
x = [print(x + ': '+ str(len(markers_dict[x]))) for x in markers_dict.keys()]


In [None]:
# Save
with open(outDir + 'DAGs/Gene_activity.pkl', 'wb') as f:
  pickle.dump(gene_act, f)
with open(outDir + 'DAGs/DAGs.pkl', 'wb') as f:
  pickle.dump(markers_dict, f)

### Label transfer

In [None]:
# # Load cisTopic object
# import pickle
# infile = open(outDir + 'cisTopicObject_clean.pkl', 'rb')
# cistopic_obj = pickle.load(infile)
# infile.close()

In [None]:
# cistopic_obj.cell_data

In [None]:
# # Prepare RNA
# from loomxpy.loomxpy import SCopeLoom
# from pycisTopic.loom import *
# import itertools
# import anndata
# import scanpy as sc
# rna_anndata = sc.read('/nfs/team292/vl6/FetalReproductiveTract/mullerian_mese_late_downsampled.h5ad')
# rna_anndata

In [None]:
# rna_anndata.obs['mese_mullerian_highres'].value_counts()

In [None]:
# # Recode RNA 
# recode = {'Mesenchymal_FallopianTube_late' : 'MesenchymalFallopianTubeLate', 'Mesenchymal_Uterus_late' : 'MesenchymalUterusLate', 'Mesenchymal_FallopianTube_early' : 'MesenchymalFallopianTubeEarly', 'Mesenchymal_Uterus_early' : 'MesenchymalUterusEarly', 
#          'Mesenchymal_MüllerianDuct' : 'MesenchymalMüllerianDuct'}

In [None]:
# rna_anndata.obs['mese_mullerian_highres'] = rna_anndata.obs['mese_mullerian_highres'].map(recode)

In [None]:
# rna_anndata = anndata.AnnData(X = rna_anndata.raw.X, var = rna_anndata.raw.var, obs = rna_anndata.obs)

In [None]:
# rna_anndata.obs['mese_mullerian_highres'].value_counts()

In [None]:
# # Prepare ATAC
# import pickle
# infile = open(outDir + 'DAGs/Gene_activity.pkl', 'rb') #Here I am using pycisTopic gene activity matrix, but could be any :)
# gene_act = pickle.load(infile)
# infile.close()
# atac_anndata = anndata.AnnData(X=gene_act.mtx.T, obs=pd.DataFrame(index=gene_act.cell_names), var=pd.DataFrame(index=gene_act.feature_names))
# atac_anndata.obs = cistopic_obj.cell_data

In [None]:
# atac_anndata

In [None]:
# atac_anndata.obs['mese_mullerian_highres'].value_counts()

In [None]:
# from pycisTopic.label_transfer import *
# label_dict = label_transfer(rna_anndata,
#                   atac_anndata,
#                   labels_to_transfer = ['mese_mullerian_highres'],
#                   variable_genes = True,
#                   methods = ['ingest', 'harmony', 'bbknn', 'scanorama', 'cca'],
#                   return_label_weights = False,
#                   #_temp_dir= ''
#                            )

In [None]:
# label_dict_x=[label_dict[key] for key in label_dict.keys()]
# label_pd = pd.concat(label_dict_x, axis=1, sort=False)
# label_pd.index = cistopic_obj.cell_names
# label_pd.columns = ['pycisTopic_' + x for x in label_pd.columns]
# cistopic_obj.add_cell_data(label_pd, split_pattern = '-')

In [None]:
# from pycisTopic.clust_vis import *
# plot_metadata(cistopic_obj,
#              reduction_name='harmony_tSNE',
#              variables= label_pd.columns.to_list(),
#              remove_nan=True,
#              cmap=cm.viridis,
#              seed=555,
#              num_columns=3,
#              color_dictionary={},
#              save= outDir + 'DAGs/label_transfer.pdf')

### pycisTarget

In [None]:
outDir

In [None]:
# Load region binarized topics
import pickle
infile = open(outDir+'topic_binarization/binarized_topic_region.pkl', 'rb')
binarized_topic_region = pickle.load(infile)
infile.close()
# Load DARs
import pickle
infile = open(outDir+'DARs/DARs.pkl', 'rb')
DARs_dict = pickle.load(infile)
infile.close()
# Format region sets
import re
import pyranges as pr
from pycistarget.utils import *
region_sets = {}
region_sets['Topics'] = {key: pr.PyRanges(region_names_to_coordinates(binarized_topic_region[key].index.tolist())) for key in binarized_topic_region.keys()}
region_sets['DARs'] = {re.sub('[^A-Za-z0-9]+', '_', key): pr.PyRanges(region_names_to_coordinates(DARs_dict[key].index.tolist())) for key in DARs_dict.keys()}

In [None]:
outDir

In [None]:
len('/lustre/scratch126/cellgen/team292/vl6/tmp/session_2023-01-23_22-04-33_639589_39002/sockets/plasma_store')

In [None]:
# Run pycistarget
# run_without_promoters = True, will run the methods in all regions + the region sets without promoters
import os
os.chdir('/nfs/team292/vl6/scenicplus/src/')
from scenicplus.wrappers.run_pycistarget import *
run_pycistarget(region_sets,
                 ctx_db_path = '/nfs/team292/vl6/scenicplus/hg38_screen_v10_clust.regions_vs_motifs.rankings.feather',
                 species = 'homo_sapiens',
                 save_path = '/lustre/scratch126/cellgen/team292/vl6/pycistarget/mullerian_mese_withvagina_post9pcw/',
                 dem_db_path = '/nfs/team292/vl6/scenicplus/hg38_screen_v10_clust.regions_vs_motifs.scores.feather',
                 run_without_promoters = True,
                 biomart_host = 'http://www.ensembl.org',
                 promoter_space = 500,
                 ctx_auc_threshold = 0.005,
                 ctx_nes_threshold = 3.0,
                 ctx_rank_threshold = 0.05,
                 dem_log2fc_thr = 0.5,
                 dem_motif_hit_thr = 3.0,
                 dem_max_bg_regions = 500,
                 path_to_motif_annotations = '/nfs/team292/vl6/scenicplus/motifs-v10nr_clust-nr.hgnc-m0.001-o0.0.tbl',
                 annotation_version = 'v10nr_clust',
                 annotation = ['Direct_annot', 'Orthology_annot'],
                 n_cpu = 1,
                 #_temp_dir = '/lustre/scratch126/cellgen/team292/vl6/pycistarget/temp/'
               )

In [None]:
save_path = '/lustre/scratch126/cellgen/team292/vl6/pycistarget/mullerian_mese_withvagina_post9pcw/'

In [None]:
import dill
import os
menr = dill.load(open(os.path.join(save_path, 'menr.pkl'), 'rb'))

In [None]:
menr.keys()

In [None]:
outDir = '/lustre/scratch126/cellgen/team292/vl6/pycistopic/mullerian_mese_withvagina_post9pcw/'
outDir

### Infer eGRNs

In [None]:
import dill
import scanpy as sc
import os
import warnings
warnings.filterwarnings("ignore")
import pandas
import pyranges
# Set stderr to null to avoid strange messages from ray
import sys

adata = sc.read_h5ad('/nfs/team292/vl6/FetalReproductiveTract/mullerian_mese_late_post10pcw.h5ad')

cistopic_obj = dill.load(open(os.path.join(outDir, 'cisTopicObject_clean.pkl'), 'rb'))


In [None]:
import dill
import scanpy as sc
import os
import warnings
warnings.filterwarnings("ignore")
import pandas
import pyranges
# Set stderr to null to avoid strange messages from ray
import sys


cistopic_obj = dill.load(open(os.path.join(outDir, 'cisTopicObject_clean.pkl'), 'rb'))
cistopic_obj

In [None]:
cell_metadata = cistopic_obj.cell_data

In [None]:
cell_metadata.head()

In [None]:
cell_metadata.to_csv(outDir + "cell_metadata_for_cicero.csv")

In [None]:
peak_metadata = cistopic_obj.region_data
peak_metadata.head()

In [None]:
peak_metadata.to_csv(outDir + "region_metadata_for_cicero.csv")

In [None]:
cell_metadata.shape

In [None]:
lowdim = pd.DataFrame(index=cell_metadata.index, columns=['tsne1', 'tsne2'])

In [None]:
lowdim['tsne1'] = lowdim.index.map(cistopic_obj.projections['cell']['harmony_tSNE']['tSNE_1'].to_dict())
lowdim['tsne2'] = lowdim.index.map(cistopic_obj.projections['cell']['harmony_tSNE']['tSNE_2'].to_dict())

In [None]:
lowdim.head()

In [None]:
lowdim.to_csv(outDir + "tsne_harmony_for_cicero.csv")

In [None]:
from scipy.io import mmwrite

In [None]:
count_matrix = cistopic_obj.binary_matrix
count_matrix.shape

In [None]:
mmwrite(outDir + 'fragment_matrix_for_cicero.mtx', count_matrix)

In [None]:
outDir

In [None]:
adata.X[20:30, 20:30].toarray()

In [None]:
adata.raw.X.shape

In [None]:
# Find common genes between adata.raw and adata
common_genes = adata.var_names.intersection(adata.raw.var_names)

# Subset adata.raw to include only the common genes
adata_raw_common = adata.raw[:, common_genes]


In [None]:
adata_raw_common.shape

In [None]:
adata.layers["raw_count"] = adata_raw_common.X.copy()

In [None]:
adata.layers["raw_count"][20:25, 20:25].toarray()

In [None]:
adata.obs.head()

In [None]:
adata.var['highly_variable'].value_counts()

In [None]:
import pickle

# Load the list from the file
with open('/lustre/scratch126/cellgen/team292/vl6/VISIUM/tot_spatially_variable_genes_mullerian_mese.pkl', 'rb') as f:
    spatially_variable_genes = pickle.load(f)

print(len(spatially_variable_genes))


## Take intersection of HVGs and spatially variable genes for CellOracle modelling

In [None]:
# Step 1: Extract the genes that are highly variable
highly_variable_genes = adata.var_names[adata.var['highly_variable'] == True]

# Step 2: Take the union with loaded_list
# Convert the loaded_list to a set for the union operation
genes_union = set(highly_variable_genes).union(set(spatially_variable_genes))

# Step 3: Convert back to list (optional) and print the result
genes_union_list = list(genes_union)
print(len(genes_union_list))

In [None]:
adata.obs['mese_mullerian_lowres'].value_counts()

In [None]:
# Recode RNA 
recode = {'Fallopian Mese' : 'FallopianMese',
          'Uterus Mese' : 'UterusMese',  
         'Cervix Mese' : 'CervixMese', 
         'Upper Vagina Mese' : 'UpperVaginaMese'}
adata.obs['mese_mullerian_lowres'] = adata.obs['mese_mullerian_lowres'].map(recode)

In [None]:
adata.obs['mese_mullerian_lowres'].value_counts(dropna = False)

In [None]:
sc.pl.umap(adata, color = 'mese_mullerian_lowres')

In [None]:
adata.obs['mese_mullerian_lowres'].value_counts()

In [None]:
# Random downsampling per cell type 
import random
import pandas as pd
from itertools import chain
def downsample(adata, labels, n): 
    
    myindex = adata.obs[labels].value_counts().index 
    myvalues = adata.obs[labels].value_counts().values
    clusters = pd.Series(myvalues, index = myindex)
    
    # Find clusters with > n cells 
    cl2downsample = clusters.index[ clusters.values > n ]

    # save all barcode ids from small clusters
    holder = []
    holder.append( adata.obs_names[[ i not in cl2downsample for i in adata.obs[labels] ]] ) 

    # randomly sample n cells in the cl2downsample
    for cl in cl2downsample:
        print(cl)
        cl_sample = adata[[ i == cl for i in adata.obs[labels]]].obs_names
        # n = int(round(len(cl_sample)/2, 0))
        if cl == 'Mese_ExtraGonad':
            cl_downsample = random.sample(set(cl_sample), 9000 )
        else: 
            cl_downsample = random.sample(set(cl_sample), n )
        holder.append(cl_downsample)
    
    # samples to include
    samples = list(chain(*holder))

    # Filter adata_count
    adata = adata[[ i in samples for i in adata.obs_names ]]
    return adata

In [None]:
adata = downsample(adata, 'mese_mullerian_lowres', 2000)

In [None]:
sc.pl.umap(adata, color = 'mese_mullerian_lowres')

In [None]:
adata.shape

In [None]:
adata = adata[:, genes_union_list]
adata.shape

In [None]:
adata.write(outDir + 'scrnaseq_for_celloracle.h5ad')

In [None]:
to_del = ['GeneID-0', 'GeneName-0', 'n_cells-0', 'GeneID-1', 'GeneName-1', 'n_cells-1', 'GeneID-10', 'GeneName-10', 'n_cells-10', 'GeneID-11', 'GeneName-11', 'n_cells-11', 'GeneID-12', 'GeneName-12', 'n_cells-12', 'GeneID-13', 'GeneName-13', 'n_cells-13', 'GeneID-14', 'GeneName-14', 'n_cells-14', 'GeneID-15', 'GeneName-15', 'n_cells-15', 'GeneID-16', 'GeneName-16', 'n_cells-16', 'GeneID-17', 'GeneName-17', 'n_cells-17', 'GeneID-18', 'GeneName-18', 'n_cells-18', 'GeneID-19', 'GeneName-19', 'n_cells-19', 'GeneID-2', 'GeneName-2', 'n_cells-2', 'GeneID-20', 'GeneName-20', 'n_cells-20', 'GeneID-21', 'GeneName-21', 'n_cells-21', 'GeneID-22', 'GeneName-22', 'n_cells-22', 'GeneID-23', 'GeneName-23', 'n_cells-23', 'GeneID-24', 'GeneName-24', 'n_cells-24', 'GeneID-25', 'GeneName-25', 'n_cells-25', 'GeneID-26', 'GeneName-26', 'n_cells-26', 'GeneID-27', 'GeneName-27', 'n_cells-27', 'GeneID-28', 'GeneName-28', 'n_cells-28', 'GeneID-29', 'GeneName-29', 'n_cells-29', 'GeneID-3', 'GeneName-3', 'n_cells-3', 'GeneID-30', 'GeneName-30', 'n_cells-30', 'GeneID-31', 'GeneName-31', 'n_cells-31', 'GeneID-32', 'GeneName-32', 'n_cells-32', 'GeneID-33', 'GeneName-33', 'n_cells-33', 'GeneID-34', 'GeneName-34', 'n_cells-34', 'GeneID-35', 'GeneName-35', 'n_cells-35', 'GeneID-36', 'GeneName-36', 'n_cells-36', 'GeneID-37', 'GeneName-37', 'n_cells-37', 'GeneID-38', 'GeneName-38', 'n_cells-38', 'GeneID-39', 'GeneName-39', 'n_cells-39', 'GeneID-4', 'GeneName-4', 'n_cells-4', 'GeneID-40', 'GeneName-40', 'n_cells-40', 'GeneID-41', 'GeneName-41', 'n_cells-41', 'GeneID-42', 'GeneName-42', 'n_cells-42', 'GeneID-43', 'GeneName-43', 'n_cells-43', 'GeneID-44', 'GeneName-44', 'n_cells-44', 'GeneID-45', 'GeneName-45', 'n_cells-45', 'GeneID-46', 'GeneName-46', 'n_cells-46', 'GeneID-47', 'GeneName-47', 'n_cells-47', 'GeneID-48', 'GeneName-48', 'n_cells-48', 'GeneID-49', 'GeneName-49', 'n_cells-49', 'GeneID-5', 'GeneName-5', 'n_cells-5', 'GeneID-50', 'GeneName-50', 'n_cells-50', 'GeneID-51', 'GeneName-51', 'n_cells-51', 'GeneID-52', 'GeneName-52', 'n_cells-52', 'GeneID-53', 'GeneName-53', 'n_cells-53', 'GeneID-54', 'GeneName-54', 'n_cells-54', 'GeneID-55', 'GeneName-55', 'n_cells-55', 'GeneID-56', 'GeneName-56', 'n_cells-56', 'GeneID-57', 'GeneName-57', 'n_cells-57', 'GeneID-58', 'GeneName-58', 'n_cells-58', 'GeneID-59', 'GeneName-59', 'n_cells-59', 'GeneID-6', 'GeneName-6', 'n_cells-6', 'GeneID-60', 'GeneName-60', 'n_cells-60', 'GeneID-61', 'GeneName-61', 'n_cells-61', 'GeneID-62', 'GeneName-62', 'n_cells-62', 'GeneID-63', 'GeneName-63', 'n_cells-63', 'GeneID-64', 'GeneName-64', 'n_cells-64', 'GeneID-65', 'GeneName-65', 'n_cells-65', 'GeneID-66', 'GeneName-66', 'n_cells-66', 'GeneID-67', 'GeneName-67', 'n_cells-67', 'GeneID-68', 'GeneName-68', 'n_cells-68', 'GeneID-69', 'GeneName-69', 'n_cells-69', 'GeneID-7', 'GeneName-7', 'n_cells-7', 'GeneID-70', 'GeneName-70', 'n_cells-70', 'GeneID-71', 'GeneName-71', 'n_cells-71', 'GeneID-72', 'GeneName-72', 'n_cells-72', 'GeneID-73', 'GeneName-73', 'n_cells-73', 'GeneID-74', 'GeneName-74', 'n_cells-74', 'GeneID-75', 'GeneName-75', 'n_cells-75', 'GeneID-76', 'GeneName-76', 'n_cells-76', 'GeneID-77', 'GeneName-77', 'n_cells-77', 'GeneID-78', 'GeneName-78', 'n_cells-78', 'GeneID-79', 'GeneName-79', 'n_cells-79', 'GeneID-8', 'GeneName-8', 'n_cells-8', 'gene_ids-80', 'feature_types-80', 'gene_ids-81', 'feature_types-81', 'gene_ids-82', 'feature_types-82', 'gene_ids-83', 'feature_types-83', 'gene_ids-84', 'feature_types-84', 'gene_ids-85', 'feature_types-85', 'gene_ids-86', 'feature_types-86', 'gene_ids-87', 'feature_types-87', 'gene_ids-88', 'feature_types-88', 'gene_ids-89', 'feature_types-89', 'GeneID-9', 'GeneName-9', 'n_cells-9']
for d in to_del:
    del adata.var[d]

In [None]:
adata

In [None]:
import anndata 
adata = anndata.AnnData(X = adata.raw.X, var = adata.raw.var, obs = adata.obs)

In [None]:
str(cistopic_obj)

In [None]:
cistopic_obj.cell_data.head()

In [None]:
cistopic_obj.cell_data.columns

In [None]:
cistopic_obj.region_data.head()

In [None]:
import pickle
infile = open(outDir + 'DARs/Imputed_accessibility.pkl', 'rb')
imputed_acc_obj = pickle.load(infile)
infile.close()

In [None]:
imputed_acc_obj

In [None]:
from scenicplus.scenicplus_class import create_SCENICPLUS_object
import numpy as np
scplus_obj = create_SCENICPLUS_object(
        GEX_anndata = adata,
        cisTopic_obj = cistopic_obj,
        imputed_acc_obj = imputed_acc_obj,
        menr = menr,
        multi_ome_mode = False,
        nr_cells_per_metacells = 20,
        key_to_group_by = 'mese_mullerian_lowres')

In [None]:
print(scplus_obj)

In [None]:
from scenicplus.preprocessing.filtering import *

In [None]:
filter_genes(scplus_obj, min_pct = 10)
filter_regions(scplus_obj, min_pct = 10)

In [None]:
# Merge cistromes (all)
from scenicplus.cistromes import *
import time
start_time = time.time()
merge_cistromes(scplus_obj)
time = time.time()-start_time
print(time/60)

In [None]:
ensembl_version_dict = {'105': 'http://www.ensembl.org',
                        '104': 'http://may2021.archive.ensembl.org/',
                        '103': 'http://feb2021.archive.ensembl.org/',
                        '102': 'http://nov2020.archive.ensembl.org/',
                        '101': 'http://aug2020.archive.ensembl.org/',
                        '100': 'http://apr2020.archive.ensembl.org/',
                        '99': 'http://jan2020.archive.ensembl.org/',
                        '98': 'http://sep2019.archive.ensembl.org/',
                        '97': 'http://jul2019.archive.ensembl.org/',
                        '96': 'http://apr2019.archive.ensembl.org/',
                        '95': 'http://jan2019.archive.ensembl.org/',
                        '94': 'http://oct2018.archive.ensembl.org/',
                        '93': 'http://jul2018.archive.ensembl.org/',
                        '92': 'http://apr2018.archive.ensembl.org/',
                        '91': 'http://dec2017.archive.ensembl.org/',
                        '90': 'http://aug2017.archive.ensembl.org/',
                        '89': 'http://may2017.archive.ensembl.org/',
                        '88': 'http://mar2017.archive.ensembl.org/',
                        '87': 'http://dec2016.archive.ensembl.org/',
                        '86': 'http://oct2016.archive.ensembl.org/',
                        '80': 'http://may2015.archive.ensembl.org/',
                        '77': 'http://oct2014.archive.ensembl.org/',
                        '75': 'http://feb2014.archive.ensembl.org/',
                        '54': 'http://may2009.archive.ensembl.org/'}

import pybiomart as pbm
def test_ensembl_host(scplus_obj, host, species):
    dataset = pbm.Dataset(name=species+'_gene_ensembl',  host=host)
    annot = dataset.query(attributes=['chromosome_name', 'transcription_start_site', 'strand', 'external_gene_name', 'transcript_biotype'])
    annot.columns = ['Chromosome', 'Start', 'Strand', 'Gene', 'Transcript_type']
    annot['Chromosome'] = annot['Chromosome'].astype('str')
    filter = annot['Chromosome'].str.contains('CHR|GL|JH|MT')
    annot = annot[~filter]
    annot.columns=['Chromosome', 'Start', 'Strand', 'Gene', 'Transcript_type']
    gene_names_release = set(annot['Gene'].tolist())
    ov=len([x for x in scplus_obj.gene_names if x in gene_names_release])
    print('Genes recovered: ' + str(ov) + ' out of ' + str(len(scplus_obj.gene_names)))
    return ov

n_overlap = {}
for version in ensembl_version_dict.keys():
    print(f'host: {version}')
    try:
        n_overlap[version] =  test_ensembl_host(scplus_obj, ensembl_version_dict[version], 'hsapiens')
    except:
        print('Host not reachable')
v = sorted(n_overlap.items(), key=lambda item: item[1], reverse=True)[0][0]
print(f"version: {v} has the largest overlap, use {ensembl_version_dict[v]} as biomart host")

In [None]:
tf_file = '/nfs/team292/vl6/scenicplus/allTFs_hg38.txt'

# Open the file in read mode
with open(tf_file, 'r') as file:
    # Read lines from the file and remove newline characters
    tfs = [line.strip() for line in file.readlines()]

In [None]:
len(tfs)

In [None]:
"ESRRG" in tfs

In [None]:
# tfs = [t for t in tfs if not t.startswith("ZNF")]

In [None]:
# len(tfs)

In [None]:
# # Specify the file path
# file_path = '/nfs/team292/vl6/scenicplus/nonZNF_TFs_hg38.txt'

# # Open the file in write mode
# with open(file_path, 'w') as file:
#     # Write each element of the list followed by a newline character
#     for element in tfs:
#         file.write(element + '\n')

In [None]:
biomart_host = "http://sep2019.archive.ensembl.org/"

In [None]:
from scenicplus.enhancer_to_gene import get_search_space, calculate_regions_to_genes_relationships, GBM_KWARGS
get_search_space(scplus_obj,
                 biomart_host = biomart_host,
                 species = 'hsapiens',
                 assembly = 'hg38',
                 upstream = [1000, 150000],
                 downstream = [1000, 150000])

In [None]:
calculate_regions_to_genes_relationships(scplus_obj,
                    ray_n_cpu = 20,
                    #_temp_dir = tmpDir,
                    importance_scoring_method = 'GBM',
                    importance_scoring_kwargs = GBM_KWARGS)

In [None]:
# Save
import pickle
with open(outDir + 'scplus_obj.pkl', 'wb') as f:
  pickle.dump(scplus_obj, f)

In [None]:
import pickle
infile = open(outDir + 'scplus_obj.pkl', 'rb')
scplus_obj = pickle.load(infile)
infile.close()

In [None]:
print(scplus_obj)

In [None]:
scplus_obj.uns.keys()

In [None]:
def timestamp(dt):
    return f"{dt.year}{dt.month}{dt.day}_{dt.hour}{dt.minute}{dt.second}"

In [None]:
"""Link transcription factors (TFs) to genes based on co-expression of TF and target genes.

Both linear methods (spearman or pearson correlation) and non-linear methods (random forrest or gradient boosting) are used to link TF to genes.

The correlation methods are used to seperate TFs which are infered to have a positive influence on gene expression (i.e. positive correlation) 
and TFs which are infered to have a negative influence on gene expression (i.e. negative correlation).

"""


import logging
import os
import shutil
import sys
import tempfile
import time
from datetime import datetime

import joblib
import numpy as np
import pandas as pd
import scipy.sparse
from arboreto.algo import _prepare_input
from arboreto.core import (EARLY_STOP_WINDOW_LENGTH, RF_KWARGS, SGBM_KWARGS,
                           infer_partial_network, to_tf_matrix)
from arboreto.utils import load_tf_names
from tqdm import tqdm

from scenicplus.scenicplus_class import SCENICPLUS
from scenicplus.utils import _create_idx_pairs, masked_rho4pairs

COLUMN_NAME_TARGET = "target"
COLUMN_NAME_WEIGHT = "importance"
COLUMN_NAME_REGULATION = "regulation"
COLUMN_NAME_CORRELATION = "rho"
COLUMN_NAME_TF = "TF"
COLUMN_NAME_SCORE_1 = "importance_x_rho"
COLUMN_NAME_SCORE_2 = "importance_x_abs_rho"
RHO_THRESHOLD = 0.03

# Create logger
level = logging.INFO
format = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
handlers = [logging.StreamHandler(stream=sys.stdout)]
logging.basicConfig(level=level, format=format, handlers=handlers)
log = logging.getLogger('TF2G')

def _inject_TF_as_its_own_target(
    scplus_obj: SCENICPLUS = None,
    TF2G_adj: pd.DataFrame = None, 
    ex_mtx: pd.DataFrame = None,
    rho_threshold = RHO_THRESHOLD, 
    TF2G_key = 'TF2G_adj', 
    out_key = 'TF2G_adj',
    inplace = True,
    increase_importance_by = 0.00001) -> pd.DataFrame:
    if scplus_obj is None and TF2G_adj is None:
        raise ValueError('Either provide a SCENIC+ object of a pd.DataFrame with TF to gene adjecencies!')
    if scplus_obj is not None and TF2G_adj is not None:
        raise ValueError('Either provide a SCENIC+ object of a pd.DataFrame with TF to gene adjecencies! Not both!')

    log.info(f"Warning: adding TFs as their own target to adjecencies matrix. Importance values will be max + {increase_importance_by}")
    
    origin_TF2G_adj = scplus_obj.uns[TF2G_key] if scplus_obj is not None else TF2G_adj
    ex_mtx = scplus_obj.to_df(layer='EXP') if scplus_obj is not None else ex_mtx

    origin_TF2G_adj = origin_TF2G_adj.sort_values('TF')
    max_importances = origin_TF2G_adj.groupby('TF').max()['importance']

    TFs_in_adj = list(set(origin_TF2G_adj['TF'].to_list()))
    TF_to_TF_adj = pd.DataFrame(
                    data = {"TF": TFs_in_adj,
                            "target": TFs_in_adj,
                            "importance": max_importances.loc[TFs_in_adj] + increase_importance_by})
    TF_to_TF_adj = _add_correlation(
            adjacencies=TF_to_TF_adj,
            ex_mtx = ex_mtx,
            rho_threshold=rho_threshold)

    new_TF2G_adj = pd.concat([origin_TF2G_adj, TF_to_TF_adj]).reset_index(drop = True)
    if inplace:
        scplus_obj.uns[out_key] = new_TF2G_adj
        return None
    else:
        return new_TF2G_adj


def load_TF2G_adj_from_file(SCENICPLUS_obj: SCENICPLUS,
                            f_adj: str,
                            inplace=True,
                            key='TF2G_adj',
                            rho_threshold=RHO_THRESHOLD):
    """
    Function to load TF2G adjacencies from file

    Parameters
    ----------
    SCENICPLUS_obj
        An instance of :class:`~scenicplus.scenicplus_class.SCENICPLUS`
    f_adj
        File path to TF2G adjacencies matrix
    inplace
        Boolean specifying wether or not to store adjacencies matrix in `SCENICPLUS_obj` under slot .uns[key].
        Default: True
    key_added
        String specifying where in the .uns slot to store the adjacencies matrix in `SCENICPLUS_obj`
        Default: "TF2G_adj"
    rho_threshold
        A floating point number specifying from which absolute value to consider a correlation coefficient positive or negative.
        Default: 0.03
    """
    log.info(f'Reading file: {f_adj}')
    df_TF_gene_adj = pd.read_csv(f_adj, sep='\t')
    # only keep relevant entries
    idx_to_keep = np.logical_and(np.array([tf in SCENICPLUS_obj.gene_names for tf in df_TF_gene_adj['TF']]),
                                 np.array([gene in SCENICPLUS_obj.gene_names for gene in df_TF_gene_adj['target']]))
    df_TF_gene_adj_subset = df_TF_gene_adj.loc[idx_to_keep]
    if COLUMN_NAME_CORRELATION not in df_TF_gene_adj_subset.columns:
        log.info('Adding correlation coefficients to adjacencies.')
        df_TF_gene_adj_subset = _add_correlation(
            adjacencies=df_TF_gene_adj_subset,
            ex_mtx=SCENICPLUS_obj.to_df(layer='EXP'),
            rho_threshold=rho_threshold)
    df_TF_gene_adj_subset = _inject_TF_as_its_own_target(
        TF2G_adj=df_TF_gene_adj_subset, 
        inplace = False, 
        ex_mtx = SCENICPLUS_obj.to_df(layer='EXP'))
    if COLUMN_NAME_SCORE_1 not in df_TF_gene_adj_subset.columns:
        log.info('Adding importance x rho scores to adjacencies.')
        df_TF_gene_adj_subset[COLUMN_NAME_SCORE_1] = df_TF_gene_adj_subset[COLUMN_NAME_CORRELATION] * \
            df_TF_gene_adj_subset[COLUMN_NAME_WEIGHT]
    if COLUMN_NAME_SCORE_2 not in df_TF_gene_adj_subset.columns:
        log.info('Adding importance x |rho| scores to adjacencies.')
        df_TF_gene_adj_subset[COLUMN_NAME_SCORE_2] = abs(
            df_TF_gene_adj_subset[COLUMN_NAME_CORRELATION]) * abs(df_TF_gene_adj_subset[COLUMN_NAME_WEIGHT])

    if inplace:
        log.info(f'Storing adjacencies in .uns["{key}"].')
        SCENICPLUS_obj.uns[key] = df_TF_gene_adj_subset
    else:
        return df_TF_gene_adj_subset


def _add_correlation(
        adjacencies: pd.DataFrame,
        ex_mtx: pd.DataFrame,
        rho_threshold=RHO_THRESHOLD,
        mask_dropouts=False):
    """
    Add correlation in expression levels between target and factor.

    Parameters
    ----------
    adjacencies: pd.DataFrame
        The dataframe with the TF-target links.
    ex_mtx: pd.DataFrame
        The expression matrix (n_cells x n_genes).
    rho_threshold: float
        The threshold on the correlation to decide if a target gene is activated
        (rho > `rho_threshold`) or repressed (rho < -`rho_threshold`).
    mask_dropouts: boolean
        Do not use cells in which either the expression of the TF or the target gene is 0 when
        calculating the correlation between a TF-target pair.

    Returns
    -------
        The adjacencies dataframe with an extra column.
    """
    assert rho_threshold > 0, "rho_threshold should be greater than 0."

    # Calculate Pearson correlation to infer repression or activation.
    if mask_dropouts:
        ex_mtx = ex_mtx.sort_index(axis=1)
        col_idx_pairs = _create_idx_pairs(adjacencies, ex_mtx)
        rhos = masked_rho4pairs(ex_mtx.values, col_idx_pairs, 0.0)
    else:
        genes = list(set(adjacencies[COLUMN_NAME_TF]).union(
            set(adjacencies[COLUMN_NAME_TARGET])))
        ex_mtx = ex_mtx[ex_mtx.columns[ex_mtx.columns.isin(genes)]]
        corr_mtx = pd.DataFrame(
            index=ex_mtx.columns, columns=ex_mtx.columns, data=np.corrcoef(ex_mtx.values.T))
        rhos = np.array([corr_mtx[s2][s1]
                        for s1, s2 in zip(adjacencies.TF, adjacencies.target)])

    regulations = (rhos > rho_threshold).astype(
        int) - (rhos < -rho_threshold).astype(int)
    return pd.DataFrame(
        data={
            COLUMN_NAME_TF: adjacencies[COLUMN_NAME_TF].values,
            COLUMN_NAME_TARGET: adjacencies[COLUMN_NAME_TARGET].values,
            COLUMN_NAME_WEIGHT: adjacencies[COLUMN_NAME_WEIGHT].values,
            COLUMN_NAME_REGULATION: regulations,
            COLUMN_NAME_CORRELATION: rhos,
        }
    )

def calculate_TFs_to_genes_relationships(scplus_obj: SCENICPLUS,
                                         tf_file: str,
                                         method: str = 'GBM',
                                         n_cpu: int = 1,
                                         key: str = 'TF2G_adj',
                                         temp_dir = None):
    """
    A function to calculate TF to gene relationships using arboreto and correlation

    Parameters
    ----------
    scplus_obj
        An instance of :class:`~scenicplus.scenicplus_class.SCENICPLUS`
    tf_file
        Path to a file specifying with genes are TFs
    method
        Whether to use Gradient Boosting Machines (GBM) or random forest (RF)
    n_cpu
        Number of cpus to use
    key
        String specifying where in the .uns slot to store the adjacencies matrix in :param:`SCENICPLUS_obj`
        default: "TF2G_adj"
    **kwargs
        Parameters to pass to ray.init
    """

    if(method == 'GBM'):
        method_params = [
            'GBM',      # regressor_type
            SGBM_KWARGS  # regressor_kwargs
        ]
    elif(method == 'RF'):
        method_params = [
            'RF',       # regressor_type
            RF_KWARGS   # regressor_kwargs
        ]

    gene_names = list(scplus_obj.gene_names)
    if len(set(gene_names)) != len(gene_names):
        raise ValueError("scplus_obj contains duplicate gene names!")
    ex_matrix = scplus_obj.X_EXP

    tf_names = load_tf_names(tf_file)
    ex_matrix, gene_names, tf_names = _prepare_input(
        ex_matrix, gene_names, tf_names)
    tf_matrix, tf_matrix_gene_names = to_tf_matrix(
        ex_matrix, gene_names, tf_names)
    
    #convert ex_matrix, tf_matrix to np.array if necessary
    if isinstance(ex_matrix, np.matrix):
        ex_matrix = np.array(ex_matrix)
    elif scipy.sparse.issparse(ex_matrix):
        ex_matrix = ex_matrix.toarray()
        
    if isinstance(tf_matrix, np.matrix):
        tf_matrix = np.array(tf_matrix)
    elif scipy.sparse.issparse(tf_matrix):
        tf_matrix = tf_matrix.toarray()

    log.info('Calculating TF-to-gene importance')
    start_time = time.time()

    if temp_dir is None:
        if os.access('/dev/shm', os.W_OK):
            temp_dir = '/dev/shm'
        else:
            temp_dir = tempfile.gettempdir()

    dt = datetime.now()
    joblib.dump(
        ex_matrix, 
        os.path.join(temp_dir, f'scenicplus_ex_matrix_{timestamp(dt)}'))
    joblib.dump(
        tf_matrix,
        os.path.join(temp_dir, f'scenicplus_tf_matrix_{timestamp(dt)}'))
    ex_matrix_memmap = joblib.load(
        os.path.join(temp_dir, f'scenicplus_ex_matrix_{timestamp(dt)}'),
        mmap_mode = 'r')
    tf_matrix_memmap = joblib.load(
        os.path.join(temp_dir, f'scenicplus_tf_matrix_{timestamp(dt)}'),
        mmap_mode = 'r')
        
    def pf_inter_partial_network(target_gene_name):
        return infer_partial_network(
            target_gene_name = target_gene_name,
            target_gene_expression = ex_matrix_memmap[
                :, gene_names.index(target_gene_name)],
            regressor_type = method_params[0],
            regressor_kwargs = method_params[1],
            tf_matrix = tf_matrix_memmap,
            tf_matrix_gene_names = tf_matrix_gene_names,
            include_meta = False,
            early_stop_window_length = EARLY_STOP_WINDOW_LENGTH,
            seed = 666)
    def clean_shared_memory():
        os.remove(os.path.join(temp_dir, f'scenicplus_ex_matrix_{timestamp(dt)}'))
        os.remove(os.path.join(temp_dir, f'scenicplus_tf_matrix_{timestamp(dt)}'))
        
    try:
        TF_to_genes = joblib.Parallel(
            n_jobs = n_cpu)(
                joblib.delayed(pf_inter_partial_network)(gene)
                for gene in tqdm(
                    gene_names, 
                    total=len(gene_names), 
                    desc=f'Running using {n_cpu} cores'))
    except Exception as e:
        clean_shared_memory()
        raise Exception(e)
    finally:
        clean_shared_memory()
    adj = pd.concat(TF_to_genes).sort_values(by='importance', ascending=False)
    log.info('Took {} seconds'.format(time.time() - start_time))
    start_time = time.time()
    log.info('Adding correlation coefficients to adjacencies.')
    ex_matrix = scplus_obj.to_df(layer = 'EXP') 
    adj = _add_correlation(adj, ex_matrix)
    adj = _inject_TF_as_its_own_target(
        TF2G_adj=adj, 
        inplace = False, 
        ex_mtx = scplus_obj.to_df(layer='EXP'))
    log.info('Adding importance x rho scores to adjacencies.')
    adj[COLUMN_NAME_SCORE_1] = adj[COLUMN_NAME_CORRELATION] * \
        adj[COLUMN_NAME_WEIGHT]
    adj[COLUMN_NAME_SCORE_2] = abs(
        adj[COLUMN_NAME_CORRELATION]) * abs(adj[COLUMN_NAME_WEIGHT])
    log.info('Took {} seconds'.format(time.time() - start_time))
    scplus_obj.uns[key] = adj

In [None]:
#from scenicplus.TF_to_gene import *
tf_file = '/nfs/team292/vl6/scenicplus/allTFs_hg38.txt'
calculate_TFs_to_genes_relationships(scplus_obj,
                    tf_file = tf_file,
                    n_cpu = 20,
                    method = 'GBM',
                    key= 'TF2G_adj')

In [None]:
# Save
import pickle
with open(outDir + 'scplus_obj.pkl', 'wb') as f:
  pickle.dump(scplus_obj, f)

In [None]:
import pickle
infile = open(outDir + 'scplus_obj.pkl', 'rb')
scplus_obj = pickle.load(infile)
infile.close()

In [None]:
outDir

In [None]:
scplus_obj.uns

In [None]:
from scenicplus.plotting import coverageplot

In [None]:
outDir

### Integrated multiome plot - haven't yet implemented this

Generate plots showing the chromatin profiles per group, region-to-gene relationships and TF and gene expression to test hypothesis: 

 * As the Mullerian epithelium can change identity based on the surrounding mesenchyme, we can see if the genes associated with Fallopian Tube identity are accessible despite not being expressed in the Uterus (and viceversa)
 
 * As the Wolffian epithelium can change identity based on the surrounding mesenchyme, we can see if the genes associated with Epididymis identity are accessible despite not being expressed in the Uterus (and viceversa)

**Genes of interest** 

* **DLX5** (uterus) = chr7:97,020,396-97,024,831
* **ERP27** (fallopian tube) = chr12:14,914,039-14,938,537
* **MSX1** (uterus) = chr4:4,859,665-4,863,936
* **WNT11** (uterus) = chr11:76,186,325-76,206,502
* **EMX1** (wolffian) = chr2:72,917,519-72,934,891
* **MARCH11** (wolffian) = 
* **CALB1** (wolffian) = chr8:90,063,299-90,095,475
* **AVPR1A** (wolffian) = chr12:63,142,759-63,151,201
* **LEFTY1** (vas deferens) = chr1:225,886,282-225,889,146
* **CLDN2** (epididymis) = chrX:106,900,164-106,929,580
* **GLYAT** (epididymis) = chr11:58,708,757-58,731,943
* **SPAG11B** (epididymis) = chr8:7,450,603-7,463,542
* **SPINK2** (epididymis) = chr4:56,809,861-56,821,742
* **MGAM** (epididymis) = chr7:141,995,879-142,106,747


In [None]:
scplus_obj.uns.keys()

In [None]:
# Load functions
from scenicplus.grn_builder.gsea_approach3 import build_grn

In [None]:
build_grn

In [None]:
build_grn(scplus_obj,
         min_target_genes = 10,
         adj_pval_thr = 1,
         min_regions_per_gene = 0,
         quantiles = (0.85, 0.90, 0.95),
         top_n_regionTogenes_per_gene = (5, 10, 15),
         top_n_regionTogenes_per_region = (),
         binarize_using_basc = True,
         rho_dichotomize_tf2g = True,
         rho_dichotomize_r2g = True,
         rho_dichotomize_eregulon = True,
         rho_threshold = 0.05,
         keep_extended_motif_annot = True,
         merge_eRegulons = True,
         order_regions_to_genes_by = 'importance',
         order_TFs_to_genes_by = 'importance',
         key_added = 'eRegulons_importance',
         cistromes_key = 'Unfiltered',
         disable_tqdm = False, #If running in notebook, set to True
         ray_n_cpu = 20,
         #_temp_dir = '/lustre/scratch117/cellgen/team292/vl6/'
         )

In [None]:
import dill
with open(outDir + 'scplus_obj2.pkl', 'wb') as f:
  dill.dump(scplus_obj, f)

In [None]:
3+4

In [None]:
import dill
infile = open(outDir + 'scplus_obj2.pkl', 'rb')
scplus_obj = dill.load(infile)
infile.close()

In [None]:
print(scplus_obj)

In [None]:
scplus_obj.uns.keys()

In [None]:
from scenicplus.utils import format_egrns
format_egrns(scplus_obj, eregulons_key = 'eRegulons_importance', TF2G_key = 'TF2G_adj', key_added = 'eRegulon_metadata')


In [None]:
scplus_obj.uns['eRegulon_metadata'][40:50]


In [None]:
len(scplus_obj.uns['eRegulons_importance'])

In [None]:
# Format eRegulons
from scenicplus.eregulon_enrichment import *
get_eRegulons_as_signatures(scplus_obj, eRegulon_metadata_key='eRegulon_metadata', key_added='eRegulon_signatures')

In [None]:
## Score chromatin layer
# Region based raking
from scenicplus.cistromes import *
import time
start_time = time.time()
region_ranking = make_rankings(scplus_obj, target='region')
# Score region regulons
score_eRegulons(scplus_obj,
                ranking = region_ranking,
                eRegulon_signatures_key = 'eRegulon_signatures',
                key_added = 'eRegulon_AUC',
                enrichment_type= 'region',
                auc_threshold = 0.05,
                normalize = False,
                n_cpu = 10)
time = time.time()-start_time
print(time/60)

In [None]:
## Score transcriptome layer
# Gene based raking
from scenicplus.cistromes import *
import time
start_time = time.time()
gene_ranking = make_rankings(scplus_obj, target='gene')
# Score gene regulons
score_eRegulons(scplus_obj,
                gene_ranking,
                eRegulon_signatures_key = 'eRegulon_signatures',
                key_added = 'eRegulon_AUC',
                enrichment_type = 'gene',
                auc_threshold = 0.05,
                normalize= False,
                n_cpu = 10)
time = time.time()-start_time
print(time/60)

In [None]:
# Generate pseudobulks
import time
start_time = time.time()
generate_pseudobulks(scplus_obj,
                         variable = 'mese_mullerian_lowres',
                         auc_key = 'eRegulon_AUC',
                         signature_key = 'Gene_based',
                         nr_cells = 5,
                         nr_pseudobulks = 100,
                         seed=555)
generate_pseudobulks(scplus_obj,
                         variable = 'mese_mullerian_lowres',
                         auc_key = 'eRegulon_AUC',
                         signature_key = 'Region_based',
                         nr_cells = 5,
                         nr_pseudobulks = 100,
                         seed=555)
time = time.time()-start_time
print(time/60)

In [None]:
# Correlation between TF and eRegulons
import time
start_time = time.time()
TF_cistrome_correlation(scplus_obj,
                        variable = 'mese_mullerian_lowres',
                        auc_key = 'eRegulon_AUC',
                        signature_key = 'Gene_based',
                        out_key = 'mese_mullerian_lowres_eGRN_gene_based')
TF_cistrome_correlation(scplus_obj,
                        variable = 'mese_mullerian_lowres',
                        auc_key = 'eRegulon_AUC',
                        signature_key = 'Region_based',
                        out_key = 'mese_mullerian_lowres_eGRN_region_based')
time = time.time()-start_time
print(time/60)

In [None]:
scplus_obj

In [None]:
color_dict = {'FallopianMese': 'orange',
 'UterusMese': 'orangered',
 'CervixMese': 'palevioletred',
 'UpperVaginaMese': 'lightpink'}

In [None]:
# Region based
%matplotlib inline
import seaborn as sns
sns.set_style("white")
categories = sorted(set(scplus_obj.metadata_cell['mese_mullerian_lowres']))
print(categories)
print(color_dict)
prune_plot(scplus_obj,
           'HOXA10_+_+',
           pseudobulk_variable = 'mese_mullerian_lowres',
           show_dot_plot = True,
           show_line_plot = False,
           color_dict = color_dict,
           use_pseudobulk = True,
           auc_key = 'eRegulon_AUC',
           signature_key = 'Region_based',
           seed=555)

In [None]:
# Gene based
%matplotlib inline
sns.set_style("white")

prune_plot(scplus_obj,
           'HOXA10_+_+',
           pseudobulk_variable = 'mese_mullerian_lowres',
           show_dot_plot = True,
           show_line_plot = False,
           color_dict = color_dict,
           use_pseudobulk = True,
           auc_key = 'eRegulon_AUC',
           signature_key = 'Gene_based',
           seed=555)

In [None]:
# Region based
%matplotlib inline
import seaborn as sns
sns.set_style("white")
categories = sorted(set(scplus_obj.metadata_cell['mese_mullerian_lowres']))
print(categories)
print(color_dict)
prune_plot(scplus_obj,
           'HOXC8_+_+',
           pseudobulk_variable = 'mese_mullerian_lowres',
           show_dot_plot = True,
           show_line_plot = False,
           color_dict = color_dict,
           use_pseudobulk = True,
           auc_key = 'eRegulon_AUC',
           signature_key = 'Region_based',
           seed=555)

In [None]:
# Region based
%matplotlib inline
import seaborn as sns
sns.set_style("white")
categories = sorted(set(scplus_obj.metadata_cell['mese_mullerian_lowres']))
print(categories)
print(color_dict)
prune_plot(scplus_obj,
           'HOXC8_+_+',
           pseudobulk_variable = 'mese_mullerian_lowres',
           show_dot_plot = True,
           show_line_plot = False,
           color_dict = color_dict,
           use_pseudobulk = True,
           auc_key = 'eRegulon_AUC',
           signature_key = 'Gene_based',
           seed=555)

In [None]:
# Region based
%matplotlib inline
import seaborn as sns
sns.set_style("white")
categories = sorted(set(scplus_obj.metadata_cell['mese_mullerian_lowres']))
print(categories)
print(color_dict)
prune_plot(scplus_obj,
           'HOXC6_+_+',
           pseudobulk_variable = 'mese_mullerian_lowres',
           show_dot_plot = True,
           show_line_plot = False,
           color_dict = color_dict,
           use_pseudobulk = True,
           auc_key = 'eRegulon_AUC',
           signature_key = 'Region_based',
           seed=555)

In [None]:
# Region based
%matplotlib inline
import seaborn as sns
sns.set_style("white")
categories = sorted(set(scplus_obj.metadata_cell['mese_mullerian_lowres']))
print(categories)
print(color_dict)
prune_plot(scplus_obj,
           'HOXC6_+_+',
           pseudobulk_variable = 'mese_mullerian_lowres',
           show_dot_plot = True,
           show_line_plot = False,
           color_dict = color_dict,
           use_pseudobulk = True,
           auc_key = 'eRegulon_AUC',
           signature_key = 'Gene_based',
           seed=555)

In [None]:
# Region based
%matplotlib inline
import seaborn as sns
sns.set_style("white")
categories = sorted(set(scplus_obj.metadata_cell['mese_mullerian_lowres']))
print(categories)
print(color_dict)
prune_plot(scplus_obj,
           'HOXA13_+_+',
           pseudobulk_variable = 'mese_mullerian_lowres',
           show_dot_plot = True,
           show_line_plot = False,
           color_dict = color_dict,
           use_pseudobulk = True,
           auc_key = 'eRegulon_AUC',
           signature_key = 'Region_based',
           seed=555)

In [None]:
# Region based
%matplotlib inline
import seaborn as sns
sns.set_style("white")
categories = sorted(set(scplus_obj.metadata_cell['mese_mullerian_lowres']))
print(categories)
print(color_dict)
prune_plot(scplus_obj,
           'HOXA13_+_+',
           pseudobulk_variable = 'mese_mullerian_lowres',
           show_dot_plot = True,
           show_line_plot = False,
           color_dict = color_dict,
           use_pseudobulk = True,
           auc_key = 'eRegulon_AUC',
           signature_key = 'Gene_based',
           seed=555)

### Identification of high quality regions

In [None]:
# Correlation between region based regulons and gene based regulons
import pandas
df1 = scplus_obj.uns['eRegulon_AUC']['Gene_based'].copy()
df2 = scplus_obj.uns['eRegulon_AUC']['Region_based'].copy()
df1.columns = [x.split('_(')[0] for x in df1.columns]
df2.columns = [x.split('_(')[0] for x in df2.columns]
correlations = df1.corrwith(df2, axis = 0)
correlations = correlations[abs(correlations) > 0.6]
# Kepp only R2G +
keep = [x for x in correlations.index if '+_+' in x] + [x for x in correlations.index if '-_+' in x]
# Keep extended if not direct
extended = [x for x in keep if 'extended' in x]
direct = [x for x in keep if not 'extended' in x]
keep_extended = [x for x in extended if not x.replace('extended_', '') in direct]
keep = direct + keep_extended
# Keep regulons with more than 10 genes
keep_gene = [x for x in scplus_obj.uns['eRegulon_AUC']['Gene_based'].columns if x.split('_(')[0] in keep]
keep_gene = [x for x in keep_gene if (int(x.split('_(')[1].replace('g)', '')) > 10)]
keep_all = [x.split('_(')[0] for x in keep_gene]
keep_region = [x for x in scplus_obj.uns['eRegulon_AUC']['Region_based'].columns if x.split('_(')[0] in keep]
scplus_obj.uns['selected_eRegulons'] = {}
scplus_obj.uns['selected_eRegulons']['Gene_based'] = keep_gene
scplus_obj.uns['selected_eRegulons']['Region_based'] = keep_region

In [None]:
print(len(keep_gene))
print(len(keep_region))

In [None]:
%matplotlib inline


In [None]:
from scenicplus.plotting.correlation_plot import *
correlation_heatmap(scplus_obj,
                    auc_key = 'eRegulon_AUC',
                    signature_keys = ['Gene_based'],
                    selected_regulons = scplus_obj.uns['selected_eRegulons']['Gene_based'],
                    fcluster_threshold = 0.1,
                    fontsize = 8, 
                   save = outDir + 'correlation_heatmap.pdf')

In [None]:
#from scenicplus.plotting.correlation_plot import *
jaccard_heatmap(scplus_obj,
                    gene_or_region_based = 'Gene_based',
                    signature_key = 'eRegulon_signatures',
                    selected_regulons = scplus_obj.uns['selected_eRegulons']['Gene_based'],
                    fcluster_threshold = 0.1,
                    fontsize = 8,
                    method='intersect', 
               save = outDir + 'jaccard_heatmap.pdf')

In [None]:
binarize_AUC(scplus_obj,
             auc_key='eRegulon_AUC',
             out_key='eRegulon_AUC_thresholds',
             signature_keys=['Gene_based', 'Region_based'],
             n_cpu=20)

In [None]:
import dill
with open(outDir + 'scplus_obj2.pkl', 'wb') as f:
  dill.dump(scplus_obj, f)

In [None]:
import dill
infile = open(outDir + 'scplus_obj2.pkl', 'rb')
scplus_obj = dill.load(infile)
infile.close()

In [None]:
from scenicplus.dimensionality_reduction import *
run_eRegulons_umap(scplus_obj,
                   scale=True, signature_keys=['Gene_based', 'Region_based'], selected_regulons=scplus_obj.uns['selected_eRegulons']['Gene_based'])
run_eRegulons_tsne(scplus_obj,
                   scale=True, signature_keys=['Gene_based', 'Region_based'], selected_regulons=scplus_obj.uns['selected_eRegulons']['Gene_based'])

In [None]:
run_eRegulons_umap(scplus_obj,
                   scale=True, signature_keys=['Gene_based'],
                   reduction_name='eRegulons_UMAP_gb', selected_regulons=scplus_obj.uns['selected_eRegulons']['Gene_based'])
run_eRegulons_tsne(scplus_obj,
                   scale=True, signature_keys=['Gene_based'],
                   reduction_name='eRegulons_tSNE_gb', selected_regulons=scplus_obj.uns['selected_eRegulons']['Gene_based'])
run_eRegulons_umap(scplus_obj,
                   scale=True, signature_keys=['Region_based'],
                   reduction_name='eRegulons_UMAP_rb', selected_regulons=scplus_obj.uns['selected_eRegulons']['Region_based'])
run_eRegulons_tsne(scplus_obj,
                   scale=True, signature_keys=['Region_based'],
                   reduction_name='eRegulons_tSNE_rb', selected_regulons=scplus_obj.uns['selected_eRegulons']['Region_based'])

In [None]:
from scenicplus.dimensionality_reduction import *

In [None]:
from scenicplus.dimensionality_reduction import *
plot_metadata(scplus_obj,
                 reduction_name='eRegulons_UMAP_rb',
                 variables=['mese_mullerian_lowres'],
                 num_columns=1,
                 text_size=10,
                 dot_size=5, 
              figsize = (5,5),
#               color_dictionary = {'mese_mullerian_lowres' : color_dict}, 
              save = outDir + 'umap_regulons.pdf')


In [None]:
from scenicplus.dimensionality_reduction import *
plot_metadata(scplus_obj,
                 reduction_name='eRegulons_tSNE_rb',
                 variables=['mese_mullerian_lowres'],
                 num_columns=1,
                 text_size=10,
                 dot_size=5, 
              figsize = (5,5),
#               color_dictionary = {'mese_mullerian_lowres' : color_dict}, 
              save = outDir + 'tsne_regulons.pdf')


In [None]:
find_clusters(scplus_obj,
              signature_keys=['Gene_based', 'Region_based'],
              k = 10,
              res = [0.6, 1.2, 1.5],
              prefix = 'SCENIC+_',
              scale = True)

In [None]:
plot_metadata(scplus_obj,
                 reduction_name='eRegulons_tSNE_rb',
                 variables=['mese_mullerian_lowres', 'SCENIC+_leiden_10_0.6'],
                 num_columns=2,
                 text_size=10,
                 dot_size=5)

In [None]:
from scenicplus.RSS import *
regulon_specificity_scores(scplus_obj,
                         'mese_mullerian_lowres',
                         signature_keys=['Gene_based'],
                         selected_regulons=scplus_obj.uns['selected_eRegulons']['Gene_based'],
                         out_key_suffix='_gene_based',
                         scale=False)

In [None]:
scplus_obj.uns['RSS']

In [None]:
plot_rss(scplus_obj, 'mese_mullerian_lowres_gene_based', num_columns=2, top_n=10, figsize = (12, 12), fontsize = 12, 
         #selected_groups = ['MeseMullerianFallopianTube', 'MeseMullerianUterus'],
         save = outDir + 'rss_importances.pdf')

In [None]:
mat = scplus_obj.uns['RSS']['mese_mullerian_lowres_gene_based']
# Reorder the indices
new_indices = ['FallopianMese', 
               'UterusMese',
               'CervixMese', 'UpperVaginaMese']  # Specify the desired order of indices
mat_reordered = mat.reindex(new_indices)

In [None]:
mat_reordered

In [None]:
scplus_obj.uns['selected_eRegulons']

In [None]:
# Select only activators 
regs = scplus_obj.uns['selected_eRegulons']['Gene_based']
repressors = [r for r in regs if '-' in r]
activators = [r for r in regs if r not in repressors]

In [None]:
activators

In [None]:
len(activators)

In [None]:
# Order activators per cell type by RSS (top 10 per cell type)
mat_reordered_activators = mat_reordered[activators]
print(mat_reordered_activators.shape)
activators_per_celltype = {'FallopianMese' : [], 
                           'UterusMese' : [],
                          'CervixMese' : [], 'UpperVaginaMese' : []}

# Iterate through each row in the DataFrame
for index, row in mat_reordered_activators.iterrows():
    print(index)
    # Sort the row values and get the top 3 columns
    top_columns = list(row.nlargest(20).index)
    print(top_columns)
    activators_per_celltype[index].extend(top_columns)

    

In [None]:
top_activators = list(np.unique(list(activators_per_celltype.values())))

In [None]:
len(top_activators)

In [None]:
mat_reordered_top_activators = mat_reordered_activators[top_activators]
mat_reordered_top_activators

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

In [None]:
hox_tfs = [i for i in activators if i.startswith('HOX')]

In [None]:
hox_tfs

In [None]:
hox_tfs = [ 'HOXC4_+_+_(33g)','HOXA5_+_+_(67g)', 
            'HOXC6_+_+_(61g)', 'HOXA7_+_+_(58g)', 'HOXC8_+_+_(90g)',
          'HOXA9_+_+_(570g)', 'HOXD9_+_+_(186g)', 'HOXA10_+_+_(399g)',  'HOXD10_+_+_(259g)','HOXA11_+_+_(274g)',
           'HOXA13_+_+_(443g)','HOXD13_+_+_(321g)',]

In [None]:
from scenicplus.plotting.dotplot import *
heatmap_dotplot(
        scplus_obj = scplus_obj,
        size_matrix = mat_reordered,
        color_matrix = scplus_obj.to_df('EXP'),
        scale_size_matrix = True,
        scale_color_matrix = True,
        group_variable = 'mese_mullerian_lowres',
        subset_eRegulons = hox_tfs,
        figsize = (10, 1.8),
        orientation = 'horizontal',
        split_repressor_activator=True,
    index_order = ['FallopianMese',
               'UterusMese',
               'CervixMese', 'UpperVaginaMese'],
 save = outDir + 'mese_importances_heatmap_hox.pdf')

In [None]:
from scenicplus.plotting.dotplot import *
heatmap_dotplot(
        scplus_obj = scplus_obj,
        size_matrix = mat_reordered,
        color_matrix = scplus_obj.to_df('EXP'),
        scale_size_matrix = True,
        scale_color_matrix = True,
        group_variable = 'mese_mullerian_lowres',
        subset_eRegulons = top_activators,
        figsize = (5.5, 24),
        orientation = 'vertical',
        split_repressor_activator=True,
    index_order = ['FallopianMese',
               'UterusMese',
               'CervixMese', 'UpperVaginaMese'],
 save = outDir + 'mese_importances_heatmap.pdf')

In [None]:
outDir

In [None]:
# Order activators per cell type by RSS (top 10 per cell type)
mat_reordered_activators = mat_reordered[activators]
print(mat_reordered_activators.shape)
activators_per_celltype = {'FallopianMese' : [], 
                           'UterusMese' : [],
                          'CervixMese' : [], 'UpperVaginaMese' : []}

# Iterate through each row in the DataFrame
for index, row in mat_reordered_activators.iterrows():
    print(index)
    # Sort the row values and get the top 3 columns
    top_columns = list(row.nlargest(20).index)
    hox = [i for i in top_columns if i.startswith('HOX')]
    top_columns = [i for i in top_columns if i not in hox]
    print(top_columns)
    activators_per_celltype[index].extend(top_columns)

In [None]:
top_activators = list(np.unique(list(activators_per_celltype.values())))

In [None]:
top_activators = [item for sublist in top_activators for item in sublist]

In [None]:
top_activators = [i for i in top_activators if i not in hox_tfs]

In [None]:
top_activators = [i for i in top_activators if not i.startswith('ZNF')]

In [None]:
len(top_activators)

In [None]:
spatially_variable_tfs = ['PROX1', 'GATA6', 'NFATC2', 'LEF1', 'FOXL2', 'MEIS2',
                          'EMX2', 'FOXO1', 'ESR1', 'RORB', 'HMGA2', 'MSX1', 
                          'AR', 'TWIST1', 'ESRRG', 'RUNX1', 'PRRX2', 'TWIST2', 'LBX2', 
                          'PBX3', 'AHR', 'EVX1', 'EVX2', 'IRF6', 'NR0B1', 'ISL1', 'HMBOX1', 'ASCL2', 
                          'TBX18']

In [None]:
top_activators_tfs = [i.split('_')[0] for i in top_activators]

In [None]:
len(top_activators_tfs)

In [None]:
top_activators_tfs_variable = [i for i in top_activators_tfs if i in spatially_variable_tfs]

In [None]:
top_activators_tfs_variable

In [None]:
top_activators = [i for i in top_activators if not i.startswith('ZNF')]

In [None]:
len(top_activators)

In [None]:
# spatially_variable = ['GATA6_+_+_(150g)', 'PROX1_+_+_(49g)','NFATC2_+_+_(308g)',
#                       'FOXL2_+_+_(193g)', 'EMX2_+_+_(268g)', 'FOXO1_+_+_(299g)', 
#                     'HMGA2_+_+_(281g)', 
#                       'PBX3_+_+_(86g)', 'PRRX2_+_+_(126g)', 'EVX1_+_+_(116g)', 'EVX2_+_+_(74g)',  'LBX2_+_+_(39g)', 
#                        'AR_+_+_(68g)', 'AHR_+_+_(141g)', 'ISL1_+_+_(55g)', 'TCF21_+_+_(144g)',
#                     'ASCL2_+_+_(68g)', 'TWIST2_extended_+_+_(149g)', 'IRF6_extended_+_+_(55g)'
#                      ]

In [None]:
# mat_reordered_spatially_variable = mat_reordered[spatially_variable]
# mat_reordered_spatially_variable.shape

In [None]:
mat_reordered_top_activators = mat_reordered_activators[top_activators]
mat_reordered_top_activators

In [None]:
from scenicplus.plotting.dotplot import *
heatmap_dotplot(
        scplus_obj = scplus_obj,
        size_matrix = mat_reordered,
        color_matrix = scplus_obj.to_df('EXP'),
        scale_size_matrix = True,
        scale_color_matrix = True,
        group_variable = 'mese_mullerian_lowres',
        subset_eRegulons = top_activators,
        figsize = (25, 4),
        orientation = 'horizontal',
        split_repressor_activator=True,
    sort_by = 'color_val',
    index_order = ['FallopianMese',
                   'UterusMese',  'CervixMese', 'UpperVaginaMese'],
 save = outDir + 'mese_importances_heatmap_top25.pdf')

In [None]:
outDir 

In [None]:
import dill
with open(outDir + 'scplus_obj2.pkl', 'wb') as f:
  dill.dump(scplus_obj, f)

In [None]:
import dill
infile = open(outDir + 'scplus_obj2.pkl', 'rb')
scplus_obj = dill.load(infile)
infile.close()

## Visualisations in scanpy-compatible format for figures

In [None]:
cistopic_obj = dill.load(open(os.path.join(outDir, 'cisTopicObject_clean.pkl'), 'rb'))

In [None]:
import scanpy

In [None]:
annots = cistopic_obj.cell_data.copy()

In [None]:
annots['tsne1'] = annots.index.map(cistopic_obj.projections['cell']['harmony_tSNE']['tSNE_1'].to_dict())
annots['tsne2'] = annots.index.map(cistopic_obj.projections['cell']['harmony_tSNE']['tSNE_2'].to_dict())

In [None]:
annots.shape

In [None]:
annots.to_csv(outDir + 'mull_mese_embedding.csv')

### Network analysis

In [None]:
df = scplus_obj.uns['eRegulon_metadata']

## Fallopian tube mesenchyme

In [None]:
spatially_variable_interactors = ['LGR5', 'NTRK2', 'CD36', 'CD55', 'ALDH1A2', 'DLK1', 'NRG1', 'WNT4', 
                                 'BMP4', 'BMP7']

In [None]:
tfs = ['HOXA5', 'HOXC5', 'HOXA7', 'HOXC6']

In [None]:
import numpy as np

In [None]:
targets = np.unique(df[df['TF'].isin(tfs)]['Gene'].tolist())

In [None]:
len(targets)

In [None]:
final = [i for i in spatially_variable_interactors if i in targets]
print(final)

In [None]:
len(final)

In [None]:
from scenicplus.networks import *
import networkx as nx
subset_genes = final
nx_tables = create_nx_tables(scplus_obj,
                     eRegulon_metadata_key = 'eRegulon_metadata',
                     subset_eRegulons = tfs,
                     subset_regions = None,
                     subset_genes = subset_genes,
                     add_differential_gene_expression = True,
                     add_differential_region_accessibility = True,
                     differential_variable = ['mese_mullerian_lowres'])


In [None]:
tfs

In [None]:
from scenicplus.networks import *
G_kk, pos_kk, edge_tables_kk, node_tables_kk = create_nx_graph(nx_tables,
                   use_edge_tables = ['TF2R','R2G'],
                   color_edge_by = {'TF2R': {'variable' : 'TF', 'category_color' : {
                                                                                     'HOXA7' : 'orchid', 
                                                                                     
                        'HOXA5' : 'orchid',  'HOXC6' : 'orchid', 
                      'HOXC5' : 'orchid',
                      
                                                                                    
                   }},
                                    'R2G': {'variable' : 'R2G_rho', 'continuous_color' : 'viridis', 'v_min': -1, 'v_max': 1}},
                   transparency_edge_by =  {'R2G': {'variable' : 'R2G_importance', 'min_alpha': 0.6, 'v_min': 0}},
                   width_edge_by = {'R2G': {'variable' : 'R2G_importance', 'max_size' :  1.5, 'min_size' : 1}},
                   color_node_by = {'TF': {'variable': 'TF', 'category_color' : {
                      'HOXA7' : 'orchid', 
                                                                                     
                        'HOXA5' : 'orchid',  'HOXC6' : 'orchid', 
                      'HOXC5' : 'orchid',
                       
                       }},
                                    'Gene': {'variable': 'mese_mullerian_lowres_Log2FC_FallopianMese', 'continuous_color' : 'Blues'},
                                    'Region': {'variable': 'mese_mullerian_lowres_Log2FC_FallopianMese', 'continuous_color' : 'Blues'}},
                   transparency_node_by =  {'Region': {'variable' : 'mese_mullerian_lowres_Log2FC_FallopianMese', 'min_alpha': 0.2},
                                    'Gene': {'variable' : 'mese_mullerian_lowres_Log2FC_FallopianMese', 'min_alpha': 0.2}},
                   size_node_by = {'TF': {'variable': 'fixed_size', 'fixed_size': 60},
                                    'Gene': {'variable': 'fixed_size', 'fixed_size': 50},
                                    'Region': {'variable': 'fixed_size', 'fixed_size': 30}},
                   shape_node_by = {'TF': {'variable': 'fixed_shape', 'fixed_shape': 'ellipse'},
                                    'Gene': {'variable': 'fixed_shape', 'fixed_shape': 'ellipse'},
                                    'Region': {'variable': 'fixed_shape', 'fixed_shape': 'diamond'}},
                   label_size_by = {'TF': {'variable': 'fixed_label_size', 'fixed_label_size': 15.0},
                                    'Gene': {'variable': 'fixed_label_size', 'fixed_label_size': 10.0},
                                    'Region': {'variable': 'fixed_label_size', 'fixed_label_size': 0.0}}, 
                    label_color_by = {'TF': {'variable': 'fixed_label_color', 'fixed_label_color': 'black'},
                                    'Gene': {'variable': 'fixed_label_color', 'fixed_label_color': 'black'},
                                    'Region': {'variable': 'fixed_label_color', 'fixed_label_color': 'darkgray'}},
                                                               
                    layout = 'kamada_kawai_layout',
                   
                   scale_position_by = 500)

In [None]:
edge_tables_kk

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

In [None]:
nx.draw_networkx_nodes(G_kk, pos_kk, node_color=nx.get_node_attributes(G_kk,'color').values(),
                           node_size=list(nx.get_node_attributes(G_kk,'size').values()),
                           node_shape = 'D')
nx.draw_networkx_edges(G_kk, pos_kk, edge_color = nx.get_edge_attributes(G_kk,'color').values(),
                          width = list(nx.get_edge_attributes(G_kk,'width').values()))
fontsize_d = {y:x['size'] for x,y in zip(list(nx.get_node_attributes(G_kk,'font').values()),list(nx.get_node_attributes(G_kk,'label').values())) if x['size'] != 0.0}
fontcolor_d = {y:x['color'] for x,y in zip(list(nx.get_node_attributes(G_kk,'font').values()),list(nx.get_node_attributes(G_kk,'label').values())) if x['size'] != 0.0}
for node, (x, y) in pos_kk.items():
    if node in fontsize_d.keys():
        plt.text(x, y, node, fontsize=fontsize_d[node], color=fontcolor_d[node],  ha='center', va='center')
ax = plt.gca()
ax.margins(0.11)
plt.tight_layout()
plt.axis("off")
plt.savefig('/home/jovyan/network_scenicplus_mull_mese_fallopiantube.pdf', bbox_inches='tight', dpi=1000)
plt.show()

### Uterus mesenchyme

In [None]:
spatially_variable_interactors = ['WNT4', 'WNT5A', 'CDH3', 'FLTR2', 'GRIA4', 
                                  'TGM2', 'ALDH1A1', 'LRRTM1', 'NRP1', 'RORB']

In [None]:
tfs = ['HOXA10', 'HOXA9', 'HOXA11']

In [None]:
targets = np.unique(df[df['TF'].isin(tfs)]['Gene'].tolist())

In [None]:
final = [i for i in spatially_variable_interactors if i in targets]
print(final)

In [None]:
from scenicplus.networks import *
import networkx as nx
subset_genes = final
nx_tables = create_nx_tables(scplus_obj,
                     eRegulon_metadata_key = 'eRegulon_metadata',
                     subset_eRegulons = tfs,
                     subset_regions = None,
                     subset_genes = subset_genes,
                     add_differential_gene_expression = True,
                     add_differential_region_accessibility = True,
                     differential_variable = ['mese_mullerian_lowres'])


In [None]:
from scenicplus.networks import *
G_kk, pos_kk, edge_tables_kk, node_tables_kk = create_nx_graph(nx_tables,
                   use_edge_tables = ['TF2R','R2G'],
                   color_edge_by = {'TF2R': {'variable' : 'TF', 'category_color' : {
                                                                                     'HOXA10' : 'orange', 
                                                                                     
                         'HOXA11' : 'orange', 'HOXA9': 'orange',  
                      
                      
                                                                                    
                   }},
                                    'R2G': {'variable' : 'R2G_rho', 'continuous_color' : 'viridis', 'v_min': -1, 'v_max': 1}},
                   transparency_edge_by =  {'R2G': {'variable' : 'R2G_importance', 'min_alpha': 0.4, 'v_min': 0}},
                   width_edge_by = {'R2G': {'variable' : 'R2G_importance', 'max_size' :  1.5, 'min_size' : 1}},
                   color_node_by = {'TF': {'variable': 'TF', 'category_color' : {
                      'HOXA10' : 'orange', 
                                                                                     
                        'HOXA11' : 'orange', 'HOXA9': 'orange',
                       
                       }},
                                    'Gene': {'variable': 'mese_mullerian_lowres_Log2FC_UterusMese', 'continuous_color' : 'Blues'},
                                    'Region': {'variable': 'mese_mullerian_lowres_Log2FC_UterusMese', 'continuous_color' : 'Blues'}},
                   transparency_node_by =  {'Region': {'variable' : 'mese_mullerian_lowres_Log2FC_UterusMese', 'min_alpha': 0.2},
                                    'Gene': {'variable' : 'mese_mullerian_lowres_Log2FC_UterusMese', 'min_alpha': 0.2}},
                   size_node_by = {'TF': {'variable': 'fixed_size', 'fixed_size': 60},
                                    'Gene': {'variable': 'fixed_size', 'fixed_size': 50},
                                    'Region': {'variable': 'fixed_size', 'fixed_size': 30}},
                   shape_node_by = {'TF': {'variable': 'fixed_shape', 'fixed_shape': 'ellipse'},
                                    'Gene': {'variable': 'fixed_shape', 'fixed_shape': 'ellipse'},
                                    'Region': {'variable': 'fixed_shape', 'fixed_shape': 'diamond'}},
                   label_size_by = {'TF': {'variable': 'fixed_label_size', 'fixed_label_size': 20.0},
                                    'Gene': {'variable': 'fixed_label_size', 'fixed_label_size': 15.0},
                                    'Region': {'variable': 'fixed_label_size', 'fixed_label_size': 0.0}}, 
                    label_color_by = {'TF': {'variable': 'fixed_label_color', 'fixed_label_color': 'black'},
                                    'Gene': {'variable': 'fixed_label_color', 'fixed_label_color': 'black'},
                                    'Region': {'variable': 'fixed_label_color', 'fixed_label_color': 'darkgray'}},
                                                               
                    layout = 'kamada_kawai_layout',
                   
                   scale_position_by = 500)

In [None]:
edge_tables_kk

In [None]:
nx.draw_networkx_nodes(G_kk, pos_kk, node_color=nx.get_node_attributes(G_kk,'color').values(),
                           node_size=list(nx.get_node_attributes(G_kk,'size').values()),
                           node_shape = 'D')
nx.draw_networkx_edges(G_kk, pos_kk, edge_color = nx.get_edge_attributes(G_kk,'color').values(),
                          width = list(nx.get_edge_attributes(G_kk,'width').values()))
fontsize_d = {y:x['size'] for x,y in zip(list(nx.get_node_attributes(G_kk,'font').values()),list(nx.get_node_attributes(G_kk,'label').values())) if x['size'] != 0.0}
fontcolor_d = {y:x['color'] for x,y in zip(list(nx.get_node_attributes(G_kk,'font').values()),list(nx.get_node_attributes(G_kk,'label').values())) if x['size'] != 0.0}
for node, (x, y) in pos_kk.items():
    if node in fontsize_d.keys():
        plt.text(x, y, node, fontsize=fontsize_d[node], color=fontcolor_d[node],  ha='center', va='center')
ax = plt.gca()
ax.margins(0.11)
plt.tight_layout()
plt.axis("off")
plt.savefig('/home/jovyan/network_scenicplus_mull_mese_uterus.pdf', bbox_inches='tight', dpi=1000)
plt.show()

## Upper vagina mesenchyme

In [None]:
spatially_variable_interactors = ['GDF7', 'GDF10', 'COL26A1', 'TNC', 'WIF1', 'SFRP5', 'IGF1', 'BMP4', 'BMP7']

In [None]:
tfs = ['HOXA13', 'HOXD13']

In [None]:
targets = np.unique(df[df['TF'].isin(tfs)]['Gene'].tolist())

In [None]:
final = [i for i in spatially_variable_interactors if i in targets]
print(final)

In [None]:
from scenicplus.networks import *
import networkx as nx
subset_genes = final
nx_tables = create_nx_tables(scplus_obj,
                     eRegulon_metadata_key = 'eRegulon_metadata',
                     subset_eRegulons = tfs,
                     subset_regions = None,
                     subset_genes = subset_genes,
                     add_differential_gene_expression = True,
                     add_differential_region_accessibility = True,
                     differential_variable = ['mese_mullerian_lowres'])


In [None]:
from scenicplus.networks import *
G_kk, pos_kk, edge_tables_kk, node_tables_kk = create_nx_graph(nx_tables,
                   use_edge_tables = ['TF2R','R2G'],
                   color_edge_by = {'TF2R': {'variable' : 'TF', 'category_color' : {
                                                                                     
                        'HOXA13' : 'yellowgreen',  'HOXD13' : 'yellowgreen', 
                      
                                                                                    
                   }},
                                    'R2G': {'variable' : 'R2G_rho', 'continuous_color' : 'viridis', 'v_min': -1, 'v_max': 1}},
                   transparency_edge_by =  {'R2G': {'variable' : 'R2G_importance', 'min_alpha': 0.4, 'v_min': 0}},
                   width_edge_by = {'R2G': {'variable' : 'R2G_importance', 'max_size' :  1.5, 'min_size' : 1}},
                   color_node_by = {'TF': {'variable': 'TF', 'category_color' : {
                      'HOXA13' : 'yellowgreen',  'HOXD13' : 'yellowgreen', 
                   
                       }},
                                    'Gene': {'variable': 'mese_mullerian_lowres_Log2FC_UpperVaginaMese', 'continuous_color' : 'Blues'},
                                    'Region': {'variable': 'mese_mullerian_lowres_Log2FC_UpperVaginaMese', 'continuous_color' : 'Blues'}},
                   transparency_node_by =  {'Region': {'variable' : 'mese_mullerian_lowres_Log2FC_UpperVaginaMese', 'min_alpha': 0.2},
                                    'Gene': {'variable' : 'mese_mullerian_lowres_Log2FC_UpperVaginaMese', 'min_alpha': 0.2}},
                   size_node_by = {'TF': {'variable': 'fixed_size', 'fixed_size': 60},
                                    'Gene': {'variable': 'fixed_size', 'fixed_size': 50},
                                    'Region': {'variable': 'fixed_size', 'fixed_size': 30}},
                   shape_node_by = {'TF': {'variable': 'fixed_shape', 'fixed_shape': 'ellipse'},
                                    'Gene': {'variable': 'fixed_shape', 'fixed_shape': 'ellipse'},
                                    'Region': {'variable': 'fixed_shape', 'fixed_shape': 'diamond'}},
                   label_size_by = {'TF': {'variable': 'fixed_label_size', 'fixed_label_size': 20.0},
                                    'Gene': {'variable': 'fixed_label_size', 'fixed_label_size': 15.0},
                                    'Region': {'variable': 'fixed_label_size', 'fixed_label_size': 0.0}}, 
                    label_color_by = {'TF': {'variable': 'fixed_label_color', 'fixed_label_color': 'black'},
                                    'Gene': {'variable': 'fixed_label_color', 'fixed_label_color': 'black'},
                                    'Region': {'variable': 'fixed_label_color', 'fixed_label_color': 'darkgray'}},
                                                               
                    layout = 'kamada_kawai_layout',
                   
                   scale_position_by = 500)

In [None]:
edge_tables_kk

In [None]:
nx.draw_networkx_nodes(G_kk, pos_kk, node_color=nx.get_node_attributes(G_kk,'color').values(),
                           node_size=list(nx.get_node_attributes(G_kk,'size').values()),
                           node_shape = 'D')
nx.draw_networkx_edges(G_kk, pos_kk, edge_color = nx.get_edge_attributes(G_kk,'color').values(),
                          width = list(nx.get_edge_attributes(G_kk,'width').values()))
fontsize_d = {y:x['size'] for x,y in zip(list(nx.get_node_attributes(G_kk,'font').values()),list(nx.get_node_attributes(G_kk,'label').values())) if x['size'] != 0.0}
fontcolor_d = {y:x['color'] for x,y in zip(list(nx.get_node_attributes(G_kk,'font').values()),list(nx.get_node_attributes(G_kk,'label').values())) if x['size'] != 0.0}
for node, (x, y) in pos_kk.items():
    if node in fontsize_d.keys():
        plt.text(x, y, node, fontsize=fontsize_d[node], color=fontcolor_d[node],  ha='center', va='center')
ax = plt.gca()
ax.margins(0.11)
plt.tight_layout()
plt.axis("off")
plt.savefig('/home/jovyan/network_scenicplus_mull_mese_uppervagina.pdf', bbox_inches='tight', dpi=1000)
plt.show()