In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import functools
import pickle
from operator import add
import matplotlib as mpl
from wazy.utils import *
from wazy.mlp import *
from jax_unirep import get_reps
import wazy
import os
import haiku as hk

In [None]:
AA_list = [
    "A",
    "R",
    "N",
    "D",
    "C",
    "Q",
    "E",
    "G",
    "H",
    "I",
    "L",
    "K",
    "M",
    "F",
    "P",
    "S",
    "T",
    "W",
    "Y",
    "V",
    "B",
    "Z",
    "X",
    "*",
]
blosum92 = np.loadtxt("./blosum62.txt", dtype="i", delimiter=" ")

avg92 = jnp.sum(blosum92) / 24 / 24
sum92 = 0.0
for row in blosum92:
    for aa in row:
        sum92 += (aa - avg92) ** 2
std92 = jnp.sqrt(sum92 / 24 / 24)


def blosum(seq1, seq2):
    seqlist1 = list(seq1)
    seqlist2 = list(seq2)
    score = 0.0
    for i in range(len(seqlist1)):
        idx1 = AA_list.index(seqlist1[i])
        idx2 = AA_list.index(seqlist2[i])
        score += (blosum92[idx1][idx2] - avg92) / std92
        # jax.nn.sigmoid(score/len(seqlist1))
    return score / len(seqlist1)

In [None]:
target_seq = "TARGETPEPTIDE"
key = jax.random.PRNGKey(0)

In [None]:
with open("../10kseqs.txt") as f:
    readfile = f.readlines()
    random_seqs = f"{readfile[0]}".split(" ")[:-1]


def get_blosum_labels(seqs):
    labels = []
    for seq in seqs:
        labels.append(blosum(target_seq, seq))
    labels = np.array(labels)

    return labels


def get_count_labels(seqs):
    return get_aanum(seqs)[:, 0]


def get_aanum(seqs):
    aa_count = []
    for seq in seqs:
        seq_list = list(seq)
        aa_num = [float(seq_list.count(aa)) for aa in AA_list]
        aa_count.append(aa_num)
    aa_count = jnp.array(aa_count)
    return aa_count


def get_flat_ohc(seqs):
    return jnp.array([encode_seq(list(s)).flatten() for s in seqs])


def get_ohc(seqs):
    return jnp.array([encode_seq(list(s)) for s in seqs])


batch_size = 8


def get_results(key, params, rep_list):
    means = []
    stds = []
    # need batch
    for i in range(0, len(rep_list) // batch_size):
        # for rep in rep_list:
        batch_reps = rep_list[i * batch_size : (i + 1) * batch_size]
        yhat = forward_t.apply(params, key, batch_reps)
        # print(yhat.shape)
        means.append(yhat[0])
        stds.append(yhat[1])
    return np.array(means), np.array(stds)


def get_single_results(key, params, rep_list):
    yhats = []
    for i in range(0, len(rep_list) // batch_size):
        # need batch
        # for rep in rep_list:
        batch_reps = rep_list[i * batch_size : (i + 1) * batch_size]
        yhat = naive_forward_t.apply(params, key, batch_reps)
        yhats.append(yhat)
    return np.array(yhats)

In [None]:
validation_seqs = [random.choice(random_seqs) for i in range(50)]
validation_ohc = get_ohc(validation_seqs)
validation_labels = get_blosum_labels(validation_seqs)
test_seqs = [random.choice(random_seqs) for i in range(500)]
test_ohc = get_ohc(test_seqs)
test_labels = get_blosum_labels(test_seqs)
train_seqs = [random.choice(random_seqs) for i in range(100)]
train_ohc = get_ohc(train_seqs)
train_labels = get_blosum_labels(train_seqs)

In [None]:
# LSTM
class LSTM(hk.Module):
    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size

    def __call__(self, x):  # batch size X sequence length X embedding dim
        batch_size = x.shape[0]
        core = hk.LSTM(16)
        x = hk.BatchApply(hk.Linear(32), num_dims=1)(x)
        x = jax.nn.relu(x)
        outs, state = hk.dynamic_unroll(
            core, x, core.initial_state(batch_size), reverse=False, time_major=False
        )
        outs = jnp.take(outs, -1, axis=-2)
        outs = hk.BatchApply(hk.Linear(32), num_dims=1)(outs)
        outs = jax.nn.relu(outs)
        outs = hk.BatchApply(hk.Linear(16), num_dims=1)(outs)
        outs = jax.nn.relu(outs)
        return hk.BatchApply(hk.Linear(self.output_size), num_dims=1)(outs)

In [None]:
def forward(x):
    f = LSTM(1)
    return f(x)


lstm = hk.transform(forward)

In [None]:
def l(params, key, seqs, labels):
    yhat = forward(params, key, seqs)
    return jnp.mean((yhat - labels) ** 2)

In [None]:
def train(params, seqs, labels, forward):
    # optimizer
    opt_init, opt_update = optax.chain(
        optax.scale_by_adam(
            b1=0.8,
            b2=0.9,
            eps=1e-3,
        ),
        optax.scale(-1e-3),  # minus sign -- minimizing the loss
    )
    if params == None:
        batch_seqs = jnp.ones((8, 13, 20))
        params = forward_t.init(key, batch_seqs)
    opt_state = opt_init(params)

    @jax.jit
    def train_step(opt_state, params, key, seq, label):
        loss, grad = jax.value_and_grad(l, 0)(params, key, seq, label)
        updates, opt_state = opt_update(grad, opt_state, params)
        params = optax.apply_updates(params, updates)
        return opt_state, params, loss