In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import functools
import pickle
from operator import add
import matplotlib as mpl
from wazy.utils import *
from wazy.mlp import *
from jax_unirep import get_reps
import wazy
import os

In [None]:
AA_list = ['A','R','N','D','C','Q','E','G','H','I','L','K','M','F','P','S','T','W','Y','V','B','Z','X','*']
blosum62 = np.loadtxt("blosum62.txt", dtype='i', delimiter=' ')
min62 = jnp.min(blosum62)
blosum62 = blosum62 - min62
avg62 = jnp.sum(blosum62)/24/24
#(blosum62 - jnp.min(blosum62)) / (jnp.max(blosum62) - jnp.min(blosum62))
sum62 = 0.
for row in blosum62:
    for aa in row:
        sum62 += (aa-avg62)**2
std62 = jnp.sqrt(sum62 / 24/24)
def blosum(seq1, seq2):
    seqlist1 = list(seq1)
    seqlist2 = list(seq2)
    score = 0.
    for i in range(len(seqlist1)):
        idx1 = AA_list.index(seqlist1[i])
        idx2 = AA_list.index(seqlist2[i])
        score += blosum62[idx1][idx2]/std62
        #jax.nn.sigmoid(score/len(seqlist1))
    return score/len(seqlist1)

In [None]:
target_seq = 'TARGETPEPTIDE'
L = len(target_seq)
N = 5
repeats = 1
key = jax.random.PRNGKey(0)
print(blosum(target_seq, 'G' * L))

In [None]:
boa = wazy.MCMCAlgorithm(L)

start = 'G' * L
rand_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0

    for i in range(N):
        s = ''.join(np.random.choice(AA_list[:20], size=(L,)))
        y = blosum(s, target_seq)
        best = max(y, best)
        if best == y:
            print(i, best)
        rand_results[i] += y
rand_results = [r / repeats for r in rand_results]

In [None]:
boa = wazy.MCMCAlgorithm(L)

start = 'G' * L
mcmc_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0

    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        if best == y:
            print(i, best)
        mcmc_results[i] += y
mcmc_results = [r / repeats for r in mcmc_results]

In [None]:
boa = wazy.BOAlgorithm(model_config=wazy.EnsembleBlockConfig(pretrained=True), 
                       alg_config=wazy.AlgConfig())

start = 'R' * L
pre_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0
    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key)
        yhat, std, _ = boa.predict(key, s)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        print(i, s, yhat, std, y, best)
        pre_results[i] += y
pre_results = [r / repeats for r in pre_results]

In [None]:
boa = wazy.BOAlgorithm(model_config=wazy.EnsembleBlockConfig(pretrained=False), 
                       alg_config=wazy.AlgConfig())

start = 'G' * L
ohc_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0
    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key)
        yhat, std, _ = boa.predict(key, s)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        print(i, s, yhat, std, y, best)
        ohc_results[i] += y
ohc_results = [r / repeats for r in ohc_results]

In [None]:
boa = wazy.BOAlgorithm(model_config=wazy.EnsembleBlockConfig(pretrained=False), 
                       alg_config=wazy.AlgConfig())

start = 'G' * L
ohc_g_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0
    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key, aq_fxn='max')
        yhat, std, _ = boa.predict(key, s)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        print(i, s, yhat, std, y, best)
        ohc_g_results[i] += y
ohc_g_results = [r / repeats for r in ohc_g_results]

In [None]:
plt.plot(ohc_g_results, label='OH Greedy')
plt.plot(ohc_results, label='OH')
plt.plot(mcmc_results, label='MCMC')
plt.plot(rand_results, label='Random')
plt.legend()
plt.show()

In [None]:
def curbest(x):
    return [np.max(x[:i]) for i in range(1, len(x) + 1)]
plt.plot(curbest(ohc_g_results), label='OH Greedy')
plt.plot(curbest(ohc_results), label='OH')
plt.plot(curbest(mcmc_results), label='MCMC')
plt.plot(curbest(rand_results), label='Random')
plt.legend()
plt.show()