# library

In [None]:
import os
import sys
import re
import pickle
import random
import subprocess
import time
import threading
import shutil
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, wait, ALL_COMPLETED, as_completed
from datetime import datetime, timedelta
from multiprocessing import Process, Pool

import numpy as np
import pandas as pd
import anndata as ad
import h5py
# import Bio
# from Bio import motifs
# import pysam
import pyranges
import pybedtools
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import sklearn
from sklearn import preprocessing
import scipy
from scipy import io
import scanpy as sc
from sklearn.cluster import KMeans
# from adjustText import adjust_text
# import episcanpy
import ruamel.yaml
yaml = ruamel.yaml.YAML(typ="safe")
yaml.default_flow_style = False
from matplotlib_venn import venn3, venn2, venn3_unweighted, venn2_unweighted

import SCRIP
from SCRIP.utilities import utils
from SCRIP.utilities.utils import print_log, safe_makedirs, excute_info, read_pickle, read_SingleCellExperiment_rds, store_to_pickle

import warnings
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)
# warnings.simplefilter(action='ignore', category=subprocess.)

plt.rcParams.update({
    'figure.figsize': [8.0, 8.0],
    'font.size' : 15,
    'font.family': 'Arial',
    'font.style' : 'normal',
    'font.weight':'normal',
    'figure.titleweight': 'normal',
    'axes.labelsize': 14 ,
    'axes.titleweight': 'normal',
    'axes.labelweight': 'normal',
    'axes.spines.right': False,
    'axes.spines.top': False,
})

N = 256
vals = np.ones((N, 4))
vals[:, 0] = np.linspace(220/256, 34/256, N)
vals[:, 1] = np.linspace(220/256, 7/256, N)
vals[:, 2] = np.linspace(220/256, 141/256, N)
regulation_cmp = mpl.colors.ListedColormap(vals)

In [None]:
import anndata2ri

# Function Define

In [None]:
from sklearn.neighbors import KDTree
from sklearn.neighbors import BallTree

def find_nearest_cells(q_point, tree, n_neighbor):
    _, ind = tree.query(q_point, k=n_neighbor+1)
    return ind[0][1:]

def cal_neighbor_cell_peak_mat(sub_mat, input_mat, tree, pc_table, impute_n, start_idx, i):
    end_index = start_idx + sub_mat.shape[0]
    k = 0
    for idx in range(start_idx, end_index):
        nearest_bc_idx = find_nearest_cells(np.reshape(pc_table[idx,:], (1,-1)), tree, n_neighbor=impute_n)
#         scipy.sparse.csr_matrix(input_mat[nearest_bc_idx,:].sum(0))
        sub_mat[k,:] = input_mat[nearest_bc_idx,:].sum(0)
        k += 1
    return sub_mat

def cal_neighbor_cell_peak_mat_batch(input_mat, impute_n=5, KD_leafsize=80, nPC = 50, n_cores=8):
    '''
    input_mat:
    a csr sparse matrix, which can be get by adata.X
    
    '''
    print_log("Building KD tree...")
    pc_table = sc.tl.pca(input_mat, n_comps=50, svd_solver='arpack')
    tree = BallTree(pc_table, KD_leafsize)
    print_log("Calculating neighbors, divide into {n} chunks...".format(n=n_cores))
    cell_number = input_mat.shape[0]
    index_split = [i for i in range(0,cell_number,int(cell_number/n_cores))] + [cell_number]
#     input_table_split = np.array_split(input_mat_dense, n_cores)
    input_mat_lil = input_mat.tolil()
    input_mat_split = [input_mat_lil[index_split[i]:index_split[i+1],:] for i in range(index_split.__len__()-1)]
    args = [[sub_mat, input_mat, tree, pc_table, impute_n, index_split[i], i] for (i, sub_mat) in enumerate(input_mat_split)]
