In [1]:
%cd ..

/mnt/ceph/users/zzhang/CRISPR_pred/crispr_kinn


In [2]:
from silence_tensorflow import silence_tensorflow
silence_tensorflow()
import numpy as np
import pandas as pd
from tqdm import tqdm
import scipy.stats as ss
from sklearn.metrics import average_precision_score, roc_auc_score
import seaborn as sns
import matplotlib.pyplot as plt
import tensorflow as tf
import src
import amber
import os

Using TensorFlow backend.


In [3]:
from src.neural_network_builder import KineticNeuralNetworkBuilder
from src.reload import reload_from_dir
from src import crispr_kinn_predict
from src.crispr_kinn_predict import predict_on_dataframe, plot_dataframe, \
    get_letter_index, featurize_alignment

wd = "/mnt/home/zzhang/ceph/CRISPR_pred/crispr_kinn/"

In [4]:
# setup metric recorder df

metrics = pd.DataFrame(columns=['data', 'model', 'auroc', 'aupr'])

In [5]:
# trainEnv parameters
evo_params = dict(
    model_fn = KineticNeuralNetworkBuilder,
    #model_fn = KineticEigenModelBuilder,
    samps_per_gen = 10,   # how many arcs to sample in each generation; important
    max_gen = 200,
    patience = 50,
    n_warmup_gen = 0,
    #train_data = (x_train, y_train),
    #test_data = (x_test, y_test)
)

# manager configs
manager_kwargs={
    'output_op': lambda: tf.keras.layers.Lambda(lambda x: tf.math.log(x)/np.log(10), name="output_log"),  # change the clip as well
    'n_feats': 25,  # remember to change this!!
    'n_channels': 9,
    'batch_size': 128,
    'epochs': 30,
    'earlystop': 10,
    'verbose': 0
}

In [6]:
sess = tf.Session()
kinn_paths = [
    "outputs/2022-05-21/KINN-wtCas9_cleave_rate_log-finkelstein-0-rep1-gRNA1/",
    "outputs/2022-05-21/KINN-wtCas9_cleave_rate_log-finkelstein-0-rep2-gRNA1/",
    "outputs/2022-05-21/KINN-wtCas9_cleave_rate_log-finkelstein-0-rep3-gRNA1/",
    "outputs/2022-05-21/KINN-wtCas9_cleave_rate_log-finkelstein-0-rep4-gRNA1/",
    "outputs/2022-05-21/KINN-wtCas9_cleave_rate_log-finkelstein-0-rep5-gRNA1/",    
    "outputs/2022-05-21/KINN-wtCas9_cleave_rate_log-finkelstein-0-rep1-gRNA2/",
    "outputs/2022-05-21/KINN-wtCas9_cleave_rate_log-finkelstein-0-rep2-gRNA2/",
    "outputs/2022-05-21/KINN-wtCas9_cleave_rate_log-finkelstein-0-rep3-gRNA2/",
    "outputs/2022-05-21/KINN-wtCas9_cleave_rate_log-finkelstein-0-rep4-gRNA2/",
    "outputs/2022-05-21/KINN-wtCas9_cleave_rate_log-finkelstein-0-rep5-gRNA2/",
    
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-4-rep1-gRNA1/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-4-rep2-gRNA1/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-4-rep3-gRNA1/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-5-rep1-gRNA1/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-5-rep2-gRNA1/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-5-rep3-gRNA1/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-6-rep1-gRNA1/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-6-rep2-gRNA1/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-6-rep3-gRNA1/",
    
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-4-rep1-gRNA2/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-4-rep2-gRNA2/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-4-rep3-gRNA2/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-5-rep1-gRNA2/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-5-rep2-gRNA2/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-5-rep3-gRNA2/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-6-rep1-gRNA2/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-6-rep2-gRNA2/",
    #"outputs/2022-05-30/KINN-wtCas9_cleave_rate_log-uniform-6-rep3-gRNA2/",

]

model_ids = ['_'.join(x.split('/')[-2].split('-')[-4:]) for x in kinn_paths]

kinns = []
for p in kinn_paths:
    kinns.append(
        reload_from_dir(
            wd=p,
            sess=sess,
            manager_kwargs=manager_kwargs,
            model_fn=evo_params['model_fn']
        )
    )


In [7]:
# load amber CNNs

def reload_cnn_from_dir(wd):
    train_hist = pd.read_table(os.path.join(wd, "train_history.csv"), sep=",", header=None)
    best_trial_id = train_hist.sort_values(2, ascending=False).head(1)[0]
    return tf.keras.models.load_model(os.path.join(wd, "weights", "trial_%i"%best_trial_id, "bestmodel.h5"))

