In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import jax
from jax import vmap
import jax.numpy as jnp
from jax.experimental import optimizers
from functools import partial
import jax_unirep
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden
from jax_unirep.utils import load_params, load_embedding, seq_to_oh
from jax_unirep.utils import *
from jax_unirep import get_reps



In [2]:
ALPHABET_Unirep = ['-','M','R','H','K','D','E','S','T','N','Q','C','U','G','P','A','V','I','F','Y','W','L','O','X','Z','B','J','start','stop']
ALPHABET = ['A','R','N','D','C','Q','E','G','H','I', 'L','K','M','F','P','S','T','W','Y','V']
def vectorize(pep):
    '''Takes a string of amino acids and encodes it to an L x 20 one-hot vector,
    where L is the length of the peptide.'''
    vec = jnp.zeros((len(pep), 20))
    for i, letter in enumerate(pep):
        vec = jax.ops.index_update(vec, jax.ops.index[i, ALPHABET.index(letter)], 1.)
    return vec

def vec_to_seq(pep_vector):  # From Rainier's code
    seq = ''
    # expect a 2D numpy array (pep_length x 20), give the string it represents
    for letter in pep_vector[:int(jnp.sum(pep_vector))]:
        idx = jnp.argmax(letter)
        if letter[idx] == 0:
            break
        seq += ALPHABET[idx]
    return seq

### Unirep related functions

In [3]:
def differentiable_jax_unirep(ohc_seq):
    emb_params = load_embedding()
    seq_embeddings = []
    for oh_vec in ohc_seq:
        seq_embedding = jnp.squeeze(jnp.stack([jnp.matmul(oh_vec, emb_params)], axis=0))
        seq_embeddings.append(seq_embedding)
    seq_embeddings = jnp.stack(seq_embeddings, axis=0)
    _, mLSTM_apply_fun = mLSTM(1900)
    weight_params = load_params()[1]
    def apply_fun_vmapped(x):
        return mLSTM_apply_fun(params=weight_params, inputs=x)
    h_final, c_final, outputs = vmap(apply_fun_vmapped)(seq_embeddings)
    #h_final, _, outputs = jax.vmap(partial(mLSTM_apply_fun, weight_params))(seq_embedding)
    h_avg = jnp.mean(outputs, axis=1)
    return h_avg

def index_trans(oh, alphabet, alphabet_unirep):
    matrix = jnp.zeros((len(alphabet), 26))
    for idx, aa in enumerate(alphabet):
        matrix = jax.ops.index_update(matrix, tuple([idx, alphabet_unirep.index(aa)]), 1.)
    start_char = jnp.zeros((1, 26))
    start_char = jax.ops.index_update(start_char, (0, 24), 1.)
    oh_unirep = jnp.einsum('ij,jk->ik', oh, matrix)
    oh_unirep = jnp.vstack((start_char, oh_unirep))
    return oh_unirep

### Sampling Layer

In [4]:
@jax.partial(jax.custom_jvp, nondiff_argnums=(0,))
def disc_ss(key, logits):
    key, sub_key = jax.random.split(key, num=2)
    sampled_onehot = jax.nn.one_hot(jax.random.categorical(key, logits), logits.shape[-1])
    return sampled_onehot

# customized gradient for back propagation
@disc_ss.defjvp
def disc_ss_jvp(key, primals, tangents):
    key, subkey = jax.random.split(key, num=2)
    logits = primals[0]
    logits_dot = tangents[0]
    primal_out = disc_ss(key, logits)
    _, tangent_out = jax.jvp(jax.nn.softmax, primals, tangents)
    return primal_out, tangent_out

### Norm layer

In [5]:
def norm_layer(logits, r, b):
    epsilon = 1e-5
    M, N = jnp.shape(logits)
    miu = jnp.sum(logits) / (M*N)
    std = jnp.sqrt(jnp.sum((logits - miu)**2) / (M*N))
    norm_logits = (logits - miu) / (std**2 + epsilon)
    scaled_logits = norm_logits * r + b
    return scaled_logits

### Batch Version Forward Seqprop

In [6]:
def forward_seqprop_batch(key, logits_batch, r_batch, b_batch):
    batch_size = len(logits_batch)
    sampled_vec_batch = []
    norm_logits_batch = []
    for i in range(batch_size):
        logits = logits_batch[i]
        r = r_batch[i]
        b = b_batch[i]
        norm_logits = norm_layer(logits, r, b) # same dimension as logits
        sampled_vec = disc_ss(key, norm_logits)
        sampled_vec_batch.append(sampled_vec)
        norm_logits_batch.append(norm_logits)
        
    return sampled_vec_batch, norm_logits_batch

### Batch Version Loss function