#     print(args)
    with Pool(n_cores) as p:
        result = p.starmap(cal_neighbor_cell_peak_mat, args)
    cell_peak_csr = scipy.sparse.vstack(result).tocsr()
    print_log('Finished!')
    return cell_peak_csr

In [None]:
def generate_peak_list(cells, input_mat, peak_confidence=1):
    cell_above_cutoff_index = sc.pp.filter_genes(
        input_mat[cells, :], min_cells=peak_confidence, inplace=False)[0]
    peaks = input_mat.var_names[cell_above_cutoff_index].to_list()
    return peaks


def generate_beds(file_path, cells, input_mat, peak_confidence=1):
    peaks = generate_peak_list(cells, input_mat, peak_confidence)
    cell_barcode = os.path.basename(file_path)[:-4]  # remove .bed
    if peaks.__len__() == 0:
        print_log('Warning: No peaks in {bed_path}, skip generation'.format(bed_path=file_path[:-4]))
    else:
        peaks = pd.DataFrame([p.rsplit("_", 2) for p in peaks])
        peaks.to_csv(file_path, sep="\t", header=None, index=None)
        cmd = 'sort --buffer-size 2G -k1,1 -k2,2n -k3,3n {bed_path} | bgzip -c > {bed_path}.gz\n'.format(bed_path=file_path)
        cmd += 'rm {bed_path}'.format(bed_path=file_path)
        subprocess.run(cmd, shell=True, check=True)
    return [cell_barcode, peaks.__len__()]


def generate_beds_by_matrix(cell_feature_adata, beds_path, peaks_number_path, n_cores):
    safe_makedirs(beds_path)
    # total_cnt = adata.obs.index.__len__()
    executor = ThreadPoolExecutor(max_workers=n_cores)
    all_task = []
    for cell in cell_feature_adata.obs.index:
        # neighbor_cells = find_nearest_cells(cell, coor_table, n_neighbor, step)
        # map_dict[cell] = neighbor_cells
        all_task.append(executor.submit(generate_beds, beds_path + "/" + str(cell) + ".bed", cell, cell_feature_adata))
    wait(all_task, return_when=ALL_COMPLETED)
    pd.DataFrame([_.result() for _ in as_completed(all_task)]).to_csv(peaks_number_path, header=None, index=None, sep='\t')
    return


def search_ref_factor(bed_path, result_path, index_path, factor):
    cmd = f'giggle search -i "{index_path}" -q "{bed_path}" -s -f {factor}_ > "{result_path}"\n'
    # cmd = f'igd search {index_path}/ref.igd -q {bed_path} | head -n -1 | cut -f 2,3,4 > {result_path}'
    # cmd = f'seqpare "{index_path}/*.bed.gz" "{bed_path}" -m 1 -o {result_path}\n'
    subprocess.run(cmd, shell=True, check=True)

def search_ref_factor_batch(bed_folder, result_folder, index_path, factor, n_cores=8, tp=''):
    print_log(f'Start searching beds from {tp} index ...')
    safe_makedirs(result_folder)
    beds = os.listdir(bed_folder)
    args = []
    for bed in beds:
        barcodes = bed[:-7]  # remove suffix '.bed.gz'
        args.append((os.path.join(bed_folder, bed),
                     os.path.join(result_folder, barcodes + '.txt'),
                     index_path,
                     factor))
    with Pool(n_cores) as p:
        p.starmap(search_ref_factor, args)
    print_log(f'Finished searching beds from {tp} index ...')

In [None]:
def search_ref(bed_path, result_path, index_path):
    cmd = f'giggle search -i "{index_path}" -q "{bed_path}" -s > "{result_path}"\n'
    # cmd = f'igd search {index_path}/ref.igd -q {bed_path} | head -n -1 | cut -f 2,3,4 > {result_path}'
    # cmd = f'seqpare "{index_path}/*.bed.gz" "{bed_path}" -m 1 -o {result_path}\n'
    subprocess.run(cmd, shell=True, check=True)

