In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from functools import partial

import jax.numpy as np
from jax import lax, nn, random, vmap
from jax._src.nn.functions import normalize
from jax.experimental import stax
from jax.nn import sigmoid
from jax.nn.initializers import glorot_normal, normal
from jax.random import normal as norm
from jax import lax

In [3]:
from patch_gnn.data import load_ghesquire
import pandas as pd
from pyprojroot import here
import pickle as pkl
from patch_gnn.splitting import train_test_split
from jax import random
from patch_gnn.seqops import one_hot
from patch_gnn.unirep import unirep_reps
from patch_gnn.graph import graph_tensors
from patch_gnn.models import MPNN, DeepMPNN
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import explained_variance_score as evs
import matplotlib.pyplot as plt 
from sklearn.metrics import mean_squared_error as mse
import pickle as pkl
from patch_gnn.graph import met_position
import seaborn as sns



##### Load dataset, split into train and test, then one hot encode it

In [4]:
# load graphs
graph_pickle_path = here() / "data/ghesquire_2011/graphs.pkl"
with open(graph_pickle_path, "rb") as f:
    graphs = pkl.load(f)

# load data
data = load_ghesquire()

# filter data based on graphs.keys() and other metrics
filtered = (
   data
   .query("`accession-sequence` in @graphs.keys()")
   .query("ox_fwd_logit < 2.0")
   .join_apply(met_position, "met_position")
)

# split data into train and testing
key = random.PRNGKey(490)
train_df, test_df = train_test_split(key, filtered) # 70% training, 30% testing
print(f"the shape of train_df and test_df are {train_df.shape}, {test_df.shape}")

# pad sequence to 50 length and one hot encode it
padding_length = 50
train_oh = one_hot(train_df, padding_length) 
test_oh = one_hot(test_df,padding_length)
print(f"the shape of train_oh and test_oh are {train_oh.shape}, {test_oh.shape}")

the shape of train_df and test_df are (258, 18), (111, 18)
the shape of train_oh and test_oh are (258, 1050), (111, 1050)


##### reshape data to fit LSTM and get target values

In [5]:
lstm_train_oh = train_oh.reshape(train_oh.shape[0], padding_length, 21)
lstm_test_oh = test_oh.reshape(test_oh.shape[0], padding_length, 21)
print(lstm_train_oh.shape, lstm_test_oh.shape)
train_target = train_df['ox_fwd_logit'].values
test_target = test_df['ox_fwd_logit'].values

(258, 50, 21) (111, 50, 21)


In [6]:
lstm_train_oh[0:1,:,:].shape, train_target[0:1,].shape

((1, 50, 21), (1,))

#### LSTM based on GRU's case study

In [7]:
def AAEmbedding(embedding_dims: int , E_init=glorot_normal(), **kwargs):
    """
    Initial n-dimensional embedding of each amino-acid
    """

    def init_fun(rng, input_shape):
        """
        Generates the inital AA embedding matrix.
        `input_shape`:
            one-hot encoded AA sequence -> (n_aa, n_unique_aa)
        `output_dims`:
            embedded sequence -> (n_aa, embedding_dims)
        `emb_matrix`:
            embedding matrix -> (n_unique_aa, embedding_dims)
        """
        emb_matrix = E_init(rng, (input_shape[-1], embedding_dims))
        output_dims = (-1, embedding_dims)

        return output_dims, emb_matrix

    def apply_fun(params, inputs, **kwargs):
        """
        Embed a single AA sequence
        """
        emb_matrix = params
        # (n_aa, n_unique_aa) * (n_unique_aa, embedding_dims) => (n_aa, embedding_dims) # noqa: E501
        return np.matmul(inputs, emb_matrix)

    return init_fun, apply_fun

