# Probablistic model building genetic algorithm

In [1]:
%cd /mnt/ceph/users/zzhang/CRISPR_pred/crispr_kinn

/mnt/ceph/users/zzhang/CRISPR_pred/crispr_kinn


In [2]:
from src.kinetic_model import KineticModel, modelSpace_to_modelParams
from src.neural_network_builder import KineticNeuralNetworkBuilder
from notebooks.runAmber import get_uniform_ms, get_finkelstein_ms, get_reward_pipeline

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-qq7o42x9 because the default path (/home/zzhang/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.
Using TensorFlow backend.


0.1.1-ga


In [3]:
import warnings
warnings.filterwarnings('ignore')
import time
from datetime import datetime

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

import scipy.stats as ss
import pandas as pd
import numpy as np
from tqdm import tqdm
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import shutil
import os
import pickle
import gc
from sklearn.model_selection import train_test_split

## Load data

In [4]:
x = np.load('./data/compiled_X.npy')
y = np.load('./data/compiled_Y.npy')
with open('./data/y_col_annot.txt', 'r') as f:
    label_annot = [x.strip() for x in f]
    label_annot = {x:i for i,x in enumerate(label_annot)}
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=777)
label_annot

{'wtCas9_cleave_rate_log': 0,
 'Cas9_enh_cleave_rate_log': 1,
 'Cas9_hypa_cleave_rate_log': 2,
 'Cas9_HF1_cleave_rate_log': 3,
 'wtCas9_cleave_rate_log_specificity': 4,
 'Cas9_enh_cleave_rate_log_specificity': 5,
 'Cas9_hypa_cleave_rate_log_specificity': 6,
 'Cas9_HF1_cleave_rate_log_specificity': 7,
 'wtCas9_ndABA': 8,
 'Cas9_enh_ndABA': 9,
 'Cas9_hypa_ndABA': 10,
 'Cas9_HF1_ndABA': 11}

In [5]:
target_idx = label_annot['wtCas9_cleave_rate_log']

## Setup AMBER

In [6]:
import amber
print(amber.__version__)
from amber.architect import pmbga
from amber.architect import ModelSpace, Operation

0.1.1-ga


In [7]:
#kinn_model_space = get_uniform_ms(n_states=4, st_win_size=7)

#print(kinn_model_space)

In [8]:
kinn_model_space = get_finkelstein_ms()
print(kinn_model_space)

StateSpace with 7 layers and 1 total combinations


In [9]:
controller = pmbga.ProbaModelBuildGeneticAlgo(
            model_space=kinn_model_space,
            buffer_type='population',
            buffer_size=50,  # buffer size controlls the max history going back
            batch_size=1,   # batch size does not matter in this case; all arcs will be retrieved
        )

## Components before they are implemented in AMBER

## A fancy For-Loop that does the work for `amber.architect.trainEnv`

In [10]:
# trainEnv parameters
samps_per_gen = 10   # how many arcs to sample in each generation; important
max_gen = 1500
epsilon = 0.05
patience = 200
n_warmup_gen = -1
wd = "outputs/notebook"

In [11]:
def compute_eps(model_space_probs, old_probs=None):
    delta = []
    samp_probs = {}
    for p in model_space_probs:
        #print(p)
        samp_probs[p] = model_space_probs[p].sample(size=10000)
        n = np.percentile(samp_probs[p], [10, 20, 30, 40, 50, 60, 70, 80, 90])
        if old_probs is None:
            delta.append( np.mean(np.abs(n)) )
        else:
            o = np.percentile(old_probs[p], [10, 20, 30, 40, 50, 60, 70, 80, 90])
            delta.append( np.mean(np.abs(o - n)) )
    return np.mean(delta), samp_probs 

# get prior probas
_, old_probs = compute_eps(controller.model_space_probs)