def search_ref_batch(bed_folder, result_folder, index_path, n_cores=8, tp=''):
    print_log(f'Start searching beds from {tp} index ...')
    safe_makedirs(result_folder)
    beds = os.listdir(bed_folder)
    args = []
    for bed in beds:
        barcodes = bed[:-7]  # remove suffix '.bed.gz'
        args.append((os.path.join(bed_folder, bed),
                     os.path.join(result_folder, barcodes + '.txt'),
                     index_path))
    with Pool(n_cores) as p:
        p.starmap(search_ref, args)
    print_log(f'Finished searching beds from {tp} index ...')


def read_search_result(files):
    for i in range(len(files)):
        result_name = os.path.basename(files[i])
        cell_bc = result_name[:-4]  # remove suffix '.txt'
        dtframe = pd.read_csv(files[i], sep="\t", index_col=0, comment='#', header=None)
        read_col = 2  # 1 file_size 2 overlaps 3 odds_ratio 4 fishers_two_tail 5 fishers_left_tail 6 fishers_right_tail 7 combo_score
        if i == 0:
            dtframe = dtframe.loc[:, [read_col]].copy()
            dataset_cell_score_df = dtframe.rename(columns={read_col: cell_bc}).copy()
        else:
            dataset_cell_score_df[cell_bc] = dtframe.loc[:, read_col]
    dataset_cell_score_df.index = [i.rsplit('/', 1)[0][:-7] for i in dataset_cell_score_df.index]  # remove suffix '.bed.gz'
    return dataset_cell_score_df


def read_search_result_batch(path, n_cores=8, tp=''):
    print_log(f"Reading searching results, using {n_cores} cores...")
    file_list = os.listdir(path)
    result_split = np.array_split(file_list, n_cores)
    args = [[[os.path.join(path, j) for j in list_chunk]] for list_chunk in result_split]
    with Pool(n_cores) as p:
        result = p.starmap(read_search_result, args)
    dataset_cell_score_df = pd.concat([i for i in result], axis=1)
    print_log(f"Finished reading {tp} index search result!")
    return dataset_cell_score_df

In [None]:
@excute_info('Getting the best reference for each cell.')
def get_factor_source(table):
    ret_table = table.copy()
    # map factor by id "_"
    factor_index_list = []
    for i in ret_table.index:
        factor_name = i.split("_")
        factor_index_list.append(factor_name[0])
    ret_table.loc[:, "Factor"] = factor_index_list
    max_index = ret_table.groupby("Factor").idxmax()
    return max_index


def cal_score(dataset_overlap_df, peaks_number):
    '''
    nql: normalize query peak length
    dm: divide the mean
    '''
    dataset_cell_percent = (dataset_overlap_df.T/peaks_number.loc[dataset_overlap_df.index, 1]).T
    dataset_cell_percent_scale = (dataset_cell_percent/dataset_cell_percent.sum())*1e4
    dataset_cell_percent_scale_dm = (dataset_cell_percent_scale.T/dataset_cell_percent_scale.mean(1)).T
    return dataset_cell_percent_scale_dm

In [None]:
def write_to_mtx(data, path):
    if not os.path.exists(path):
        os.makedirs(path)
    pd.DataFrame(data.var.index).to_csv(os.path.join(path, "genes.tsv" ), sep = "\t", index=False, header=False)
    pd.DataFrame(data.obs.index).to_csv(os.path.join(path, "barcodes.tsv"), sep = "\t", index=False, header=False)
    data.obs.to_csv(os.path.join(path, "metadata.tsv"), sep = "\t", index=False, header=False)
    io.mmwrite(os.path.join(path, "matrix.mtx"), data.X.T)

