In [18]:
import numpy as np
import pandas as pd
import sys
import os
import gc

import pickle
from collections import defaultdict
from tqdm import tqdm
from glob import glob
from sklearn.metrics import roc_auc_score
from joblib import Parallel, delayed
from sklearn.utils import resample

import scipy.stats

sys.path.append("/home/icb/sergey.vilov/workspace/MLM/utils") 
from misc import model_alias, rna_models

In [19]:
data_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/'

In [None]:
import gzip
import re

gtf_gz =  '/ictstr01/groups/epigenereg01/projects/predict-gxe-eqtl/results/scooby-scPower/scooby_training_data/hg38/gencode.v32.annotation.gtf.gz'

exons = []

with gzip.open(gtf_gz,'r') as f:
    for line in f:
        line = line.decode('utf-8')
        if not line.startswith('#'):
            chrom, _, fragment, start, end, _, strand, _,  info = line.split('\t')
            if fragment == 'CDS':
                exons.append((chrom,int(start)-1,int(end)-1)) #to 0-based

exons_df = pd.DataFrame(exons,columns=['chrom','start','stop'])
exons_df = exons_df.sort_values(by=['chrom','start']).groupby('chrom').apply(lambda x:x.values.tolist(),include_groups=False)

In [None]:
def get_auc(df,model_name,scores):
    dataset_scores = {}
    for score_name in scores:
        if model_name+'-'+score_name in df.columns:
            y = df.label.values
            X = df[model_name+'-'+score_name].values
            y = y[~np.isnan(X)]
            X = X[~np.isnan(X)]
            score = roc_auc_score(y,X)
            dataset_scores[model_name+'-'+score_name] = max(score,1-score)
    return pd.Series(dataset_scores)

def is_included(chrom_pos,intervals_df):
    chrom = chrom_pos.name
    res = []
    for pos in chrom_pos.values:
        for start,stop in intervals_df.loc[chrom]:
            if start<=pos<stop:
                res.append(1)
                break
            if pos<start:
                res.append(0)
                break
        else:
            res.append(0)
    return pd.Series(res,index=chrom_pos.index)

def read_fasta(fasta):

    seqs = defaultdict(str)
    
    with open(fasta, 'r') as f:
        for line in f:
            if line.startswith('>'):
                seq_name = line[1:].rstrip()
            else:
                seqs[seq_name] += line.rstrip()
    return seqs

rc_dict = {'A':'T', 'C':'G', 'G':'C', 'T':'A','a':'t', 'c':'g', 'g':'c', 't':'a'}

def reverse_complement(seq):
    '''
    Take sequence reverse complement
    '''
    compl_seq = ''.join([rc_dict.get(x,x) for x in seq])
    rev_seq = compl_seq[::-1]
    return rev_seq

In [20]:
variants_df = pd.read_csv(data_dir + 'all_scores/nomargin/variant_scores_ann.tsv.gz',sep='\t').set_index('var_id')

In [None]:
motif_len = 5

eclip_df = pd.read_csv(data_dir + 'motif_analysis/eclip.tsv',sep='\t')
motifs_df = eclip_df[eclip_df.eCLIP_RBNS==True].rename(columns={'pos':'start'})
motifs_df['stop'] = motifs_df.start + motif_len
motifs_df = motifs_df.groupby('chrom')[['start','stop']].apply(lambda x:x.values.tolist(),include_groups=False)

In [None]:
motifs_isec = variants_df.groupby('chrom').pos.apply(lambda x:is_included(x,motifs_df)).rename('in_rbp_motif')

In [None]:
#cds_isec = variants_df.groupby('chrom').pos.apply(lambda x:is_included(x,exons_df)).rename('is_coding')

In [None]:
cds_isec.mean() 

In [None]:
#variants_df=variants_df.merge(cds_isec.reset_index())

In [None]:
# all 3'UTR FASTA alignments are distributed across subfolders
# find all FASTA files and their subfolders

fasta_path = data_dir + 'fasta/aligned_3UTR/fa/'

fasta_dirs = []

for file in glob(fasta_path + '**/*.fa', recursive=True):
    relative_path = os.path.relpath(file, fasta_path)
    folder, seq_id = relative_path.split('/')
    seq_id = seq_id.replace('.fa','')
    fasta_dirs.append((folder,seq_id))
    