In [None]:
hist = []
pc_cnt = 0
best_indv = 0
stat_df = pd.DataFrame(columns=['Generation', 'GenAvg', 'Best', 'PostVar'])
for generation in range(max_gen):
    try:
        start = time.time()
        has_impr = False
        for _ in range(samps_per_gen):
            # get arc
            arc, _ = controller.get_action()
            # get reward
            try:
                test_pcc = get_reward_pipeline(
                    arc,
                    x_train=x_train,
                    y_train=y_train[:, target_idx], 
                    x_test=x_test, 
                    y_test=y_test[:, target_idx],
                    wd=wd)
            #except ValueError:
            #    test_pcc = 0
            except Exception as e:
                raise e
            rate_df = None
            # update best, or increase patience counter
            if test_pcc > best_indv:
                best_indv = test_pcc
                has_impr = True
                shutil.move(os.path.join(wd, "bestmodel.h5"), os.path.join(wd, "AmberSearchBestModel.h5"))
                shutil.move(os.path.join(wd, "model_params.pkl"), os.path.join(wd, "AmberSearchBestModel_config.pkl"))

            # store
            _ = controller.store(action=arc, reward=test_pcc)
            hist.append({'gen': generation, 'arc':arc, 'test_pcc': test_pcc, 'rate_df': rate_df})
        end = time.time()
        if generation < n_warmup_gen:
            print(f"Gen {generation} < {n_warmup_gen} warmup.. skipped - Time %.2f" % (end-start), flush=True)
            continue
        _ = controller.train(episode=generation, working_dir=".")
        delta, old_probs = compute_eps(controller.model_space_probs, old_probs)
        post_vars = [np.var(x.sample(size=100)) for _, x in controller.model_space_probs.items()]
        stat_df = stat_df.append({
            'Generation': generation,
            'GenAvg': controller.buffer.r_bias,
            'Best': best_indv,
            'PostVar': np.mean(post_vars)
        }, ignore_index=True)
        print("[%s] Gen %i - Mean fitness %.3f - Best %.4f - PostVar %.3f - Eps %.3f - Time %.2f" % (
            datetime.now().strftime("%H:%M:%S"),
            generation, 
            controller.buffer.r_bias, 
            best_indv, 
            np.mean(post_vars),
            delta,
            end-start), flush=True)
        #if delta < epsilon:
        #    print("stop due to convergence criteria")
        #    break
        pc_cnt = 0 if has_impr else pc_cnt+1
        if pc_cnt >= patience:
            print("early-stop due to max patience w/o improvement")
            break
    except KeyboardInterrupt:
        print("user interrupted")
        break

datapoints:  6 / total:  10
[23:46:21] Gen 0 - Mean fitness 0.719 - Best 0.7837 - PostVar 6.004 - Eps 0.600 - Time 140.22
datapoints:  10 / total:  20
[23:48:43] Gen 1 - Mean fitness 0.726 - Best 0.7837 - PostVar 6.309 - Eps 0.459 - Time 141.25
datapoints:  17 / total:  30
[23:51:09] Gen 2 - Mean fitness 0.710 - Best 0.7837 - PostVar 5.776 - Eps 0.400 - Time 145.37
datapoints:  20 / total:  40
[23:53:27] Gen 3 - Mean fitness 0.725 - Best 0.7837 - PostVar 5.786 - Eps 0.334 - Time 136.56
datapoints:  27 / total:  50
[23:55:37] Gen 4 - Mean fitness 0.718 - Best 0.7838 - PostVar 5.879 - Eps 0.149 - Time 129.70
datapoints:  25 / total:  60
[23:57:35] Gen 5 - Mean fitness 0.744 - Best 0.7838 - PostVar 5.300 - Eps 0.467 - Time 116.48
datapoints:  31 / total:  70
[23:59:32] Gen 6 - Mean fitness 0.742 - Best 0.7838 - PostVar 5.521 - Eps 0.163 - Time 116.47
datapoints:  43 / total:  80
[00:01:36] Gen 7 - Mean fitness 0.734 - Best 0.7841 - PostVar 5.520 - Eps 0.185 - Time 123.77
datapoints:  51 /