In [None]:
def geneInfoSimple(gene_bed):
    genes_info = []
    genes_list = []
    fhd = open(gene_bed, 'rt')
    fhd.readline() # skip the first line. In our current gene txt file, there is no '#' in the first line. We need to, perhaps, use the 'ExtractGeneInfo' function.
    for line in fhd:
        line = line.strip().split('\t')
        if not line[0].startswith('#'):
            if line[3] == "+":
                genes_info.append((line[2].replace('chr',''), int(line[4]), 1, "%s@%s@%s" % (line[12], line[2], line[4])))
            else:
                genes_info.append((line[2].replace('chr',''), int(line[5]), 1, "%s@%s@%s" % (line[12], line[2], line[5])))
                # gene_info [chrom, tss, 1, gene_unique]
    fhd.close()
    genes_info = list(set(genes_info))
    for igene in range(len(genes_info)):
        tmp_gene = list(genes_info[igene])
        genes_list.append(tmp_gene[3])
        tmp_gene[3] = igene
        genes_info[igene] = tmp_gene
    return genes_info, genes_list

def RP_Simple(peaks_info, genes_info, decay):
    """Multiple processing function to calculate regulation potential."""

    Sg = lambda x: 2**(-x)
    gene_distance = 15 * decay
    genes_peaks_score_array = scipy.sparse.dok_matrix((len(genes_info), len(peaks_info)), dtype=np.float64)

    w = genes_info + peaks_info

    A = {}

    w.sort()
    for elem in w:
        if elem[2] == 1:
            A[elem[-1]] = [elem[0], elem[1]]
        else:
            dlist = []
            for gene_name in list(A.keys()):
                g = A[gene_name]
                tmp_distance = abs(elem[1] - g[1])
                if (g[0] != elem[0]) or (tmp_distance > gene_distance):
                    dlist.append(gene_name)
                else:
                    genes_peaks_score_array[gene_name, elem[-1]] = Sg(tmp_distance / decay)
            for gene_name in dlist:
                del A[gene_name]

    w.reverse()
    for elem in w:
        if elem[2] == 1:
            A[elem[-1]] = [elem[0], elem[1]]
        else:
            dlist = []
            for gene_name in list(A.keys()):
                g = A[gene_name]
                tmp_distance = abs(g[1] - elem[1])
                if (g[0] != elem[0]) or (tmp_distance > gene_distance):
                    dlist.append(gene_name)
                else:
                    genes_peaks_score_array[gene_name, elem[-1]] = Sg(tmp_distance / decay)
            for gene_name in dlist:
                del A[gene_name]

    return genes_peaks_score_array

In [None]:
def enhance(input_mat, impute_n=5, KD_leafsize=80, nPC = 50, path='SCRIPT/enhancement/', binarize=True, n_cores=8):
    '''
    input_mat:
    a csr sparse matrix
    
    '''
    safe_makedirs(path)
    imputed_csr = cal_neighbor_cell_peak_mat_batch(input_mat, impute_n=impute_n, KD_leafsize=KD_leafsize, nPC = nPC, n_cores=n_cores)
    if binarize == True:
        imputed_csr[imputed_csr>1] = 1
    utils.store_to_pickle(imputed_csr, path + 'imputed.csr.pk')
    return imputed_csr