In [8]:
cnns = []
fp_temp = "outputs/2022-10-15/CNN-wtCas9_cleave_rate_log-rep{rep}-gRNA{grna}/"
cnn_paths = [
    fp_temp.format(rep=rep, grna=grna)
    for rep in range(1, 6)
    for grna in (1,2)
]
for c in cnn_paths:
    cnns.append(reload_cnn_from_dir(c))


In [9]:
cnn_model_ids = ['CNN_'+'_'.join(x.split('/')[-2].split('-')[-2:]) for x in cnn_paths]

Kleinstiver et al., Nature 2015 - 5 gRNAs, GUIDE-seq in vivo

In [10]:
kl_df = pd.read_csv(wd+"baselines/CRISPR_Net/data/Dataset_II_mismatch/dataset_II-5/Kleinstiver_5gRNA_wholeDataset.csv")
kl_df['off_seq'] = kl_df['off_seq'].str.upper()
# mismatch-only data has the same sgRNA_seq and sgRNA_type
kl_df['sgRNA_type'] = kl_df['sgRNA_seq']
alignments = [x[1].str[::-1].tolist() for x in tqdm(kl_df.iloc[:,[0,1]].iterrows(), total=kl_df.shape[0])]
ltidx = get_letter_index(build_indel=True)
fea = featurize_alignment(alignments, ltidx)

100%|██████████| 95829/95829 [00:20<00:00, 4705.35it/s]


In [11]:
preds = []
aucs = []
auprs = []

for k in kinns:
    preds.append(k.predict(fea))
    aucs.append( roc_auc_score(y_true=kl_df.label, y_score=preds[-1]) )
    auprs.append( average_precision_score(y_true=kl_df.label, y_score=preds[-1]) )

In [12]:
# add average of predictions
auprs.append(average_precision_score(
    y_true=kl_df.label, 
    y_score=np.array(preds).squeeze().mean(axis=0)))

aucs.append(roc_auc_score(
    y_true=kl_df.label, 
    y_score=np.array(preds).squeeze().mean(axis=0)))

In [13]:
metrics = metrics.append(pd.DataFrame({
    'data': ['Kleinstiver_5gRNA']*(len(model_ids)+1),
    'model': model_ids + ['kinn_ensemble'],
    'auroc': aucs,
    'aupr': auprs,
}), ignore_index=True)
#metrics

In [14]:
preds = []
aucs = []
auprs = []

for k in cnns:
    preds.append(k.predict(fea))
    aucs.append( roc_auc_score(y_true=kl_df.label, y_score=preds[-1]) )
    auprs.append( average_precision_score(y_true=kl_df.label, y_score=preds[-1]) )

In [15]:
# add average of predictions
auprs.append(average_precision_score(
    y_true=kl_df.label, 
    y_score=np.array(preds).squeeze().mean(axis=0)))

aucs.append(roc_auc_score(
    y_true=kl_df.label, 
    y_score=np.array(preds).squeeze().mean(axis=0)))

In [16]:
metrics = metrics.append(pd.DataFrame({
    'data': ['Kleinstiver_5gRNA']*(len(model_ids)+1),
    'model': cnn_model_ids + ['cnn_ensemble'],
    'auroc': aucs,
    'aupr': auprs,
}), ignore_index=True)
metrics

Unnamed: 0,data,model,auroc,aupr
0,Kleinstiver_5gRNA,finkelstein_0_rep1_gRNA1,0.972203,0.121124
1,Kleinstiver_5gRNA,finkelstein_0_rep2_gRNA1,0.967521,0.124632
2,Kleinstiver_5gRNA,finkelstein_0_rep3_gRNA1,0.970767,0.125742
3,Kleinstiver_5gRNA,finkelstein_0_rep4_gRNA1,0.973076,0.172079
4,Kleinstiver_5gRNA,finkelstein_0_rep5_gRNA1,0.971717,0.137392
5,Kleinstiver_5gRNA,finkelstein_0_rep1_gRNA2,0.984515,0.199862
6,Kleinstiver_5gRNA,finkelstein_0_rep2_gRNA2,0.977871,0.214388
7,Kleinstiver_5gRNA,finkelstein_0_rep3_gRNA2,0.973184,0.127002
8,Kleinstiver_5gRNA,finkelstein_0_rep4_gRNA2,0.978864,0.181173
9,Kleinstiver_5gRNA,finkelstein_0_rep5_gRNA2,0.975924,0.192743


