In [1]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import functools
import pickle
from operator import add
import matplotlib as mpl
from wazy.utils import *
from wazy.mlp import *
from jax_unirep import get_reps
import wazy
import os



In [2]:
AA_list = ['A','R','N','D','C','Q','E','G','H','I','L','K','M','F','P','S','T','W','Y','V','B','Z','X','*']
blosum62 = np.loadtxt("blosum62.txt", dtype='i', delimiter=' ')
min62 = jnp.min(blosum62)
blosum62 = blosum62 - min62
avg62 = jnp.sum(blosum62)/24/24
#(blosum62 - jnp.min(blosum62)) / (jnp.max(blosum62) - jnp.min(blosum62))
sum62 = 0.
for row in blosum62:
    for aa in row:
        sum62 += (aa-avg62)**2
std62 = jnp.sqrt(sum62 / 24/24)
def blosum(seq1, seq2):
    seqlist1 = list(seq1)
    seqlist2 = list(seq2)
    score = 0.
    for i in range(len(seqlist1)):
        idx1 = AA_list.index(seqlist1[i])
        idx2 = AA_list.index(seqlist2[i])
        score += blosum62[idx1][idx2]/std62
        #jax.nn.sigmoid(score/len(seqlist1))
    return score/len(seqlist1)

In [3]:
target_seq = 'TARGETPEPTIDE'
L = len(target_seq)
N = 50
repeats = 3
key = jax.random.PRNGKey(0)
print(blosum(target_seq, 'G' * L))

1.2428473


In [4]:
boa = wazy.MCMCAlgorithm(L)

start = 'G' * L
rand_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0

    for i in range(N):
        s = ''.join(np.random.choice(AA_list[:20], size=(L,)))
        y = blosum(s, target_seq)
        best = max(y, best)
        if best == y:
            print(i, best)
        rand_results[i] += y
rand_results = [r / repeats for r in rand_results]

0 1.7044765
17 1.7044765
32 1.7399863
39 1.7754962
0 1.6334566
3 1.7044765
33 2.2371254
0 1.2073373
1 1.7754962


In [5]:
boa = wazy.MCMCAlgorithm(L)

start = 'G' * L
mcmc_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0

    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        if best == y:
            print(i, best)
        mcmc_results[i] += y
mcmc_results = [r / repeats for r in mcmc_results]

0 1.2428473
1 1.3493769
2 1.4203968
3 1.6689664
5 1.6689664
15 1.9530457
16 2.0240655
17 2.0240655
18 2.0240655
19 2.0950854
0 1.2073373
1 1.2428473
2 1.2783571
3 1.4914168
4 1.5269266
6 1.5624366
7 1.5979464
8 1.6334565
9 1.7044762
10 1.7399863
12 1.8110061
16 1.8110061
0 1.2428473
1 1.2783571
2 1.3493769
4 1.3493772
5 1.3493772
8 1.455907
31 1.5624368
32 1.5979466


In [6]:
boa = wazy.BOAlgorithm(model_config=wazy.EnsembleBlockConfig(pretrained=False), 
                       alg_config=wazy.AlgConfig())

start = 'G' * L
ohc_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0
    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key)
        yhat, std, _ = boa.predict(key, s)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        print(i, s, yhat, std, y, best)
        ohc_results[i] += y
ohc_results = [r / repeats for r in ohc_results]

  lax_internal._check_user_dtype_supported(dtype, "zeros")


0 GRGGVLSGGGCGG 0.38505948 0.72121453 1.384887 1.384887
1 HFGGHYQGGKDGI 0.24415572 0.71787035 1.3138671 1.384887
2 GQYGGWSGKFCGG 0.32090068 0.70313585 1.384887 1.384887
3 GQKGCGGGNVCGG 0.28943405 0.750027 1.4559067 1.4559067
4 GGKVGNGGWATGT 0.31988755 0.7379295 1.2783571 1.4559067
5 CMTGIGGVEMGTA 0.17337285 0.72886044 1.3493772 1.4559067
6 GQGTCGMGITTLG 0.33012044 0.7314469 1.0652976 1.4559067
7 LFEGREGGGITTW 0.20824865 0.7198456 1.4914169 1.4914169
8 KGSGCEGGHGPSW 0.32075843 0.7221307 1.3138671 1.4914169
9 CFRGRTGGFTAPC 0.27469903 0.7210311 1.9885558 1.9885558
10 GGSQLTPGFGPYG 0.29721975 0.71933436 1.420397 1.9885558
11 NQWGGPGGGMGGG 0.37274107 0.7315319 1.3138671 1.9885558
12 FQKPCQGGFHWPG 0.2928308 0.72147584 0.99427783 1.9885558
13 LNKGRATGNTSPN 0.25474367 0.7199106 1.9175358 1.9885558
14 GIVGGNGGFWMRG 0.2788469 0.739749 1.3138671 1.9885558
15 HFEGDQPGWCGGD 0.68920577 0.7962394 1.8465161 1.9885558
16 DGEGRTGGFGNDA 0.7041578 0.78151095 1.9175359 1.9885558
17 GPVECQGPFGATC 0.6544836 

42 HAGGNHGIGYGGE 1.4093939 0.02263701 1.6689664 2.1305957
43 LPGGKGNLNGGEE 1.3546559 0.016234279 1.6689664 2.1305957
44 TGGGNGSRPGSGT 1.4104865 0.028320432 2.0950856 2.1305957
45 GLPHNGGNSGGRQ 0.94661254 0.08183533 1.2073374 2.1305957
46 FAGGDTFLTWAKL 1.3865554 0.058394194 1.7754964 2.1305957
47 FGGGVCYAKWIKH 1.4119701 0.03277886 1.6689663 2.1305957
48 FFTGCGGGNTCKR 1.4384314 0.03122425 1.5624366 2.1305957
49 FAEGGWGPNTGDG 1.664513 0.02377081 1.9885558 2.1305957