def impute(input_mat_adata, impute_factor, ref_path, bed_check=True, search_check=True, path='SCRIPT/imputation/', write_mtx=True, ref_baseline=500, remove_others_source=False, n_cores=8):
    '''
    '''
    
    safe_makedirs(path)
    print(input_mat_adata.X.shape)
    if bed_check == True:
        if not os.path.exists(f'{path}/imputed_beds/'):
            print_log('Generating beds...')
            generate_beds_by_matrix(input_mat_adata, path + '/imputed_beds/', path + '/imputed_beds_peaks_number.txt', n_cores)
        else:
            print_log('Skip generate beds...')
    else:
        print_log('Generating beds...')
        generate_beds_by_matrix(input_mat_adata, path + '/imputed_beds/', path + '/imputed_beds_peaks_number.txt', n_cores)
    
    if search_check == True:
        if not os.path.exists(path + '/imputed_results_%s/' % impute_factor):
            search_ref_factor_batch(path + '/imputed_beds/', path + '/imputed_results_%s/' % impute_factor, ref_path, impute_factor, n_cores)
        else:
            print_log('Skip searching beds...')
    else:
        search_ref_factor_batch(path + '/imputed_beds/', path + '/imputed_results_%s/' % impute_factor, ref_path, impute_factor, n_cores)
    
    print_log('Calculating score...')
    factor_enrich = read_search_result_batch(path + '/imputed_results_%s/' % impute_factor, n_cores)
    
    peaks_length = pd.read_csv(os.path.join(ref_path, 'peaks_number.txt'), sep='\t', header=None, index_col=0)
    peaks_length_factor = peaks_length.loc[[i for i in peaks_length.index if i.startswith(impute_factor)], :].copy()
    factor_score = cal_score(factor_enrich, peaks_length_factor)

    factor_source = get_factor_source(factor_score)
    store_to_pickle(factor_source, path + '%s_dataset_source.pk' % impute_factor)

    chip_bed_list = [pybedtools.BedTool(os.path.join(ref_path, 'raw_beds', i + '.bed.gz')) for i in factor_source.iloc[0,:].unique()]
    chip_bed = chip_bed_list[0].cat(*chip_bed_list[1:])
    data_bed = pybedtools.BedTool('\n'.join(['\t'.join(p.rsplit('_', maxsplit=2)) for p in input_mat_adata.var_names]), from_string=True)
    intersect_bed = data_bed.intersect(chip_bed, u=True)
    imputed_chip_peak = str(intersect_bed).replace('\t','_').split('\n')[0:-1]
    
    chip_cell_peak = input_mat_adata[:,imputed_chip_peak].copy()
    chip_cell_peak_df = chip_cell_peak.to_df()
    if remove_others_source == True:
        for i in factor_source.iloc[0,:].unique():
            cellbc = factor_source.columns[factor_source.iloc[0,:] == i]
            tmp_dataset_bed = pybedtools.BedTool(os.path.join(ref_path, i + '.bed.gz'))
            exclude_chip_peak = str(intersect_bed.intersect(tmp_dataset_bed, v=True)).replace('\t','_').split('\n')[0:-1]
            chip_cell_peak_df.loc[cellbc,exclude_chip_peak] = 0
    chip_cell_peak = sc.AnnData(chip_cell_peak_df)
    chip_cell_peak.X = scipy.sparse.csr.csr_matrix(chip_cell_peak.X)
    print_log('Writing results...')
    if write_mtx == True:
        write_to_mtx(chip_cell_peak, path + '/imputed_%s_mtx/' % impute_factor)
    print_log('Finished!')
    return chip_cell_peak, factor_score