fasta_dirs = pd.DataFrame(fasta_dirs, columns=['folder','seq_name']).set_index('seq_name').squeeze()

fasta_dirs.head()

In [None]:
def jaccard(l1,l2):
    n_intersect = len(set(l1).intersection(l2))
    n_union = len(set(l1).union(l2))
    return n_intersect/n_union

In [None]:
nucl_counts_df = []

for seq_name in tqdm(variants_df.seq_name.unique()):
    
    file_path = fasta_path + fasta_dirs.loc[seq_name] + '/' + seq_name + '.fa'
    
    fasta = read_fasta(file_path)
    
    fasta_seqs = list(fasta.values())
    fasta_seqs_rc = [reverse_complement(seq) for seq in fasta_seqs]

    repeat_idx = [idx for idx,nt in enumerate(fasta_seqs[0]) if nt.islower()]
    repeat_idx_rc = [idx for idx,nt in enumerate(fasta_seqs_rc[0]) if nt.islower()]
    
    fasta_seqs = [seq.upper() for seq in fasta_seqs]
    fasta_seqs_rc = [seq.upper() for seq in fasta_seqs_rc]
    
    for var_id, var in variants_df[variants_df.seq_name==seq_name].iterrows():
        if var.strand=='+':
            assert fasta_seqs[0][var.pos_rel]==var.ref
            in_repeat = var.pos_rel in repeat_idx
            counts = [np.sum([seq[var.pos_rel]==nt for seq in fasta_seqs]) 
                      for nt in ['A','C','G','T']]
            #ctx_ref = [seq[var.pos_rel-2:var.pos_rel]+seq[var.pos_rel+1:var.pos_rel+3] for seq in fasta_seqs if seq[var.pos_rel]==var.ref]
            #ctx_alt = [seq[var.pos_rel-2:var.pos_rel]+seq[var.pos_rel+1:var.pos_rel+3] for seq in fasta_seqs if seq[var.pos_rel]==var.alt]
        else:
            assert fasta_seqs_rc[0][var.pos_rel]==var.ref
            in_repeat = var.pos_rel in repeat_idx_rc
            counts = [np.sum([seq[var.pos_rel]==nt for seq in fasta_seqs_rc]) 
                      for nt in ['A','C','G','T']]
            #ctx_ref = [seq[var.pos_rel-2:var.pos_rel]+seq[var.pos_rel+1:var.pos_rel+3] for seq in fasta_seqs_rc if seq[var.pos_rel]==var.ref]
            #ctx_alt = [seq[var.pos_rel-2:var.pos_rel]+seq[var.pos_rel+1:var.pos_rel+3] for seq in fasta_seqs_rc if seq[var.pos_rel]==var.alt]
        #jaccard_context = jaccard(ctx_ref,ctx_alt)
        nucl_counts_df.append((var_id,in_repeat,*counts))

nucl_counts_df = pd.DataFrame(nucl_counts_df,columns=['var_id','in_repeat','A','C','G','T']).set_index('var_id')

In [None]:
variants_df = pd.concat([variants_df,nucl_counts_df],axis=1)

In [None]:
#variants_df.reset_index().to_csv(data_dir + 'all_scores/nomargin/variant_scores_ann.tsv.gz',sep='\t',index=None)

In [21]:
variants_df['depth'] = variants_df[['A','C','G','T']].sum(1)

In [22]:
models = ('NTv2-100M-3UTR','CADD','PhyloP-100way','PhyloP-241way','DNABERT-3UTR','StateSpace')

epsilon = 1e-14
for model_name in models:
    if model_name+'-palt' in variants_df.columns:
        variants_df[model_name+'-pratio'] = np.log( variants_df[model_name+'-pref']+epsilon)-np.log( variants_df[model_name+'-palt']+epsilon)

In [23]:
variants_df['zoo-pref'] = variants_df.apply(lambda x: x[x.ref]/x.depth, axis=1)
variants_df['zoo-palt'] = variants_df.apply(lambda x: x[x.alt]/x.depth, axis=1)