In [7]:
boa = wazy.BOAlgorithm(model_config=wazy.EnsembleBlockConfig(pretrained=False), 
                       alg_config=wazy.AlgConfig())

start = 'G' * L
ohc_g_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0
    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key, aq_fxn='max')
        yhat, std, _ = boa.predict(key, s)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        print(i, s, yhat, std, y, best)
        ohc_g_results[i] += y
ohc_g_results = [r / repeats for r in ohc_g_results]

0 EGGGGMGTGGMGF 0.30973092 0.71209973 1.4914168 1.4914168
1 LGFGGSGCKVAGD 0.24549417 0.7320166 1.6334565 1.6334565
2 LPGGGPGKMVGGG 0.3558564 0.74687177 1.4559067 1.6334565
3 LMEGGGGAAVCVG 0.28897166 0.7402934 1.4914168 1.6334565
4 SGTGCVGGMGFNC 0.23117833 0.74305147 1.5269266 1.6334565
5 AMGGGRGGGGIFL 0.3335875 0.7389836 1.4914168 1.6334565
6 CMTQGVHPKVPFS 0.16355287 0.7359596 1.2428473 1.6334565
7 LGFDEQDCKVAGS 0.2698772 0.75284016 1.5269266 1.6334565
8 LGFAGSDGKVPGM 0.26915753 0.736858 1.313867 1.6334565
9 MYLAGSGCMVRFG 0.28380886 0.7272063 1.0652976 1.6334565
10 LGLGGSGMKVQLP 0.35644692 0.7708671 1.455907 1.6334565
11 GGTEGSGCKGWGG 0.36159617 0.75988 1.1008075 1.6334565
12 VMTGCQSSMVLAQ 0.19133745 0.7052118 1.775496 1.775496
13 MGLGPGGGVVQFQ 0.33593994 0.7340721 1.4914168 1.775496
14 TNCDGETSVVDMG 0.19629242 0.7212751 1.3138671 1.775496
15 GMNGTQSSYVGGM 0.72949713 0.7120635 1.4914168 1.775496
16 GITGEPKTVGQGG 0.73928785 0.6947618 1.6334566 1.775496
17 GAQLNSKAMGNCW 0.34386018 0.7080

42 GPGFEEGFGVCGH 0.92095166 0.03317201 1.384887 2.1661055
43 WLWAYTPSDGDGS 1.2054037 0.04010129 1.7399864 2.1661055
44 GGLGGATAKGGGW 1.4397701 0.013702631 1.3848871 2.1661055
45 FGTGCGGVGGAGW 1.1351417 0.03287208 1.2783573 2.1661055
46 GGDPWMTSHGFGG 1.2526408 0.031085849 1.2073373 2.1661055
47 DFAGIFFWWGAPM 0.8837263 0.08956081 1.1363175 2.1661055
48 SGCEEAKLTGLGC 1.3194727 0.033792973 1.5269266 2.1661055
49 APRGGDHCGGLGD 1.2895784 0.018296838 1.8465159 2.1661055


In [None]:
boa = wazy.BOAlgorithm(model_config=wazy.EnsembleBlockConfig(pretrained=True), 
                       alg_config=wazy.AlgConfig())

start = 'G' * L
pre_results = [0 for i in range(N)]
for r in range(repeats):
    boa.tell(key, start, blosum(target_seq, start))
    best = 0
    for i in range(N):
        key,_ = jax.random.split(key)
        s, a = boa.ask(key)
        yhat, std, _ = boa.predict(key, s)
        y = blosum(s, target_seq)
        best = max(y, best)
        boa.tell(key, s, y)
        print(i, s, yhat, std, y, best)
        pre_results[i] += y
pre_results = [r / repeats for r in pre_results]

2022-09-14 12:20:41.441211: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:61] Constant folding an instruction is taking > 1s:

  %reduce-window.3 = f32[60,7600]{1,0} reduce-window(f32[1900,7600]{1,0} %constant.434, f32[] %constant.18), window={size=32x1 stride=32x1 pad=10_10x0_0}, to_apply=%region_32.1610

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2022-09-14 12:20:48.363985: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:61] Constant folding an instruction is taking > 2s:

  %reduce-window.5 = f32[60,7600]{1,0} reduce-window(f32[1900,7600]{1,0} %constant.639, f32[

0 GGGGGGGGAGGGA 1.1712836 0.054262638 1.3138671 1.3138671
1 RGGGGGGGGGGGA 1.1502995 0.059482217 1.3138671 1.3138671
2 CGPGGGQQQGGGA 0.9413517 0.118350685 1.5269268 1.5269268
3 GGGGGGCAAGGGA 1.2854325 0.0713588 1.3138671 1.5269268
4 RGGGGGGGAGGGA 1.3112258 0.07609081 1.3493772 1.5269268


In [None]:
plt.plot(ohc_g_results, label='OH Greedy')
plt.plot(ohc_results, label='OH')
plt.plot(mcmc_results, label='MCMC')
plt.plot(rand_results, label='Random')
plt.legend()
plt.show()

In [None]:
def curbest(x):
    return [np.max(x[:i]) for i in range(1, len(x) + 1)]
plt.plot(curbest(ohc_g_results), label='OH Greedy')
plt.plot(curbest(ohc_results), label='OH')
plt.plot(curbest(mcmc_results), label='MCMC')
plt.plot(curbest(rand_results), label='Random')
plt.legend()
plt.show()