datapoints:  202 / total:  500
[02:10:37] Gen 66 - Mean fitness 0.768 - Best 0.8032 - PostVar 5.402 - Eps 0.148 - Time 104.72
datapoints:  318 / total:  500
[02:12:23] Gen 67 - Mean fitness 0.760 - Best 0.8032 - PostVar 4.538 - Eps 0.089 - Time 105.17
datapoints:  247 / total:  500
[02:14:14] Gen 68 - Mean fitness 0.765 - Best 0.8032 - PostVar 4.663 - Eps 0.059 - Time 110.20
datapoints:  317 / total:  500
[02:16:19] Gen 69 - Mean fitness 0.761 - Best 0.8032 - PostVar 4.753 - Eps 0.081 - Time 124.57
datapoints:  269 / total:  500
[02:18:56] Gen 70 - Mean fitness 0.764 - Best 0.8032 - PostVar 4.332 - Eps 0.133 - Time 156.16
datapoints:  188 / total:  500
[02:21:30] Gen 71 - Mean fitness 0.769 - Best 0.8032 - PostVar 4.869 - Eps 0.132 - Time 153.37
datapoints:  272 / total:  500
[02:24:08] Gen 72 - Mean fitness 0.764 - Best 0.8032 - PostVar 4.612 - Eps 0.150 - Time 156.08
datapoints:  438 / total:  500
[02:27:12] Gen 73 - Mean fitness 0.740 - Best 0.8032 - PostVar 4.901 - Eps 0.170 - Time

datapoints:  224 / total:  500
[04:17:11] Gen 131 - Mean fitness 0.774 - Best 0.8131 - PostVar 3.988 - Eps 0.044 - Time 103.67
datapoints:  233 / total:  500
[04:19:00] Gen 132 - Mean fitness 0.774 - Best 0.8131 - PostVar 4.082 - Eps 0.022 - Time 108.17
datapoints:  205 / total:  500
[04:20:55] Gen 133 - Mean fitness 0.775 - Best 0.8131 - PostVar 4.178 - Eps 0.045 - Time 114.56
datapoints:  283 / total:  500
[04:22:37] Gen 134 - Mean fitness 0.771 - Best 0.8131 - PostVar 4.005 - Eps 0.045 - Time 101.24
datapoints:  290 / total:  500
[04:24:32] Gen 135 - Mean fitness 0.771 - Best 0.8131 - PostVar 3.857 - Eps 0.072 - Time 113.89
datapoints:  187 / total:  500
[04:26:25] Gen 136 - Mean fitness 0.777 - Best 0.8131 - PostVar 4.098 - Eps 0.094 - Time 112.07
datapoints:  243 / total:  500
[04:28:10] Gen 137 - Mean fitness 0.774 - Best 0.8131 - PostVar 3.943 - Eps 0.081 - Time 104.61
datapoints:  154 / total:  500
[04:30:02] Gen 138 - Mean fitness 0.778 - Best 0.8131 - PostVar 3.748 - Eps 0.09

datapoints:  145 / total:  500
[06:19:19] Gen 196 - Mean fitness 0.783 - Best 0.8140 - PostVar 2.815 - Eps 0.119 - Time 117.39
datapoints:  155 / total:  500
[06:21:06] Gen 197 - Mean fitness 0.783 - Best 0.8140 - PostVar 2.862 - Eps 0.030 - Time 106.10
datapoints:  107 / total:  500
[06:23:11] Gen 198 - Mean fitness 0.787 - Best 0.8140 - PostVar 3.217 - Eps 0.081 - Time 123.85
datapoints:  409 / total:  500
[06:24:59] Gen 199 - Mean fitness 0.770 - Best 0.8140 - PostVar 3.600 - Eps 0.184 - Time 107.12
datapoints:  259 / total:  500
[06:26:48] Gen 200 - Mean fitness 0.779 - Best 0.8140 - PostVar 2.949 - Eps 0.061 - Time 108.53
datapoints:  175 / total:  500
[06:28:46] Gen 201 - Mean fitness 0.782 - Best 0.8140 - PostVar 3.124 - Eps 0.089 - Time 116.69
datapoints:  346 / total:  500
[06:30:30] Gen 202 - Mean fitness 0.774 - Best 0.8140 - PostVar 3.395 - Eps 0.111 - Time 103.89
datapoints:  126 / total:  500
[06:32:30] Gen 203 - Mean fitness 0.786 - Best 0.8140 - PostVar 3.376 - Eps 0.18