In [24]:
variants_df['3prime_dist'] = variants_df.apply(lambda x:x.seq_end-x.pos if x.strand=='+' else x.pos-x.seq_start,axis=1)
variants_df['5prime_dist'] = variants_df.apply(lambda x:x.seq_end-x.pos if x.strand=='-' else x.pos-x.seq_start,axis=1)

bins = (0,128,256,512,1024,2048,4096)
n_bins = len(bins)

variants_df['3_prime_bin'] = pd.cut(variants_df['3prime_dist'],bins=bins,)
variants_df['5_prime_bin'] = pd.cut(variants_df['5prime_dist'],bins=bins,)

#variants_df['3_prime_bin'] = pd.qcut(variants_df['3prime_dist'],q=10,)
#variants_df['5_prime_bin'] = pd.qcut(variants_df['5prime_dist'],q=10,)


In [25]:
df = variants_df[(variants_df.source=='CADD')&(variants_df.in_repeat>-1)]

df.groupby(['label','5_prime_bin'])[['NTv2-100M-3UTR-MLP','NTv2-100M-3UTR-pref','NTv2-100M-3UTR-palt','NTv2-100M-3UTR-pratio','zoo-pref','zoo-palt','depth','in_repeat','jaccard_context']].mean()

  df.groupby(['label','5_prime_bin'])[['NTv2-100M-3UTR-MLP','NTv2-100M-3UTR-pref','NTv2-100M-3UTR-palt','NTv2-100M-3UTR-pratio','zoo-pref','zoo-palt','depth','in_repeat','jaccard_context']].mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,NTv2-100M-3UTR-MLP,NTv2-100M-3UTR-pref,NTv2-100M-3UTR-palt,NTv2-100M-3UTR-pratio,zoo-pref,zoo-palt,depth,in_repeat,jaccard_context
label,5_prime_bin,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0.0,"(0, 128]",0.461907,0.21478,0.464979,-1.148481,0.192163,0.67828,183.844802,0.136164,0.21329
0.0,"(128, 256]",0.460718,0.220975,0.414874,-0.901559,0.171829,0.696496,183.658333,0.185,0.223564
0.0,"(256, 512]",0.466792,0.227429,0.384431,-0.705984,0.168716,0.704687,174.955939,0.249042,0.225861
0.0,"(512, 1024]",0.474245,0.247272,0.366408,-0.506805,0.175418,0.714067,163.960317,0.327228,0.233309
0.0,"(1024, 2048]",0.460566,0.241399,0.345299,-0.453372,0.180329,0.714626,157.993475,0.357547,0.243538
0.0,"(2048, 4096]",0.475585,0.245542,0.345736,-0.430157,0.180028,0.714651,151.096625,0.389274,0.251727
1.0,"(0, 128]",0.535241,0.565298,0.159417,2.07606,0.782533,0.108534,188.92246,0.141711,0.144356
1.0,"(128, 256]",0.543894,0.461247,0.183315,1.43592,0.79774,0.098439,186.591483,0.228707,0.131683
1.0,"(256, 512]",0.54942,0.442443,0.195015,1.256255,0.782444,0.107099,186.910714,0.240179,0.137187
1.0,"(512, 1024]",0.540306,0.422737,0.200867,1.067655,0.793896,0.105521,175.403392,0.294973,0.137192


In [26]:
df.groupby(['label','5_prime_bin'])[['DNABERT-3UTR-MLP','DNABERT-3UTR-pref','DNABERT-3UTR-palt','DNABERT-3UTR-pratio','zoo-pref','zoo-palt','depth','in_repeat','jaccard_context']].mean()

  df.groupby(['label','5_prime_bin'])[['DNABERT-3UTR-MLP','DNABERT-3UTR-pref','DNABERT-3UTR-palt','DNABERT-3UTR-pratio','zoo-pref','zoo-palt','depth','in_repeat','jaccard_context']].mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,DNABERT-3UTR-MLP,DNABERT-3UTR-pref,DNABERT-3UTR-palt,DNABERT-3UTR-pratio,zoo-pref,zoo-palt,depth,in_repeat,jaccard_context