def count_to_gene_by_RP(input_adata, decay=100000, refgene_path='/fs/home/dongxin/Files/GRCm38_refgenes.txt'):
    cells_list = input_adata.obs.index.tolist()
    peaks_list = input_adata.var.index.tolist()

    genes_info, genes_list= geneInfoSimple(refgene_path)

    peaks_info = []
    for ipeak, peak in enumerate(peaks_list):
        peaks_tmp = peak.rsplit("_", maxsplit=2)
        peaks_info.append([peaks_tmp[0][3:], (int(peaks_tmp[1]) + int(peaks_tmp[2])) / 2.0, 0, ipeak])

    genes_peaks_score_dok = RP_Simple(peaks_info, genes_info, decay)

    genes_peaks_score_csr = genes_peaks_score_dok.tocsr()
    genes_cells_score_csr = genes_peaks_score_csr.dot(chip_cell_peak.X.T)

    score_cells_dict = {}
    score_cells_sum_dict = {}

    for igene, gene in enumerate(genes_list):
        score_cells_dict[gene] = igene
        score_cells_sum_dict[gene] = genes_cells_score_csr[igene, :].sum()

    score_cells_dict_dedup = {}
    score_cells_dict_max = {}
    genes = list(set([i.split("@")[0] for i in genes_list]))
    for gene in genes:
        score_cells_dict_max[gene] = float("-inf")

    for gene in genes_list:
        symbol = gene.split("@")[0]
        if score_cells_sum_dict[gene] > score_cells_dict_max[symbol]:
            score_cells_dict_dedup[symbol] = score_cells_dict[gene]
            score_cells_dict_max[symbol] = score_cells_sum_dict[gene]
    gene_symbol = sorted(score_cells_dict_dedup.keys())
    matrix_row = []
    for gene in gene_symbol:
        matrix_row.append(score_cells_dict_dedup[gene])

    score_cells_matrix = genes_cells_score_csr[matrix_row, :]

    RP_adata = ad.AnnData(score_cells_matrix.T, obs=pd.DataFrame(index=cells_list.tolist()), var =pd.DataFrame(index=gene_symbol))
    return RP_adata
    

In [None]:
atac = read_SingleCellExperiment_rds('example/PBMC/data/PBMC_ATAC_500bin/analysis/PBMC_TBMono_500bin.rds')

In [None]:
atac.var.index = [i.replace('-', '_') for i in atac.var.index]

In [None]:
impute_factor = 'H3K27ac'

In [None]:
atac.obs['nFeature_ATAC'] = atac.obs['nFeature_ATAC'].astype(int)
atac.obs['nCount_ATAC'] = atac.obs['nCount_ATAC'].astype(int)

In [None]:
atac.write_h5ad('example/PBMC/data/PBMC_ATAC_500bin/analysis/PBMC_TBMono_500bin.h5ad')

In [None]:
atac = ad.read_h5ad('example/PBMC/data/PBMC_ATAC_500bin/analysis/PBMC_TBMono_500bin.h5ad')

In [None]:
chip_cell_peak_H3K27ac, factor_score = impute(atac, 'H3K27ac', '/fs/home/dongxin/Projects/SCRIPT/indices/human/hm_chip_qc_5fold_giggle/', 
                                              bed_check=True, search_check=True, path='example/histone/peak_base/cuttagpro/SCRIPT_PBMC/imputationPBMC1022/', 
                                              write_mtx=True, ref_baseline=500, remove_others_source=False, n_cores=64)

In [None]:
sc.pp.filter_genes(chip_cell_peak_H3K27ac, min_cells=20)

In [None]:
read_pickle('/fs/home/dongxin/Projects/SCRIPT/scATAC/example/histone/SCRIPT_1114_remove_others/imputation/H3K27ac_dataset_source.pk').iloc[0,:].value_counts()[0:30]

In [None]:
read_pickle('/fs/home/dongxin/Projects/SCRIPT/scATAC/example/histone/SCRIPT_1114/imputation/H3K4me3_dataset_source.pk').iloc[0,:].value_counts()[0:30]

## Plot

In [None]:
bulk_t_target = pd.read_csv('/fs/home/dongxin/Projects/SCRIPT/scATAC/example/histone/62350_gene_score_5fold_T.txt', comment = '#', sep='\t', header = None)
bulk_mono_target = pd.read_csv('/fs/home/dongxin/Projects/SCRIPT/scATAC/example/histone/34935_gene_score_5fold_Mono.txt', comment = '#', sep='\t', header = None)

In [None]:
bulk_t_target_list = bulk_t_target[6].unique()
bulk_mono_target_list = bulk_mono_target[6].unique()

In [None]:
bulk_t_target_rp = bulk_t_target.groupby(6).max()[4]
bulk_mono_target_rp = bulk_mono_target.groupby(6).max()[4]