In [8]:
def LSTM(out_dim, W_init=glorot_normal(), b_init=normal()):
    """
    one directional LSTM, see math here https://d2l.ai/chapter_recurrent-modern/lstm.html
    :params out_dim: number of output neurons associated with an input of a single time point
    
    """
    def init_fun(rng, input_shape):
        """
        initialize LSTM layer for stax
        :param rng: The PRNGKey (from JAX) for random number generation _reproducibility_.
        :params input_shape: (num_time_steps/n_letters, embeddings)
        """
        # initial hidden state and memory state
        hidden = b_init(rng, (1, out_dim)) # denote by H in the formula #b_init(rng, (input_shape[0], out_dim)), instead None can be any number
        memory = b_init(rng, (1, out_dim)) # denote by C in the formula #b_init(rng, (input_shape[0], out_dim))
        # input gate
        k1, k2, k3 = random.split(rng, num=3)
        input_W, input_U, input_b = (
            W_init(k1, (input_shape[-1], out_dim)),
            W_init(k2, (out_dim, out_dim)),
            b_init(k3, (out_dim,)),
        )
        # forget gate
        k1, k2, k3 = random.split(rng, num=3)
        forget_W, forget_U, forget_b = (
            W_init(k1, (input_shape[-1], out_dim)),
            W_init(k2, (out_dim, out_dim)),
            b_init(k3, (out_dim,)),
        )
        # output gate
        k1, k2, k3 = random.split(rng, num=3)
        output_W, output_U, output_b = (
            W_init(k1, (input_shape[-1], out_dim)),
            W_init(k2, (out_dim, out_dim)),
            b_init(k3, (out_dim,)),
        )
        # current memory 
        k1, k2, k3 = random.split(rng, num=3)
        candidate_m_W, candidate_m_U, candidate_m_b = (
            W_init(k1, (input_shape[-1], out_dim)),
            W_init(k2, (out_dim, out_dim)),
            b_init(k3, (out_dim,)),
        )
        
        # Input dim 0 - batch dimension, 1 - time dimension (before scan moveaxis)
        output_shape = (input_shape[0], out_dim )#,input_shape[0],
        return (output_shape,
               ((hidden, memory),
               (input_W, input_U, input_b),
               (forget_W, forget_U, forget_b),
               (output_W, output_U, output_b),
               (candidate_m_W, candidate_m_U, candidate_m_b)),) # this tuple is (output_shape, params)
    
    def apply_fun(params, inputs, **kwargs):
        """ Loop over the time steps of the input sequence """
        h_0, m_0 = params[0] # initial hidden and memory
        
        def apply_fun_scan(params, carry, inp):
            """ 
            Perform single step update of the network
            carry: a tuple with (hidden, memory)
            :param inp: of shape 
            """
            (hidden, memory) = carry
            (i,j), (input_W, input_U, input_b), (forget_W, forget_U, forget_b), (
                output_W, output_U, output_b),(
                candidate_m_W, candidate_m_U, candidate_m_b) = params
            # shape annotation: (1,embedding)*(embedding, outdim) +(1, outdim)*(outdim, outdim) => (1, outdim)
            input_gate = sigmoid(np.dot(inp, input_W) +
                                  np.dot(hidden, input_U) + input_b)
            # shape annotation: (1,embedding)*(embedding, outdim) +(1, outdim)*(outdim, outdim)=> (1, outdim)
            forget_gate = sigmoid(np.dot(inp, forget_W) +
                                 np.dot(hidden, forget_U) + forget_b)
            # shape annotation: (1,embedding)*(embedding, outdim) +(1, outdim)*(outdim, outdim)=> (1, outdim)
            output_gate = sigmoid(np.dot(inp, output_W) +
                                 np.dot(hidden, output_U) + output_b)
            # shape annotation: (1,embedding)*(embedding, outdim) +(1, outdim)*(outdim, outdim)=> (1, outdim)
            candidate_memory = np.tanh(np.dot(inp, candidate_m_W) +
                                 np.dot(hidden, candidate_m_U) + candidate_m_b)
            # shape annotation: (1,outdim)@(1,outdim) + 1,outdim)@(1,outdim)=> (1, outdim)
            current_memory = np.multiply(forget_gate, memory) + np.multiply(input_gate, candidate_memory)
            # shape annotation: (1,outdim)@(1, outdim)=> (1, outdim)
            current_hidden = np.multiply(output_gate, np.tanh(current_memory))
            hidden = current_hidden
            memory = current_memory

            return (hidden, memory), hidden

        f = partial(apply_fun_scan, params) # this is a function
        (h_final, m_final), output = lax.scan(f, init=(h_0, m_0), xs=inputs)
        
        
        return (h_final, m_final), output

    return init_fun, apply_fun





#### next step we want to generate some inputs for LSTM layer for testing

In [9]:
from jax import value_and_grad, jit
from jax.experimental import stax
from patch_gnn.training import mseloss
from jax.experimental.optimizers import adam

# Initialize the network and perform a forward pass
vanilla_lstm_init_fun, vanilla_lstm_apply_fun = stax.serial(
    AAEmbedding(64), 
    stax.Sigmoid, 
    LSTM(32),
    stax.Dense(64),
    stax.Dense(1))#Dense(lstm_train_oh.shape[0])