In [7]:
def loss_func_batch(target_rep, sampled_vec_batch):
    batch_size = len(sampled_vec_batch)
    sampled_vec_unirep_batch = []
    losses = []
    for i in range(batch_size):
        sampled_vec_unirep = index_trans(sampled_vec_batch[i], ALPHABET, ALPHABET_Unirep)
        sampled_vec_unirep_batch.append(sampled_vec_unirep)
    h_avg= differentiable_jax_unirep(sampled_vec_unirep_batch)
    for i in range(h_avg.shape[0]):
        losses.append(1-jnp.sum(jnp.vdot(h_avg[i], target_rep))/jnp.sqrt(jnp.sum(h_avg[i]**2)*jnp.sum(target_rep**2)))
    # loss = jnp.mean(((target_rep - h_avg)/target_rep)**2)   # mean square error
    #loss = 1-jnp.sum(jnp.vdot(h_avg, target_rep))/jnp.sqrt(jnp.sum(h_avg**2)*jnp.sum(target_rep**2))
    return losses

In [8]:
def g_loss_func(key, logits, r, b, target_rep):
    sampled_vec, _ = forward_seqprop_batch(key, logits, r, b)
    return loss_func_batch(target_rep, sampled_vec)

In [82]:
target_char = ['G','I','G','A','V','L','K','V','L','T','T','G','L','P','A','L','I','S','W','I','K','R','K','R','Q','Q']
oh_vec = vectorize(target_char)
target_seq = ['GIGAVLKVLTTGLPALISWIKRKRQQ']
target_rep = get_reps(target_seq)[0]
key = jax.random.PRNGKey(0)

# batch the logits 32
batch_num = 32
logits_batch = []
r_batch = []
b_batch = []
# make batched logits
for _ in range(batch_num):
    key, logits_key, r_key, b_key = jax.random.split(key, num=4)
    logits_batch.append(jax.random.normal(logits_key, shape=jnp.shape(oh_vec)))
    r_batch.append(jax.random.normal(r_key))
    b_batch.append(jax.random.normal(b_key))


In [78]:
sampled_vec_batch, norm_logits_batch = forward_seqprop_batch(key, logits_batch, r_batch, b_batch)
loss = loss_func_batch(target_rep, sampled_vec_batch)

In [91]:
print(loss)

[DeviceArray(0.19891185, dtype=float32), DeviceArray(0.24761719, dtype=float32), DeviceArray(0.247576, dtype=float32), DeviceArray(0.3009013, dtype=float32), DeviceArray(0.2752024, dtype=float32), DeviceArray(0.30030912, dtype=float32), DeviceArray(0.28053582, dtype=float32), DeviceArray(0.24470818, dtype=float32), DeviceArray(0.13517243, dtype=float32), DeviceArray(0.2691852, dtype=float32), DeviceArray(0.28933328, dtype=float32), DeviceArray(0.23477793, dtype=float32), DeviceArray(0.27657723, dtype=float32), DeviceArray(0.2933095, dtype=float32), DeviceArray(0.2606218, dtype=float32), DeviceArray(0.28960413, dtype=float32), DeviceArray(0.3630033, dtype=float32), DeviceArray(0.29853022, dtype=float32), DeviceArray(0.22240508, dtype=float32), DeviceArray(0.38732547, dtype=float32), DeviceArray(0.20633137, dtype=float32), DeviceArray(0.31671095, dtype=float32), DeviceArray(0.23193097, dtype=float32), DeviceArray(0.2935465, dtype=float32), DeviceArray(0.27611536, dtype=float32), DeviceAr

In [92]:
def train_seqprop_batch(key, target_rep, init_logits_batch, init_r_batch, init_b_batch, iter_num=600):
    opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2, b1=0.9, b2=0.8)
    #opt_init, opt_update, get_params = optimizers.adagrad(step_size=1e-2)
    opt_state = opt_init((init_logits_batch, init_r_batch, init_b_batch)) # initial state
    logits_trace = []
    loss_trace = []

    @jax.jit
    def step(key, i, opt_state):
        key, subkey = jax.random.split(key, num=2)
        p = get_params(opt_state)
        logits, r, b = p
        sampled_vec, norm_logits = forward_seqprop_batch(key, logits, r, b)
        loss = loss_func_batch(target_rep, sampled_vec)
        #g = jax.grad(g_loss_func, (1,2,3))(key, logits, r, b, target_rep)
        g = jax.jacfwd(g_loss_func, (1,2,3))(key, logits, r, b, target_rep)
        print(g)
        return opt_update(i, g, opt_state), loss

    for step_idx in range(iter_num):
        print(step_idx)
        opt_state, loss = step(key, step_idx, opt_state)
        print(loss)
        loss_trace.append(loss)
        mid_logits, mid_r, mid_b = get_params(opt_state)
        logits_trace.append(mid_logits)
    final_logits, final_r, final_b = get_params(opt_state)
    sampled_vec, _ = forward_seqprop(key, final_logits, final_r, final_b)
    return sampled_vec, final_logits, logits_trace, loss_trace

In [None]:
sampled_vec, final_logits, logits_trace, loss_trace = train_seqprop_batch(key, target_rep, logits_batch, r_batch, b_batch, iter_num = 600)

0