In [None]:
impute_RP = sc.read_h5ad('/fs/home/dongxin/Projects/SCRIPT/scATAC/example/histone/SCRIPT_1114_remove_others/imputation/H3K27ac_RP.h5ad').to_df().T

In [None]:
impute_RP.columns = [i.split('-')[0] for i in impute_RP.columns]

In [None]:
keys = pd.read_csv('example/PBMC/barcode_key.txt', sep='\t', index_col=0)

In [None]:
keys.index = keys['ATAC']

In [None]:
impute_metadata = pd.read_csv('example/PBMC/analysis/metadata.txt', sep='\t', index_col=0)

In [None]:
atac_rp = sc.read_h5ad('/fs/home/dongxin/Projects/SCRIPT/scATAC/example/histone/H3K27ac_RP.h5ad').to_df().T

In [None]:
atac_rp.columns = [i.split('-')[0] for i in atac_rp.columns]

In [None]:
# real_RP = pd.read_csv('/fs/home/dongxin/Projects/SCRIPT/scATAC/example/histone/peak_base/cuttagpro/SCRIPT_PBMC/real_RP.txt', sep='\t')
# store_to_pickle(real_RP, '/fs/home/dongxin/Projects/SCRIPT/scATAC/example/histone/peak_base/cuttagpro/SCRIPT_PBMC/real_RP.pk')

In [None]:
real_RP = read_pickle('/fs/home/dongxin/Projects/SCRIPT/scATAC/example/histone/peak_base/cuttagpro/SCRIPT_PBMC/real_RP.pk')

In [None]:
real_matadata = pd.read_csv('example/histone/peak_base/cuttagpro/SCRIPT_PBMC/real_meta_data.txt', sep='\t')

In [None]:
impute_metadata.loc['CGTACTTCAAGCGAAG']

In [None]:
tmp_meta['CellType'].unique()

In [None]:
tmp_meta = pd.read_csv('example/PBMC/pbmc_meta.txt', sep='\t', index_col=0)
keys.index = keys['RNA']
tmp_meta.index = [keys.loc[i,'ATAC'] for i in tmp_meta.index]
keys.index = keys['ATAC']

In [None]:
# RP correlation

In [None]:
bulk_ovlp_target = set(bulk_t_target_rp.index).intersection(impute_t_rp.index)

In [None]:
impute_t_bc = tmp_meta.index[(tmp_meta['CellType'] == 'naive_CD4_T_cells') | (tmp_meta['CellType'] == 'memory_CD4_T_cells') | (tmp_meta['CellType'] == 'naive_CD8_T_cells')| (tmp_meta['CellType'] == 'effector_CD8_T_cells')]
impute_t_bc = set(impute_RP.columns).intersection(keys.loc[impute_t_bc,'RNA'])
impute_t_rp = impute_RP[impute_t_bc].max(1)

atac_t_rp = atac_rp[impute_t_bc].max(1)

real_t_bc = real_matadata.index[(real_matadata['Celltype'] == 'CD4 T') | (real_matadata['Celltype'] == 'CD8 T') | (real_matadata['Celltype'] == 'other T')]
# real_t_bc = real_matadata.index[ (real_matadata['Celltype'] == 'CD4 T')]
real_t_rp = real_RP[real_t_bc].max(1)

In [None]:
scipy.stats.spearmanr(impute_t_rp,real_t_rp)

In [None]:
scipy.stats.spearmanr(atac_t_rp,real_t_rp)

In [None]:
scipy.stats.spearmanr(real_t_rp[bulk_ovlp_target],bulk_t_target_rp[bulk_ovlp_target])

In [None]:
scipy.stats.spearmanr(impute_t_rp[bulk_ovlp_target],bulk_t_target_rp[bulk_ovlp_target])

In [None]:
scipy.stats.spearmanr(atac_t_rp[bulk_ovlp_target],bulk_t_target_rp[bulk_ovlp_target])

