In [1]:
import numpy as np
import pandas as pd
import scipy
import sklearn.metrics
import os

import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

from scipy import stats
import statsmodels.stats.multicomp as mc

from statannotations.Annotator import Annotator

In [2]:
mpl.rcParams.update({'xtick.labelsize': 14, 'ytick.labelsize': 14, 
                     'axes.titlesize':14, 'axes.labelsize':16}) #default font sizes for plots

# Auxiliary functions

In [3]:
def compute_score(df):
    pearson_r = scipy.stats.pearsonr(df.y_true,df.y_pred)
    ci_95 = pearson_r.confidence_interval()
    ci_95 = np.diff(ci_95)[0]/2
    pearson_r = pearson_r[0]
    return (pearson_r,ci_95)
    
def get_best_models(df):

    def is_overlap(a, b):
        return max(0, min(a[1], b[1]) - max(a[0], b[0]))>0

    best_models = []

    best_auc, best_auc_err =  df.sort_values().iloc[-1]

    for model, (auc, auc_err) in df.items():
            if is_overlap((best_auc-best_auc_err,best_auc+best_auc_err),(auc-auc_err,auc+auc_err)):
                best_models.append(model)

    return best_models

def highlight_ns(x, best_models):
    #make the best model and models with insignificant difference with the best model bold
    cell_type = x.name
    return ['font-weight: bold' if model in best_models[cell_type] else ''
                for model in x.index]

# Collect predictions

In [28]:
regressor = 'Ridge' # Ridge or SVR
onlyref = 1

In [29]:
data_dir = f'/lustre/groups/epigenereg01/workspace/projects/vale/MLM/siegel_2022/predictions/onlyref_{onlyref}/{regressor}/'

In [30]:
#models = ['Species-aware','Species-agnostic','DNABERT','DNABERT-2','NT-MS-v2-500M','4mers','5mers','word2vec','effective_length']
models = ['DNABERT','DNABERT-2','Species-agnostic','Species-aware','NT-MS-v2-500M','5mers','effective_length']

In [31]:
res = {}

for response in ('stability', 'steady_state'):
    for cell_type in ('Jurkat', 'Beas2B'):
        res[(response,cell_type)] = []
        for model in models:
            res_tsv = data_dir + f'{cell_type}-{response}-{model}.tsv'
            if os.path.isfile(res_tsv):
                #df = pd.read_csv(res_tsv, sep='\t', skiprows=1, usecols=[2,7,8,36,38,39],names=['ids','iscontrol','parent_control_oligo','y_true','chrom','y_pred'])
                df = pd.read_csv(res_tsv, sep='\t', skiprows=1, usecols=[2,7,8,37,38,39],names=['ids','iscontrol','parent_control_oligo','y_true','chrom','y_pred'])
                df['model'] = model
                res[(response,cell_type)].append(df)
        if len(res[(response,cell_type)])>0:
            res[(response,cell_type)] = pd.concat(res[(response,cell_type)])
            N=res[(response,cell_type)].groupby('model').size().mean()
            print(response,cell_type,int(N))

stability Jurkat 4229
stability Beas2B 1110
steady_state Jurkat 4616
steady_state Beas2B 2418


# Visualize per fold scores

In [None]:
per_fold_scores = res[('stability','Jurkat')].groupby(['model','chrom']).apply(compute_score).rename('score')

In [11]:
per_fold_scores = per_fold_scores.reset_index().sort_values(by=['model','chrom'])

In [None]:
fig, ax = plt.subplots(figsize=(4,4))

ax = sns.swarmplot(data=per_fold_scores, x="model", y="score", order=models, ) #scatter plot
ax = sns.boxplot(data=per_fold_scores, x="model", y="score", order=models, boxprops={'facecolor':'None'})

box_pairs=[ ('MLM', '4mers'), ('MLM', '5mers'),('MLM', 'word2vec'),('MLM', 'effective_length')]

annotator = Annotator(ax, box_pairs, data=per_fold_scores, x="model", y="score", order=models)
annotator.configure(test='Wilcoxon', text_format='star', loc='inside', comparisons_correction="BH")
#annotator.configure(test='t-test_paired', text_format='star', loc='inside', comparisons_correction="BH")

annotator.apply_and_annotate()

ax.set_xlabel("")
ax.set_ylabel("score")
ax.tick_params(rotation=30)
ax.grid()

# Stability and Steady state prediction

In [27]:
preds_res = {}
best_models = {}


for cell_type in ('Jurkat', 'Beas2B'):
    
    for response in ('steady_state','stability'):
        
        preds_res[(cell_type,response)] = res[(response,cell_type)].groupby('model').apply(compute_score)
        best_models[(cell_type,response)] = get_best_models(preds_res[(cell_type,response)])
        
preds_res = pd.DataFrame(preds_res).map(lambda x: f'{x[0]:.2f}±{x[1]:.2f}' if type(x)==tuple else x)

preds_res.loc[models].style.apply(lambda x: highlight_ns(x, best_models))

Unnamed: 0_level_0,Jurkat,Jurkat,Beas2B,Beas2B
Unnamed: 0_level_1,steady_state,stability,steady_state,stability
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
DNABERT,0.08±0.02,0.07±0.02,0.10±0.03,0.08±0.04
DNABERT-2,0.17±0.02,0.22±0.02,0.11±0.03,0.26±0.04
Species-agnostic,0.25±0.02,0.32±0.02,0.31±0.03,0.43±0.04
Species-aware,0.27±0.02,0.33±0.02,0.28±0.03,0.46±0.03
NT-MS-v2-500M,0.15±0.02,0.24±0.02,0.20±0.03,0.27±0.04
5mers,0.24±0.02,0.43±0.02,0.19±0.03,0.39±0.04
