### MLP in JAX

In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
from jax_unirep import get_reps
import optax



In [2]:
# multi-model mlp
# input
# model_batch X sequence_batch X 1900
# output
# model_batch X sequence_batch X 2 (mean, std)
def forward(x):
    mlp = hk.nets.MLP([1900, 256, 32, 2])
    return mlp(x)

forward = hk.transform(forward)

In [31]:
# make data
seqs = ['AAAAA', 'AAAAAA', 'AAAAAAA', 'AAAAAAAA']
labels = jnp.array([5.4, 10.2, 25.0, 27.3])
unirep_seq = get_reps(seqs)[0]
print(unirep_seq.shape)

(4, 1900)


In [4]:
# make data duplicate model_batch times
#deep_ensemble_unirep_seq = unirep_seq[jnp.newaxis, ...]
ensemble_unirep_seq = jax.lax.broadcast(unirep_seq, (5,))
ensemble_labels = jax.lax.broadcast(labels, (5,))[...,jnp.newaxis]
#deep_ensemble_unirep_seq = unirep_seq[jnp.newaxis, ...]
#print(deep_ensemble_unirep_seq.shape)
#print(ensemble_labels.shape)

In [25]:
rng = jax.random.PRNGKey(37)
model_batch = 5
seq_batch = 4 # nothing related to the seqprop sequence batch, it's training batch
#x = jnp.ones([model_batch, seq_batch, 1900]) # input
#params = forward.init(rng, jnp.ones([1900]))
#outs = forward.apply(params, rng, ensemble_unirep_seq)
batch_keys = jax.random.split(rng, num=model_batch)
batch_keys = jnp.reshape(batch_keys, (model_batch, -1))
#batch_keys.reshape()

In [79]:
def deep_ensemble_loss(params, ins, labels): # labels are in batches
    # in batches
    outs = forward.apply(params, rng, ins)
    means = outs[0]
    #print(means)
    stds = outs[1]
    n_log_likelihoods = 0.5*jnp.log(stds**2) + 0.5*(labels-means)**2/stds**2
    #print(n_log_likelihoods[0])
    return n_log_likelihoods[0]

In [85]:
def train_mlp(key, seqs, labels):
    learning_rate = 0.01
    n_training_steps = 10
    opt_init, opt_update = optax.chain(
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
        optax.scale(-learning_rate)) # minus sign -- minimizing the loss

  # Initialise the model's parameters and the optimiser's state.
  # The `state` of an optimiser contains all statistics used by the
  # stateful transformations in the `chain` (in this case just `scale_by_adam`).
    #print(seqs.shape)
    #len_seq = len(seqs)
    key1, key2 = jax.random.split(key, num=2)
    params = forward.init(key1, jax.random.normal(key2, shape=(1900,)))
    opt_state = opt_init(params)

    loss_trace = []
    for step in range(n_training_steps):
        loss, grad = jax.value_and_grad(deep_ensemble_loss)(params, seqs, labels)
        loss_trace.append(loss)
        print(f'Loss[{step}] = {loss}')
        updates, opt_state = opt_update(grad, opt_state, params)
        params = optax.apply_updates(params, updates)
    outs = forward.apply(params, key1, seqs)
    return params, loss_trace, outs

In [88]:
# batch twice with vmap
b_train_mlp = jax.vmap(train_mlp, (None, 0, 0),(0, 0, 0))
bb_train_mlp = jax.vmap(b_train_mlp, (0, 0, 0),(0, 0, 0))

In [89]:
# call training process
#params, loss_trace, outs = b_train_mlp(rng, unirep_seq, labels)
params, loss_trace, outs= bb_train_mlp(batch_keys, ensemble_unirep_seq, ensemble_labels)

Loss[0] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[4])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[1.03615609e+05, 1.90608422e+05, 7.56650875e+05,
                                        7.62263375e+05],
                                       [1.41960328e+05, 4.98721656e+05, 3.56300525e+06,
                                        5.12011350e+06],
                                       [1.18992419e+03, 4.79595215e+03, 3.27389434e+04,
                                        4.43107656e+04],
                                       [1.85071172e+05, 2.67819281e+05, 1.04371462e+06,
                                        1.15667650e+06],
                                       [2.03935605e+04, 1.19289914e+05, 1.01147244e+06,
                                        1.39738962e+06]], dtype=float32)
                    batch_dim = 0
       batch_dim = 0
Loss[1] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/

In [90]:
#print(loss_trace)
#print(outs)
print(params)

FlatMap({
  'mlp/~/linear_0': FlatMap({
                      'b': DeviceArray([[[ 0.        ,  0.        ,  0.        , ..., -0.04187571,
                                          -0.04187571,  0.        ],
                                         [ 0.        ,  0.        ,  0.        , ..., -0.04187571,
                                          -0.04187571,  0.        ],
                                         [ 0.        ,  0.        ,  0.        , ..., -0.04187571,
                                          -0.04187571,  0.        ],
                                         [ 0.        ,  0.        ,  0.        , ..., -0.04187571,
                                           0.04187571,  0.        ]],
                           
                                        [[ 0.        , -0.04187571,  0.        , ...,  0.        ,
                                           0.        ,  0.        ],
                                         [ 0.        , -0.04187571,  0.        , ...,  0.  

In [None]:
def surrogate(params, key, X):
    predict, std = forward.apply(params, key, seqs)
    return predict, std

def acquisition(X, Xsamples, model):
    # calculate the best surrogate score found so far
    yhat, _ = surrogate(model, X)
    best = max(yhat)
    # calculate mean and stdev via surrogate function
    mu, std = surrogate(model, Xsamples)
    mu = mu[:, 0]
    # calculate the probability of improvement
    probs = norm.cdf((mu - best) / (std+1E-9))
    return probs