In [None]:
impute_mono_bc = tmp_meta.index[(tmp_meta['CellType'] == 'non-classical_monocytes') | (tmp_meta['CellType'] == 'classical_monocytes') | (tmp_meta['CellType'] == 'intermediate_monocytes')]
impute_mono_bc = set(impute_RP.columns).intersection(keys.loc[impute_mono_bc,'RNA'])
impute_mono_rp = impute_RP[impute_mono_bc].max(1)

atac_mono_rp = atac_rp[impute_mono_bc].max(1)

real_mono_bc = real_matadata.index[(real_matadata['Celltype'] == 'Mono')]
real_mono_rp = real_RP[real_mono_bc].max(1)

In [None]:
scipy.stats.spearmanr(impute_mono_rp,real_mono_rp)

In [None]:
scipy.stats.spearmanr(atac_mono_rp,real_mono_rp)

In [None]:
scipy.stats.spearmanr(real_mono_rp[bulk_ovlp_target],bulk_t_target_rp[bulk_ovlp_target])

In [None]:
data_df = pd.DataFrame([['Imputed', 'T', 0.7226870110986734], ['Imputed', 'Mono', 0.6901540022681896], 
                        ['scATAC', 'T', 0.6412991747820043], ['scATAC', 'Mono', 0.6255047796652997], 
                        ['Bulk', 'T', 0.3856955331547345], ['Bulk', 'Mono', 0.3584068064358755]])

In [None]:
fig, ax = plt.subplots(figsize=(6,6))
sns.barplot(x=1, y=2, data=data_df, hue=0, palette='Set3', ax=ax)
ax.set_xlabel('Cell Type')
ax.set_ylabel('RP Correlation with scCUT&Pro')
fig.show()
fig.savefig('Figures/RP_correlation_imputed_atac_bulk.pdf')

In [None]:
# target venn overlap

In [None]:
fig, ax = plt.subplots(figsize=(6,6))
out = venn3_unweighted([736, 637, 232, 858, 11, 110, 21], ('scCUT&Pro', 'SCRIP Imputed', 'Bulk'))
for x in range(len(out.subset_labels)):
    if out.subset_labels[x] is not None:
        out.subset_labels[x].set_fontsize(20)
fig.show()
fig.savefig('Figures/T_target_venn.pdf')

In [None]:
fig, ax = plt.subplots(figsize=(6,6))
out = venn3_unweighted([661, 545, 232, 702, 75, 191, 32], ('scCUT&Pro', 'SCRIP Imputed', 'Bulk'))
for x in range(len(out.subset_labels)):
    if out.subset_labels[x] is not None:
        out.subset_labels[x].set_fontsize(20)
fig.show()
fig.savefig('Figures/mono_target_venn.pdf')

In [None]:
with open('/fs/home/dongxin/Projects/SCRIPT/scATAC/example/histone/SCRIPT_1114_remove_others/imputation/H3K27ac_T_rp_gene.txt', 'w+') as f:
    for i in impute_t_rp[bulk_ovlp_target].sort_values(ascending=False)[0:1000].index.tolist():
        f.write(f'{i}\n')
with open('/fs/home/dongxin/Projects/SCRIPT/scATAC/example/histone/SCRIPT_1114_remove_others/imputation/H3K27ac_mono_rp_gene.txt', 'w+') as f:
    for i in impute_mono_rp[bulk_ovlp_target].sort_values(ascending=False)[0:1000].index.tolist():
        f.write(f'{i}\n')

In [None]:
set1 = set(real_mono_rp[bulk_ovlp_target].sort_values(ascending=False)[0:1000].index)
set2 = set(impute_mono_rp[bulk_ovlp_target].sort_values(ascending=False)[0:1000].index)
set3 = set(bulk_mono_target_list[0:1000])

venn3([set1, set2, set3], ('scCUT&Pro', 'SCRIP Imputed', 'Bulk'))
plt.show()