#_, params = dense_lstm_init_fun(key, (lstm_train_oh.shape[0], padding_length, 21))
_, params = vanilla_lstm_init_fun(key,  (padding_length, 21))

#def mse_loss(params, inputs, targets):
#    """ Calculate the Mean Squared Error Prediction Loss. """
#    preds = vanilla_lstm_apply_fun(params, inputs)
#    return np.mean((preds - targets)**2)
# the following update function is the same as step function in patch_gnn.training
def fit(
    #sequences: Iterable[str],
    num_epochs: int,
    model_func: Callable = vanilla_lstm_apply_fun,
    params: Any = None,
    batch_size: int = 25,
    step_size: float = 0.0001,
    #holdout_seqs: Optional[Iterable[str]] = None,
    #proj_name: str = "temp",
    #epochs_per_print: int = 1,
    backend: str = "cpu",
):
    @jit
    def step(i, x, y, state):
        """ Perform one timestep forward pass, calculate the MSE & perform a SGD step. """
        params = get_params(state)
        g = grad(partial(loss, model_func))(params, x, y)
        state = update(i, g, state)
        return state

    init, update, get_params = adam(step_size=step_size)
    get_params = jit(get_params)
    state = init(params)
    
    for epoch in range(num_epochs):
    for batch_idx in range(lstm_train_oh.shape[0]):
        x_in = lstm_train_oh[batch_idx, :, :]
        y = train_target[batch_idx:batch_idx+1, ]
    
    


In [10]:
from jax.experimental.optimizers import adam
# Defining an optimizer in Jax
step_size = 1e-3
opt_init, opt_update, get_params = adam(step_size)
opt_state = opt_init(params)

num_epochs = 10

In [13]:
# Loop over the training epochs
loss = []
for epoch in range(num_epochs):
    for batch_idx in range(lstm_train_oh.shape[0]):
        x_in = lstm_train_oh[batch_idx, :, :]
        y = train_target[batch_idx:batch_idx+1, ]
        print(type(x_in), type(y))
        loss, grads = value_and_grad(mse_loss)(params, x_in, y)
        print(loss, grads)
        #params, opt_state, loss = update(params, x_in, y, opt_state)
        #train_loss_log.append(loss)





<class 'numpy.ndarray'> <class 'numpy.ndarray'>


TypeError: dot requires ndarray or scalar arguments, got <class 'tuple'> at position 0.

##### below part is what we want to do ultimately

In [None]:
from jax.experimental import stax
from jax.experimental.optimizers import adam
class VanillaLSTM:
    """Vanilla shallow LSTM model in sklearn-compatible format.

    Forward direction LSTM + linear regression on top.
    """

    def __init__(
        self,
        one_hot_encoded_shape,
        #node_feature_shape,
        #num_adjacency,
        num_training_steps: int = 100,
        optimizer_step_size=1e-5,
    ):
        """
        :param one_hot_encoded_shape: (batch_size, timestep, embedding), timestep is the padding_length
        
        """
        model_init_fun, model_apply_fun = stax.serial(
            Dense(64), 
            stax.Sigmoid, 
            LSTM(32), 
            Dense(1)
        )
        self.model_apply_fun = model_apply_fun

        self.optimizer = adam(step_size=optimizer_step_size)

        output_shape, params = model_init_fun(
            PRNGKey(42), input_shape=(*one_hot_encoded_shape)
        )

        self.params = params
        self.num_training_steps = num_training_steps
        self.state_history = []
        self.loss_history = []

    def fit(self, X, y):
        """Fit model.

        :param X: tuple(timestep, embedding)
        :param y: vector(values to predict)
        """
        if len(y.shape) == 1:
            y = np.reshape(y, (-1, 1))
        init, update, get_params = self.optimizer
        training_step = partial(
            step,
            loss_fun=mseloss,
            apply_fun=self.model_apply_fun,
            update_fun=update,
            get_params=get_params,
            inputs=X,
            outputs=y,
        )
        training_step = jit(training_step)

        state = init(self.params)

        for i in tqdm(range(self.num_training_steps)):
            state, loss = training_step(i, state)
            self.state_history.append(state)
            self.loss_history.append(loss)

        self.params = get_params(state)
        return self

    def predict(self, X, checkpoint: int = None):
        """
        predict
        :param X: tuple(adjacency, node_features)
        """
        params = self.params
        if checkpoint:
            _, _, get_params = self.optimizer
            params = get_params(self.state_history[checkpoint])
        return vmap(partial(self.model_apply_fun, params))(X)

