In [1]:
import numpy as np
import pandas as pd
import sys
import os
import gc

import pickle
from collections import defaultdict
from tqdm import tqdm
from glob import glob
from sklearn.metrics import roc_auc_score

import scipy.stats

sys.path.append("/home/icb/sergey.vilov/workspace/MLM/utils") 
from misc import model_alias, rna_models

In [2]:
data_dir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/'

In [3]:
def read_fasta(fasta):

    seqs = defaultdict(str)
    
    with open(fasta, 'r') as f:
        for line in f:
            if line.startswith('>'):
                seq_name = line[1:].rstrip()
            else:
                seqs[seq_name] += line.rstrip().upper()
    return seqs

def reverse_complement(seq):
    '''
    Take sequence reverse complement
    '''
    compl_dict = {'A':'T', 'C':'G', 'G':'C', 'T':'A','a':'t', 'c':'g', 'g':'c', 't':'a'}
    compl_seq = ''.join([compl_dict.get(x,x) for x in seq])
    rev_seq = compl_seq[::-1]
    return rev_seq

def get_auc(df,model_name,scores):
    dataset_scores = {}
    for score_name in scores:
        if model_name+'-'+score_name in df.columns:
            y = df.label.values
            X = df[model_name+'-'+score_name].values
            y = y[~np.isnan(X)]
            X = X[~np.isnan(X)]
            score = roc_auc_score(y,X)
            dataset_scores[model_name+'-'+score_name] = max(score,1-score)
    return pd.Series(dataset_scores)

mapping = {'A':0,'C':1,'G':2,'T':3}

In [4]:
def get_model_probas(glob_path):
    res = {}
    for probas_file in glob(glob_path):
        #print(probas_file)
        with open(probas_file, 'rb') as f:
            fold_res = pickle.load(f)
            if 'left_shift' not in fold_res.keys():
                fold_res['left_shift'] = np.zeros((len(fold_res['seq_names']),1))
            fold_res = {seq_name:{'probs':prob,'seq':seq,'left_shift':left_shift} for seq_name,prob,seq,left_shift in zip(fold_res['seq_names'],fold_res['probs'],fold_res['seqs'],fold_res['left_shift'])}
            res.update(fold_res)
    return res

In [5]:
def get_model_scores(model_name, model_preds):
    
    epsilon=1e-14

    n_var_added = 0

    res = defaultdict(dict)

    df =  utr_variants.drop_duplicates(subset='var_id').set_index('var_id')
    
    for var_id, var in tqdm(df.iterrows(), total=len(df),bar_format='{bar}|{percentage:3.0f}%'):
        
        if var.vartype=='SNP':
            
            altseq, altprobs = None, None
            
            if var_id + '_ref' in model_preds.keys():
                
                refseq = model_preds[var_id + '_ref']['seq']
                refprobs = model_preds[var_id + '_ref']['probs'][:,:4]

                if var_id + '_alt' in model_preds.keys():
                    altseq = model_preds[var_id + '_alt']['seq']
                    altprobs = model_preds[var_id + '_alt']['probs'][:,:4] 

                if refseq.isupper():
                    varpos_rel = len(refseq)//2
                else:
                    varpos_rel = [idx for idx,c in enumerate(refseq) if c.islower()][0]
                    if model_name in rna_models and var.strand=='-':
                        varpos_rel = len(refseq)-varpos_rel-1
                    
            elif var.seq_name in model_preds.keys():

                refseq = model_preds[var.seq_name]['seq']        
                refprobs = model_preds[var.seq_name]['probs'][:,:4]

                varpos_rel = var.pos_rel
                
            else:

                continue  

            if model_name in rna_models and var.strand=='-':
    
                refseq = reverse_complement(refseq)
                refprobs = refprobs[::-1,[3,2,1,0]]

                if altseq:
                    altseq = reverse_complement(altseq)
                    altprobs = altprobs[::-1,[3,2,1,0]]

            assert refseq[varpos_rel].upper() == var.ref, f'{var}'
            refprobs = refprobs/refprobs.sum(1,keepdims=1)

            if altseq:
                
                assert altseq[varpos_rel].upper() == var.alt
                altprobs = altprobs/altprobs.sum(1,keepdims=1)
                
                dependency = np.max(abs(
                            (np.log2(refprobs+epsilon)-np.log2(1-refprobs+epsilon))
                            -(np.log2(altprobs+epsilon)-np.log2(1-altprobs+epsilon))
                                       )
                                 ,1)

                #R = min(varpos_rel,len(dependency)-varpos_rel)
                #
                #if R>0:
                #    vis = np.stack((dependency[varpos_rel-R:varpos_rel],
                #                        dependency[varpos_rel:varpos_rel+R])).mean()

                vis = np.delete(dependency,varpos_rel).mean()
                
                res[var_id][model_name+'-VIS'] = vis

            else:
                
                refprob = refprobs[varpos_rel][mapping[var.ref]]
                altprob = refprobs[varpos_rel][mapping[var.alt]]
                
                res[var_id][model_name+'-palt_inv'] = -np.log(altprob+1e-14)
                res[var_id][model_name+'-pref'] = np.log(refprob+1e-14)
                #utr_variants.at[var_idx, model_name+'-pratio'] = np.log(refprob+epsilon)-np.log(altprob+epsilon)

            n_var_added += 1

    res = pd.DataFrame(res.values(),index=res.keys())
    
    return res