datapoints:  210 / total:  500
[08:27:17] Gen 261 - Mean fitness 0.786 - Best 0.8140 - PostVar 2.435 - Eps 0.067 - Time 114.75
datapoints:  161 / total:  500
[08:29:14] Gen 262 - Mean fitness 0.789 - Best 0.8140 - PostVar 2.346 - Eps 0.015 - Time 116.18
datapoints:  155 / total:  500
[08:31:17] Gen 263 - Mean fitness 0.790 - Best 0.8140 - PostVar 2.859 - Eps 0.044 - Time 122.39
datapoints:  195 / total:  500
[08:33:22] Gen 264 - Mean fitness 0.787 - Best 0.8140 - PostVar 2.533 - Eps 0.081 - Time 124.42
datapoints:  357 / total:  500
[08:35:12] Gen 265 - Mean fitness 0.779 - Best 0.8140 - PostVar 2.807 - Eps 0.074 - Time 109.26
datapoints:  274 / total:  500
[08:37:08] Gen 266 - Mean fitness 0.783 - Best 0.8140 - PostVar 2.050 - Eps 0.089 - Time 114.92
datapoints:  171 / total:  500
[08:39:09] Gen 267 - Mean fitness 0.788 - Best 0.8140 - PostVar 2.394 - Eps 0.059 - Time 119.83
datapoints:  122 / total:  500
[08:41:12] Gen 268 - Mean fitness 0.793 - Best 0.8140 - PostVar 2.687 - Eps 0.04

datapoints:  339 / total:  500
[10:34:10] Gen 326 - Mean fitness 0.783 - Best 0.8140 - PostVar 2.046 - Eps 0.026 - Time 113.38
datapoints:  269 / total:  500
[10:36:10] Gen 327 - Mean fitness 0.786 - Best 0.8140 - PostVar 2.049 - Eps 0.048 - Time 118.78
datapoints:  266 / total:  500
[10:38:07] Gen 328 - Mean fitness 0.786 - Best 0.8140 - PostVar 2.138 - Eps 0.037 - Time 116.22
datapoints:  237 / total:  500
[10:40:01] Gen 329 - Mean fitness 0.788 - Best 0.8172 - PostVar 2.559 - Eps 0.022 - Time 113.88
datapoints:  238 / total:  500
[10:41:55] Gen 330 - Mean fitness 0.788 - Best 0.8172 - PostVar 1.943 - Eps 0.007 - Time 112.60
datapoints:  189 / total:  500
[10:43:46] Gen 331 - Mean fitness 0.791 - Best 0.8172 - PostVar 2.148 - Eps 0.037 - Time 109.97
datapoints:  256 / total:  500
[10:45:36] Gen 332 - Mean fitness 0.787 - Best 0.8172 - PostVar 1.909 - Eps 0.036 - Time 109.73
datapoints:  134 / total:  500
[10:47:38] Gen 333 - Mean fitness 0.795 - Best 0.8172 - PostVar 2.074 - Eps 0.08

In [None]:
pd.DataFrame(hist).sort_values('test_pcc', ascending=False)

In [None]:
print("\n".join([str(x) 
                 for x in pd.DataFrame(hist).
                 sort_values('test_pcc', ascending=False).
                 head(1)['arc'].values[0]]))

In [None]:
a = pd.DataFrame(hist)
a['arc'] = ['|'.join([f"{x.Layer_attributes['RANGE_ST']}-{x.Layer_attributes['RANGE_ST']+x.Layer_attributes['RANGE_D']}" for x in entry]) for entry in a['arc']]
a.drop(columns=['rate_df'], inplace=True)
a.to_csv(os.path.join(wd,"train_history.tsv"), sep="\t", index=False)

In [None]:
%matplotlib inline

ax = stat_df.plot.line(x='Generation', y=['GenAvg', 'Best'])
ax.set_ylabel("Reward (Pearson correlation)")
ax.set_xlabel("Generation")
#plt.savefig("reward_vs_time.png")