In [None]:
num_training_steps = 50
model_vanilla_lstm = VanillaLSTM(
    one_hot_encoded_shape = (padding_length, 21)
    num_training_steps=num_training_steps
)
model_vanilla_lstm.fit(lstm_train_oh, train_target)

## Case study: GRU
### Implementation of GRU from https://towardsdatascience.com/getting-started-with-jax-mlps-cnns-rnns-d0bc389bd683
##### formula of GRU can be found at https://www.google.com/search?q=GRU+formula&tbm=isch&source=iu&ictx=1&fir=VgQCUBNNXFcaTM%252CXHoefSnHEDF68M%252C_&vet=1&usg=AI4_-kSHwsjZClZeI1h2s23kwDj5ncW6Jg&sa=X&ved=2ahUKEwjcwZaW8vjvAhUHKFkFHc8NC8sQ9QF6BAgNEAE&biw=1680&bih=895#imgrc=VgQCUBNNXFcaTM
##### machanism of GRU can be found at https://d2l.ai/chapter_recurrent-modern/gru.html

In [None]:
from jax.nn import sigmoid
from jax.nn.initializers import glorot_normal, normal

from functools import partial
from jax import lax

def GRU(out_dim, W_init=glorot_normal(), b_init=normal()):
    def init_fun(rng, input_shape):
        """ Initialize the GRU layer for stax """
        # input shape is of (batch_size, num_time_steps, embeddings)
        hidden = b_init(rng, (input_shape[0], out_dim))

        k1, k2, k3 = random.split(rng, num=3)
        update_W, update_U, update_b = (
            W_init(k1, (input_shape[2], out_dim)),
            W_init(k2, (out_dim, out_dim)),
            b_init(k3, (out_dim,)),)

        k1, k2, k3 = random.split(rng, num=3)
        reset_W, reset_U, reset_b = (
            W_init(k1, (input_shape[2], out_dim)),
            W_init(k2, (out_dim, out_dim)),
            b_init(k3, (out_dim,)),)

        k1, k2, k3 = random.split(rng, num=3)
        out_W, out_U, out_b = (
            W_init(k1, (input_shape[2], out_dim)),
            W_init(k2, (out_dim, out_dim)),
            b_init(k3, (out_dim,)),)
        # Input dim 0 represents the batch dimension
        # Input dim 1 represents the time dimension (before scan moveaxis)
        output_shape = (input_shape[0], input_shape[1], out_dim)
        return (output_shape,
            (hidden,
             (update_W, update_U, update_b),
             (reset_W, reset_U, reset_b),
             (out_W, out_U, out_b),),)

    def apply_fun(params, inputs, **kwargs):
        """ Loop over the time steps of the input sequence """
        h = params[0]
        
        def apply_fun_scan(params, hidden, inp):
            """ Perform single step update of the network """
            _, (update_W, update_U, update_b), (reset_W, reset_U, reset_b), (
                out_W, out_U, out_b) = params

            update_gate = sigmoid(np.dot(inp, update_W) +
                                  np.dot(hidden, update_U) + update_b)
            reset_gate = sigmoid(np.dot(inp, reset_W) +
                                 np.dot(hidden, reset_U) + reset_b)
            output_gate = np.tanh(np.dot(inp, out_W) 
                                  + np.dot(np.multiply(reset_gate, hidden), out_U) 
                                  + out_b)
            output = np.multiply(update_gate, hidden) + np.multiply(1-update_gate, output_gate)
            hidden = output
            return hidden, hidden

        # Move the time dimension to position 0
        inputs = np.moveaxis(inputs, 1, 0)
        f = partial(apply_fun_scan, params)
        _, h_new = lax.scan(f, h, inputs)
        return h_new

    return init_fun, apply_fun

In [None]:
num_dims = 10              # Number of OU timesteps
batch_size = 64            # Batchsize
num_hidden_units = 12      # GRU cells in the RNN layer 

# Initialize the network and perform a forward pass
init_fun, gru_rnn = stax.serial(Dense(num_hidden_units), Relu,
                                GRU(num_hidden_units), Dense(1))
_, params = init_fun(key, (batch_size, num_dims, 1))

def mse_loss(params, inputs, targets):
    """ Calculate the Mean Squared Error Prediction Loss. """
    preds = gru_rnn(params, inputs)
    return np.mean((preds - targets)**2)

@jit
def update(params, x, y, opt_state):
    """ Perform a forward pass, calculate the MSE & perform a SGD step. """
    loss, grads = value_and_grad(mse_loss)(params, x, y)
    opt_state = opt_update(0, grads, opt_state)
    return get_params(opt_state), opt_state, loss

