In [1]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import functools
import pickle
from operator import add
import matplotlib as mpl
from wazy.utils import *
from wazy.mlp import *
from jax_unirep import get_reps
import wazy
import os



In [2]:
AA_list = ['A','R','N','D','C','Q','E','G','H','I','L','K','M','F','P','S','T','W','Y','V','B','Z','X','*']
blosum62 = np.loadtxt("blosum62.txt", dtype='i', delimiter=' ')
min62 = jnp.min(blosum62)
blosum62 = blosum62 - min62
avg62 = jnp.sum(blosum62)/len(blosum62)**2
sum62 = 0.
for row in blosum62:
    for aa in row:
        sum62 += (aa-avg62)**2
std62 = jnp.sqrt(sum62 / len(blosum62)**2)
def blosum(seq1, seq2):
    seqlist1 = list(seq1)
    seqlist2 = list(seq2)
    score = 0.
    for i in range(len(seqlist1)):
        idx1 = AA_list.index(seqlist1[i])
        idx2 = AA_list.index(seqlist2[i])
        score += blosum62[idx1][idx2]/std62
        #jax.nn.sigmoid(score/len(seqlist1))
    return score/len(seqlist1)

In [3]:
target_seq = 'TARGETPEPTIDE'
L = len(target_seq)
N = 15
repeats = 1
key = jax.random.PRNGKey(0)
print(blosum(target_seq, 'G' * L))

1.2428473


In [4]:
start = 'G' * L
rand_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0

    for i in range(N):
        s = ''.join(np.random.choice(AA_list[:20], size=(L,)))
        y = blosum(s, target_seq)
        best = max(y, best)
        if best == y:
            print(i, best)
        rand_results[i] += y
rand_results = [r / repeats for r in rand_results]

0 2.2016153


In [5]:
boa = wazy.MCMCAlgorithm(L)

start = 'G' * L
mcmc_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0

    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        if best == y:
            print(i, best)
        mcmc_results[i] += y
mcmc_results = [r / repeats for r in mcmc_results]

0 1.2428473
1 1.3493769
2 1.4203968
3 1.6689664
5 1.6689664


In [6]:
boa = wazy.BOAlgorithm(model_config=wazy.EnsembleBlockConfig(pretrained=False), 
                       alg_config=wazy.AlgConfig())

start = 'G' * L
ohc_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0
    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key)
        yhat, std, _ = boa.predict(key, s)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        print(i, s, yhat, std, y, best)
        ohc_results[i] += y
ohc_results = [r / repeats for r in ohc_results]

  lax_internal._check_user_dtype_supported(dtype, "zeros")


0 DEETGILPQHGEG 0.16635673 0.71697116 1.2073374 1.2073374
1 NGEGHGATGGGGC 0.3151273 0.705463 1.4559067 1.4559067
2 EGGWSGLTGGGGC 0.3176662 0.77049345 0.99427783 1.4559067
3 HWNRNGPYTCGGC 0.21391058 0.70192385 1.3138671 1.4559067
4 GACGGVLTSWDGP 0.20742269 0.72061914 1.5269268 1.5269268
5 YGGGYFLETAEQC 0.22313021 0.7043353 1.5624366 1.5624366
6 TERLLDLYGWSGC 0.23618136 0.7423074 1.3138671 1.5624366
7 AMEKWTMWGPRMP 0.1212672 0.7202117 1.2783571 1.5624366
8 DNGTGGPEHEGPN 0.23690248 0.73243284 1.5979466 1.5979466
9 RNQNKGADQIGPC 0.19396403 0.72596264 1.3848871 1.5979466
10 HNAKDDVGGLPQF 0.14229819 0.7127397 1.1718274 1.5979466
11 DAVTSGGTQFYPW 0.2296536 0.72032154 1.3138671 1.5979466
12 NERKKFLGHGCMF 0.25480008 0.7507061 1.3138673 1.5979466
13 ATGTGGGYHTCPS 0.21307547 0.7203607 1.4559067 1.5979466
14 GVHTGYGEGMYWG 0.19227378 0.7269907 1.3138671 1.5979466


In [None]:
boa = wazy.BOAlgorithm(model_config=wazy.EnsembleBlockConfig(pretrained=False), 
                       alg_config=wazy.AlgConfig())

start = 'G' * L
ohc_g_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0
    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key, aq_fxn='max')
        yhat, std, _ = boa.predict(key, s)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        print(i, s, yhat, std, y, best)
        ohc_g_results[i] += y
ohc_g_results = [r / repeats for r in ohc_g_results]

0 GYCGQHPGSINGA 0.21212211 0.6940221 1.7399865 1.7399865
1 NSGIGSEYEWASF 0.08664542 0.7074663 1.2783573 1.7399865
2 GESTRGTDLGAGL 0.14419763 0.7412526 1.2428473 1.7399865
3 YDGGGIMGHVEDG 0.10770092 0.7224255 1.5624365 1.7399865
4 YDCGAMLCIEEGF 0.28742588 0.71599394 1.1008075 1.7399865
5 EACGGDGGKGQKD 0.13477764 0.7422701 1.6334565 1.7399865
6 NDCGGQGRGWWIL 0.21228671 0.7329135 1.2428473 1.7399865
7 YEPGQGMSGPGGG 0.26128078 0.72771084 1.455907 1.7399865


In [None]:
boa = wazy.BOAlgorithm(model_config=wazy.EnsembleBlockConfig(pretrained=True), 
                       alg_config=wazy.AlgConfig())

start = 'A' * L
pre_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0
    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key)
        yhat, std, epi_std = boa.predict(key, s)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        print(i, s, yhat, std, epi_std, y, best)
        pre_results[i] += y
pre_results = [r / repeats for r in pre_results]

In [None]:
plt.plot(ohc_g_results, label='OH Greedy')
plt.plot(ohc_results, label='OH')
plt.plot(mcmc_results, label='MCMC')
plt.plot(rand_results, label='Random')
plt.legend()
plt.show()

In [None]:
def curbest(x):
    return [np.max(x[:i]) for i in range(1, len(x) + 1)]
plt.plot(curbest(ohc_g_results), label='OH Greedy')
plt.plot(curbest(ohc_results), label='OH')
plt.plot(curbest(mcmc_results), label='MCMC')
plt.plot(curbest(rand_results), label='Random')
plt.legend()
plt.show()