Listgarten et al, Nat BME 2018 - GUIDE-seq in vivo - Mismatch Only

In [17]:
ls_df = pd.read_csv(wd+"baselines/CRISPR_Net/data/Dataset_II_mismatch/dataset_II-6/Listgarten_22gRNA_wholeDataset.csv")
ls_df['off_seq'] = ls_df['off_seq'].str.upper()
# mismatch-only data has the same sgRNA_seq and sgRNA_type
ls_df['sgRNA_type'] = ls_df['sgRNA_seq']
ls_df['Read'] = ls_df['read']
alignments = [x[1].str[::-1].tolist() for x in tqdm(ls_df.iloc[:,[0,1]].iterrows(), total=ls_df.shape[0])]
ltidx = get_letter_index(build_indel=True)
fea = featurize_alignment(alignments, ltidx)

100%|██████████| 383463/383463 [01:21<00:00, 4696.82it/s]


In [18]:
preds = []
aucs = []
auprs = []
for k in kinns:
    preds.append(k.predict(fea))
    aucs.append( roc_auc_score(y_true=ls_df.label, y_score=preds[-1]) )
    auprs.append( average_precision_score(y_true=ls_df.label, y_score=preds[-1]) )

In [19]:
# add average of predictions
auprs.append(average_precision_score(
    y_true=ls_df.label, 
    y_score=np.array(preds).squeeze().mean(axis=0)))

aucs.append(roc_auc_score(
    y_true=ls_df.label, 
    y_score=np.array(preds).squeeze().mean(axis=0)))

In [20]:
metrics = metrics.append(pd.DataFrame({
    'data': ['Listgarten_22gRNA']*(len(model_ids)+1),
    'model': model_ids + ['kinn_ensemble'],
    'auroc': aucs,
    'aupr': auprs,
}), ignore_index=True)
#metrics

In [21]:
preds = []
aucs = []
auprs = []
for k in cnns:
    preds.append(k.predict(fea))
    aucs.append( roc_auc_score(y_true=ls_df.label, y_score=preds[-1]) )
    auprs.append( average_precision_score(y_true=ls_df.label, y_score=preds[-1]) )

In [22]:
# add average of predictions
auprs.append(average_precision_score(
    y_true=ls_df.label, 
    y_score=np.array(preds).squeeze().mean(axis=0)))

aucs.append(roc_auc_score(
    y_true=ls_df.label, 
    y_score=np.array(preds).squeeze().mean(axis=0)))

In [23]:
metrics = metrics.append(pd.DataFrame({
    'data': ['Listgarten_22gRNA']*(len(model_ids)+1),
    'model': cnn_model_ids + ['cnn_ensemble'],
    'auroc': aucs,
    'aupr': auprs,
}), ignore_index=True)


In [24]:
metrics

Unnamed: 0,data,model,auroc,aupr
0,Kleinstiver_5gRNA,finkelstein_0_rep1_gRNA1,0.972203,0.121124
1,Kleinstiver_5gRNA,finkelstein_0_rep2_gRNA1,0.967521,0.124632
2,Kleinstiver_5gRNA,finkelstein_0_rep3_gRNA1,0.970767,0.125742
3,Kleinstiver_5gRNA,finkelstein_0_rep4_gRNA1,0.973076,0.172079
4,Kleinstiver_5gRNA,finkelstein_0_rep5_gRNA1,0.971717,0.137392
5,Kleinstiver_5gRNA,finkelstein_0_rep1_gRNA2,0.984515,0.199862
6,Kleinstiver_5gRNA,finkelstein_0_rep2_gRNA2,0.977871,0.214388
7,Kleinstiver_5gRNA,finkelstein_0_rep3_gRNA2,0.973184,0.127002
8,Kleinstiver_5gRNA,finkelstein_0_rep4_gRNA2,0.978864,0.181173
9,Kleinstiver_5gRNA,finkelstein_0_rep5_gRNA2,0.975924,0.192743


In [25]:
%load_ext watermark
%watermark -n -u -v -iv -w

Last updated: Sun Oct 16 2022

Python implementation: CPython
Python version       : 3.7.9
IPython version      : 7.22.0

tensorflow: 1.15.0
amber     : 0.1.2
pandas    : 1.0.3
numpy     : 1.19.5
seaborn   : 0.11.1
scipy     : 1.7.3
matplotlib: 3.4.3
src       : 0.0.1

Watermark: 2.3.1