In [6]:
utr_variants = pd.read_csv(data_dir+'variants/selected/variants_snp.tsv', sep='\t') 

utr_variants['pos_rel'] = utr_variants.pos-utr_variants.seq_start

# Conservation-based models

In [7]:
for model in ('PhyloP-100way','PhyloP-241way'):
    phylop_res = pd.read_csv(data_dir + f'variants/prefiltered/PhyloP/{model}.3utr.scores.tsv.gz', sep='\t',
                           header = None, names=['chrom','pos',f'{model}-score'])
    
    phylop_res.pos = phylop_res.pos-1 #to 0-based
    
    utr_variants = utr_variants.merge(phylop_res,how='left')
    
    del phylop_res

In [8]:
#add CADD 1.7 scores

for vartype in ('snps',):
    cadd_res = pd.read_csv(data_dir + f'variants/prefiltered/CADD/CADD.3utr.{vartype}.scores.tsv.gz', sep='\t',
                           header = None, names=['chrom','pos','ref','alt','CADD-raw','CADD-phred'])
    
    cadd_res.pos = cadd_res.pos-1 #to 0-based
    
    utr_variants = utr_variants.merge(cadd_res,how='left').drop_duplicates()
    
    del cadd_res

In [9]:
gc.collect()

0

# Scores from LM probabilities

In [10]:
probs_path = data_dir + 'human_3utr/probs/'

In [11]:
for model_name in ('Zoo-AL','DNABERT', 'DNBT-3UTR-RNA', 'NT-MS-v2-100M', 'NT-3UTR-RNA',
                   'STSP-3UTR-RNA','STSP-3UTR-RNA-SA','STSP-3UTR-DNA','STSP-3UTR-RNA-HS',):

    print(model_name)
    
    model_probas = get_model_probas(probs_path + '/' + model_alias[model_name] + '/predictions*.pickle')
    
    print(f'{model_name} loaded, {len(model_probas)} sequences')
    
    res = get_model_scores(model_name, model_probas)

    n_var_added = utr_variants.var_id.isin(res.index).sum()

    if n_var_added>0:
        utr_variants=utr_variants.merge(res.reset_index(names='var_id'),how='left')
    
    print(n_var_added, 'variants added')

Zoo-AL
Zoo-AL loaded, 18178 sequences


██████████|100%


70261 variants added
DNABERT
DNABERT loaded, 18134 sequences


██████████|100%


70097 variants added
DNBT-3UTR-RNA
DNBT-3UTR-RNA loaded, 18134 sequences


██████████|100%


70097 variants added
NT-MS-v2-100M
NT-MS-v2-100M loaded, 18178 sequences


██████████|100%


70261 variants added
NT-3UTR-RNA
NT-3UTR-RNA loaded, 18134 sequences


██████████|100%


70097 variants added
STSP-3UTR-RNA
STSP-3UTR-RNA loaded, 18134 sequences


██████████|100%


70097 variants added
STSP-3UTR-RNA-SA
STSP-3UTR-RNA-SA loaded, 18134 sequences


██████████|100%


70097 variants added
STSP-3UTR-DNA
STSP-3UTR-DNA loaded, 18178 sequences


██████████|100%


70261 variants added
STSP-3UTR-RNA-HS
STSP-3UTR-RNA-HS loaded, 18134 sequences


██████████|100%


70097 variants added


In [12]:
#zero_shot_dir = data_dir + 'variants/zero-shot-probs'
#
#
#zero_shot_dir_models = {'NT-3UTR-RNA': zero_shot_dir,
#                 'NT-MS-v2-100M': zero_shot_dir,
#                 'DNABERT': zero_shot_dir,
#                 'DNBT-3UTR-RNA': zero_shot_dir,
#                 'STSP-3UTR-RNA': data_dir + 'variants/embeddings/', 
#                 'STSP-3UTR-RNA-SA': data_dir + 'variants/embeddings/',
#                 'STSP-3UTR-DNA': data_dir + 'variants/embeddings/', 
#                 'STSP-3UTR-RNA-HS': data_dir + 'variants/embeddings/',
#                 'Zoo-AL': data_dir + 'human_3utr/probs/'}