label,5_prime_bin,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0.0,"(0, 128]",0.435784,0.245913,0.280434,-0.095883,0.192163,0.67828,183.844802,0.136164,0.21329
0.0,"(128, 256]",0.44145,0.239693,0.305125,-0.234138,0.171829,0.696496,183.658333,0.185,0.223564
0.0,"(256, 512]",0.457007,0.246445,0.30846,-0.244379,0.168716,0.704687,174.955939,0.249042,0.225861
0.0,"(512, 1024]",0.453746,0.260309,0.317279,-0.224516,0.175418,0.714067,163.960317,0.327228,0.233309
0.0,"(1024, 2048]",0.44164,0.255732,0.312153,-0.245349,0.180329,0.714626,157.993475,0.357547,0.243538
0.0,"(2048, 4096]",0.447889,0.262704,0.321939,-0.239142,0.180028,0.714651,151.096625,0.389274,0.251727
1.0,"(0, 128]",0.577424,0.311677,0.228531,0.35303,0.782533,0.108534,188.92246,0.141711,0.144356
1.0,"(128, 256]",0.56926,0.322474,0.218098,0.452572,0.79774,0.098439,186.591483,0.228707,0.131683
1.0,"(256, 512]",0.565261,0.318424,0.233085,0.364616,0.782444,0.107099,186.910714,0.240179,0.137187
1.0,"(512, 1024]",0.559131,0.337275,0.227111,0.501234,0.793896,0.105521,175.403392,0.294973,0.137192


In [27]:
df.groupby(['label','5_prime_bin'])[['StateSpace-MLP','StateSpace-pref','StateSpace-palt','StateSpace-pratio','zoo-pref','zoo-palt','depth','in_repeat','jaccard_context']].mean()

  df.groupby(['label','5_prime_bin'])[['StateSpace-MLP','StateSpace-pref','StateSpace-palt','StateSpace-pratio','zoo-pref','zoo-palt','depth','in_repeat','jaccard_context']].mean()


Unnamed: 0_level_0,Unnamed: 1_level_0,StateSpace-MLP,StateSpace-pref,StateSpace-palt,StateSpace-pratio,zoo-pref,zoo-palt,depth,in_repeat,jaccard_context
label,5_prime_bin,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0.0,"(0, 128]",0.438371,0.246148,0.276528,-0.086293,0.192163,0.67828,183.844802,0.136164,0.21329
0.0,"(128, 256]",0.438561,0.241121,0.298413,-0.22724,0.171829,0.696496,183.658333,0.185,0.223564
0.0,"(256, 512]",0.452759,0.242226,0.307809,-0.265209,0.168716,0.704687,174.955939,0.249042,0.225861
0.0,"(512, 1024]",0.452295,0.255085,0.311451,-0.23027,0.175418,0.714067,163.960317,0.327228,0.233309
0.0,"(1024, 2048]",0.451086,0.250468,0.305111,-0.224261,0.180329,0.714626,157.993475,0.357547,0.243538
0.0,"(2048, 4096]",0.440415,0.255882,0.313863,-0.22541,0.180028,0.714651,151.096625,0.389274,0.251727
1.0,"(0, 128]",0.554616,0.301799,0.231959,0.301503,0.782533,0.108534,188.92246,0.141711,0.144356
1.0,"(128, 256]",0.555204,0.316013,0.21423,0.448742,0.79774,0.098439,186.591483,0.228707,0.131683
1.0,"(256, 512]",0.562768,0.313306,0.234183,0.341536,0.782444,0.107099,186.910714,0.240179,0.137187
1.0,"(512, 1024]",0.559169,0.334989,0.222216,0.505854,0.793896,0.105521,175.403392,0.294973,0.137192


In [None]:
df = variants_df[(variants_df.split=='gnomAD')].copy()


#TEST if coding variants affect the score
#df = variants_df[(variants_df.split=='CADD')&(variants_df.is_coding==0)].copy()

#TEST if negative variants affect the score
pos_df = variants_df[(variants_df.split=='gnomAD')&(variants_df.label==1)]
neg_df = variants_df[(variants_df.split=='gnomAD')&(variants_df.label==0)]
df = pd.concat([pos_df,neg_df])




In [None]:
res = []
scores=('pref','pratio','raw','score')

for model_name in models:
    
    res_model = df.groupby(['5_prime_bin'],observed=False ).apply(lambda x:get_auc(x,model_name,scores),include_groups=False)
    #res_model = get_auc(df,model_name,scores)
    res.append(res_model)

pd.concat(res,axis=1)