In [None]:
# START SITE
fig, axs_ = plt.subplots(3,3, figsize=(15,15))
axs = [axs_[i][j] for i in range(len(axs_)) for j in range(len(axs_[i]))]
for k in controller.model_space_probs:
    if k[-1] == 'RANGE_ST':
        try:
            d = controller.model_space_probs[k].sample(size=1000)
        except:
            continue
        ax = axs[k[0]]
        sns.distplot(d, label="Post", ax=ax)
        sns.distplot(controller.model_space_probs[k].prior_dist, label="Prior", ax=ax)
        ax.set_title(
            ' '.join(['Rate ID', str(k[0]), '\nPosterior mean', str(np.mean(d))]))

        #_ = ax.set_xlim(0,50)

fig.suptitle('range start')
fig.tight_layout()
#fig.savefig("range_st.png")

In [None]:
# CONV RANGE
fig, axs_ = plt.subplots(3,3, figsize=(15,15))
axs = [axs_[i][j] for i in range(len(axs_)) for j in range(len(axs_[i]))]
for k in controller.model_space_probs:
    if k[-1] == 'RANGE_D':
        d = controller.model_space_probs[k].sample(size=1000)
        ax = axs[k[0]]
        sns.distplot(d, ax=ax)
        sns.distplot(controller.model_space_probs[k].prior_dist, label="Prior", ax=ax)
        ax.set_title(
                ' '.join(['Rate ID', str(k[0]), '\nPosterior mean', str(np.mean(d))]))
fig.suptitle('range length')
fig.tight_layout()
#fig.savefig("range_d.png")

In [None]:
# KERNEL SIZE 
fig, axs_ = plt.subplots(3,3, figsize=(15,15))
axs = [axs_[i][j] for i in range(len(axs_)) for j in range(len(axs_[i]))]
for k in controller.model_space_probs:
    if k[-1] == 'kernel_size':
        d = controller.model_space_probs[k].sample(size=1000)
        ax = axs[k[0]]
        sns.distplot(d, ax=ax)
        sns.distplot(controller.model_space_probs[k].prior_dist, ax=ax)
        ax.set_title(
            ' '.join(['Rate ID', str(k[0]), '\nPosterior mean', str(np.mean(d))]))
        #_ = ax.set_xlim(0,20) 
fig.suptitle('kernel size')
fig.tight_layout()

In [None]:
# reload and re-train to full convergence
%run notebooks/reload

In [None]:
pickle.load(open("outputs/notebook/AmberSearchBestModel_config.pkl", "rb"))

In [None]:
mb = reload_from_dir("outputs/notebook", replace_conv_by_fc=False, n_channels=8)
model = mb.model

In [None]:
x_train_b = mb.blockify_seq_ohe(x_train)
x_test_b = mb.blockify_seq_ohe(x_test)
checkpointer = ModelCheckpoint(
    filepath=os.path.join(wd,"bestmodel.h5"), mode='min', verbose=0, save_best_only=True,
    save_weights_only=True)
earlystopper = EarlyStopping(
    monitor="val_loss",
    mode='min',
    patience=15,
    verbose=0)

model.fit(x_train_b, y_train[:,target_idx],
          batch_size=128,
          validation_split=0.2,
          callbacks=[checkpointer, earlystopper],
          epochs=225, verbose=2)
model.load_weights(os.path.join(wd,"bestmodel.h5"))
y_hat = model.predict(x_test_b).flatten()
test_pcc = ss.pearsonr(y_hat, y_test[:,target_idx])[0]

In [None]:
[str(x.__dict__) for x in mb.kinn.rates]

In [None]:
layer_dict = {l.name:l for l in model.layers}

In [None]:
np.around(layer_dict['conv_k0'].get_weights()[0],3)

In [None]:
np.around(layer_dict['conv_k1'].get_weights()[0],3)

In [None]:
x_test_b = mb.blockify_seq_ohe(x_test)
y_hat = model.predict(x_test_b).flatten()
y_hat = np.clip(y_hat, -5, -1)
h = sns.jointplot(y_test[:,target_idx], y_hat)
h.set_axis_labels("obs", "pred", fontsize=16)
print("spearman", ss.spearmanr(y_hat, y_test[:,target_idx]))
p = ss.pearsonr(y_hat, y_test[:,target_idx])
print("pearson", p)
h.fig.suptitle("Testing prediction, pcc=%.3f"%p[0], fontsize=16)