In [13]:
#for model_name in ('Zoo-AL','DNABERT', 'DNBT-3UTR-RNA', 'NT-MS-v2-100M', 'NT-3UTR-RNA',
#                   'STSP-3UTR-RNA','STSP-3UTR-RNA-SA','STSP-3UTR-DNA','STSP-3UTR-RNA-HS',):
#
#    print(model_name)
#    
#    model_probas = get_model_probas(zero_shot_dir_models[model_name] + '/' + model_alias[model_name] + '/predictions*.pickle')
#    model_probas = {k:v for k,v in model_probas.items() if not k.endswith('_alt') }
#
#    print(f'{model_name} loaded, {len(model_probas)} sequences')
#    
#    res = get_model_scores(model_name, model_probas)
#
#    #res.columns = [x + '-zs' for x in res.columns]
#
#    n_var_added = utr_variants.var_id.isin(res.index).sum()
#
#    if n_var_added>0:
#        utr_variants=utr_variants.merge(res.reset_index(names='var_id'),how='left')
#
#    print(n_var_added, 'variants added')

In [14]:
for model_name in model_alias.keys():
    if model_name+'-palt_inv' in utr_variants.columns:
        utr_variants[model_name+'-pratio'] = utr_variants[model_name+'-pref']+utr_variants[model_name+'-palt_inv']

In [15]:
#scores=('pref','pratio','pref-zs','pratio-zs','score','raw')
#scores=('pratio','score','raw')
#res = []
#
#for model_name in ('Zoo-AL','DNABERT', 'DNABERT-3UTR', 'NT-MS-v2-100M', 'NTv2-100M-3UTR',
#                   'StateSpace', 'StateSpace-SA'):
#for model_name in ('NT-MS-v2-100M','NTv2-100M-3UTR','NTv2-100M-3UTR*','Zoo-AL','CADD','PhyloP-241way','PhyloP-100way'):
#    
#    res_model = utr_variants.groupby(['split']).apply(lambda x:get_auc(x,model_name,scores)).loc[['clinvar','gnomAD','eQTL-susie','CADD']]
#    res.append(res_model)
#
#pd.concat(res,axis=1)

# Variant influence score

In [16]:
vis_dir = data_dir + '/variants/variant_influence_score/'

In [17]:
for model_name in ('DNABERT', 'DNBT-3UTR-RNA', 'NT-MS-v2-100M', 'NT-3UTR-RNA',
                   'STSP-3UTR-RNA','STSP-3UTR-RNA-SA','STSP-3UTR-DNA','STSP-3UTR-RNA-HS',):
    
    print(model_name)
    
    if model_name in rna_models:
        vis_fa = data_dir + '/variants/selected/variants_rna.fa'
    else:
        vis_fa = data_dir + '/variants/selected/variants_dna_fwd.fa'
    
    vis_fasta = read_fasta(vis_fa)
    
    vis_predictions_path = vis_dir + model_alias[model_name] + '/predictions*'
    model_probas =  get_model_probas(vis_predictions_path)
    
    print(f'{model_name} loaded, {len(model_probas)} sequences')
    
    res = get_model_scores(model_name, model_probas)

    n_var_added = utr_variants.var_id.isin(res.index).sum()

    if n_var_added>0:
        utr_variants=utr_variants.merge(res.reset_index(names='var_id'),how='left')

    print(n_var_added, 'variants added')
    #utr_variants.rename(columns={model_name+'-'+score:model_name+'-'+score+'_vis' for score in ('palt','pref','pratio')},inplace=True)

DNABERT
DNABERT loaded, 30518 sequences


██████████|100%


21285 variants added
DNBT-3UTR-RNA
DNBT-3UTR-RNA loaded, 30518 sequences


██████████|100%


21285 variants added
NT-MS-v2-100M
NT-MS-v2-100M loaded, 30518 sequences


██████████|100%


21285 variants added
NT-3UTR-RNA
NT-3UTR-RNA loaded, 30518 sequences


██████████|100%


21285 variants added
STSP-3UTR-RNA
STSP-3UTR-RNA loaded, 30518 sequences


██████████|100%


21285 variants added
STSP-3UTR-RNA-SA
STSP-3UTR-RNA-SA loaded, 30518 sequences


██████████|100%


21285 variants added
STSP-3UTR-DNA
STSP-3UTR-DNA loaded, 30518 sequences


██████████|100%


21285 variants added
STSP-3UTR-RNA-HS
STSP-3UTR-RNA-HS loaded, 30518 sequences


██████████|100%


21285 variants added


# Scores from embeddings

In [18]:
def compute_embeddings_score(seq_names,embeddings,losses,model_name):
    res = []
    #loss_ref_avg, loss_ref_central, loss_alt_avg, loss_alt_central = None, None, None, None #we don't compute score base on losses here
    for idx in range(0,len(embeddings),2):
        assert seq_names[idx]==seq_names[idx+1].replace('alt','ref')
        emb_ref, emb_alt = embeddings[idx], embeddings[idx+1]
        l2 = np.linalg.norm(emb_ref-emb_alt)
        l1 = np.linalg.norm((emb_ref-emb_alt), ord=1)
        dot = np.dot(emb_ref,emb_alt)
        cosine = dot/(np.linalg.norm(emb_ref)*np.linalg.norm(emb_alt))
        loss_ref, loss_alt = losses[idx], losses[idx+1]
        varname = seq_names[idx].replace('_ref','').split('_')
        res.append((varname[0],int(varname[1]),varname[2],varname[3],l1,l2,dot,cosine,loss_ref,loss_alt))
    res = pd.DataFrame(res,columns=['chrom','pos','ref','alt',
        f'{model_name}-l1',f'{model_name}-l2',f'{model_name}-dot',f'{model_name}-cosine',
        f'{model_name}-loss_ref', f'{model_name}-loss_alt'])
    return res

In [19]:
emb_dir  = data_dir + 'variants/embeddings/'

for model_name in ('DNABERT','DNBT-3UTR-RNA','DNABERT2','DNABERT2-ZOO','DNBT2-3UTR-RNA','NT-MS-v2-100M',
         'NT-3UTR-RNA','STSP-3UTR-RNA','STSP-3UTR-RNA-SA','STSP-3UTR-DNA','STSP-3UTR-RNA-HS',):

    print(model_name)
    
    with open(emb_dir + model_alias[model_name] + '/predictions.pickle', 'rb') as f:
        data = pickle.load(f)
        
    embeddings_scores = compute_embeddings_score(data['seq_names'],data['embeddings'], data['losses'],model_name)

    embeddings_scores[model_name+'-loss_diff'] = embeddings_scores[model_name+'-loss_alt']-embeddings_scores[model_name+'-loss_ref']

    utr_variants = utr_variants.merge(embeddings_scores, how='left')

    n_var_added = (~(utr_variants[model_name+'-l1']).isna()).sum()

    print(n_var_added, 'variants added')

DNABERT
70261 variants added
DNBT-3UTR-RNA
70261 variants added
DNABERT2
70261 variants added
DNABERT2-ZOO
70261 variants added
DNBT2-3UTR-RNA
70261 variants added
NT-MS-v2-100M
70261 variants added
NT-3UTR-RNA
70261 variants added
STSP-3UTR-RNA
70261 variants added
STSP-3UTR-RNA-SA
70261 variants added
STSP-3UTR-DNA
70261 variants added
STSP-3UTR-RNA-HS
70261 variants added


# Scores from supervised learning

In [20]:
for classifier in ('MLP',):#('LogisticRegression','MLP'):
    
    pred_dir  = data_dir + f'variants/predictions/merge_embeddings_1/{classifier}/'

    for model_name in ('DNABERT','DNBT-3UTR-RNA','DNABERT2','DNABERT2-ZOO','DNBT2-3UTR-RNA','NT-MS-v2-100M',
         'NT-3UTR-RNA','STSP-3UTR-RNA','STSP-3UTR-RNA-SA','STSP-3UTR-DNA','STSP-3UTR-RNA-HS',):
    
        print(model_name)

        model_res = []

        for subset in ('CADD','clinvar','gnomAD','eQTL-susie',):

            pred_res = pred_dir + subset  + '-' + model_alias[model_name] + '.tsv'
            
            if os.path.isfile(pred_res):
                subset_df = pd.read_csv(pred_res,sep='\t')
                model_res.append(subset_df)
            else:
                print(pred_res)

        if len(model_res)>0:
            model_res = pd.concat(model_res)
            model_res.rename(columns={'y_pred':model_name+'-'+classifier},inplace=True)
            utr_variants = utr_variants.merge(model_res,how='left')
            
            n_var_added = (~(utr_variants[model_name+'-'+classifier]).isna()).sum()
                
            print(n_var_added, 'variants added')

DNABERT
70261 variants added
DNBT-3UTR-RNA
70261 variants added
DNABERT2
70261 variants added
DNABERT2-ZOO
70261 variants added
DNBT2-3UTR-RNA
70261 variants added
NT-MS-v2-100M
70261 variants added
NT-3UTR-RNA
70261 variants added
STSP-3UTR-RNA
70261 variants added
STSP-3UTR-RNA-SA
70261 variants added
STSP-3UTR-DNA
70261 variants added
STSP-3UTR-RNA-HS
70261 variants added


In [21]:
utr_variants.replace([np.inf, -np.inf], np.nan, inplace=True)

In [22]:
utr_variants.to_csv(data_dir + 'all_scores/variant_scores.tsv.gz', sep='\t', index=None)