### MLP in JAX

In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
from jax_unirep import get_reps
import optax



In [2]:
# multi-model mlp
# input
# model_batch X sequence_batch X 1900
# output
# model_batch X sequence_batch X 2 (mean, std)
def forward(x):
    mlp = hk.nets.MLP([1900, 256, 32, 2])
    return mlp(x)

forward = hk.transform(forward)

In [3]:
# make data
seqs = ['AAAAA', 'AAAAAA', 'AAAAAAA', 'AAAAAAAA']
labels = jnp.array([5.4, 10.2, 25.0, 27.3])
unirep_seq = get_reps(seqs)[0]

In [4]:
# make data duplicate model_batch times
#deep_ensemble_unirep_seq = unirep_seq[jnp.newaxis, ...]
ensemble_unirep_seq = jax.lax.broadcast(unirep_seq, (5,))
ensemble_labels = jax.lax.broadcast(labels, (5,))[...,jnp.newaxis]
#deep_ensemble_unirep_seq = unirep_seq[jnp.newaxis, ...]
#print(deep_ensemble_unirep_seq.shape)
#print(ensemble_labels.shape)

In [52]:
rng = jax.random.PRNGKey(37)
model_batch = 5
seq_batch = 4 # nothing related to the seqprop sequence batch, it's training batch
#x = jnp.ones([model_batch, seq_batch, 1900]) # input
#params = forward.init(rng, jnp.ones([1900]))
#outs = forward.apply(params, rng, ensemble_unirep_seq)
batch_keys = jax.random.split(rng, num=seq_batch*model_batch)
batch_keys = jnp.reshape(batch_keys, (model_batch, seq_batch, -1))
#batch_keys.reshape()

[[[2193506619 1836845627]
  [2640627218 2731996434]
  [1162524604 3933337620]
  [1698034610  156788376]]

 [[1517479688 3054226544]
  [  44168207  790790252]
  [2847106455 1905325981]
  [3732109485 1890434876]]

 [[3920083628  641085216]
  [ 317469417 2871877061]
  [1862523578 1593348394]
  [1011846268 3348669339]]

 [[1165058815 2518368341]
  [ 366627617 3721941943]
  [ 916183925 1491602017]
  [1961424384 1989074513]]

 [[2086938857  173443644]
  [1610144698 1642220118]
  [ 672792170  282660665]
  [ 408629254 3048786725]]]


In [6]:
def deep_ensemble_loss(params, ins, labels): # labels are in batches
    # in batches
    outs = forward.apply(params, rng, ins)
    means = outs[0]
    #print(means)
    stds = outs[1]
    n_log_likelihoods = 0.5*jnp.log(stds**2) + 0.5*(labels-means)**2/stds**2
    print(n_log_likelihoods[0])
    return n_log_likelihoods[0]

In [43]:
def train_mlp(key, seqs, labels):
    learning_rate = 0.01
    n_training_steps = 10
    opt_init, opt_update = optax.chain(
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
        optax.scale(-learning_rate)) # minus sign -- minimizing the loss

  # Initialise the model's parameters and the optimiser's state.
  # The `state` of an optimiser contains all statistics used by the
  # stateful transformations in the `chain` (in this case just `scale_by_adam`).
    params = forward.init(rng, jax.random.normal(key, shape=(jnp.shape(seqs))))
    opt_state = opt_init(params)

    loss_trace = []
    for step in range(n_training_steps):
        loss, grad = jax.value_and_grad(deep_ensemble_loss)(params, seqs, labels)
        loss_trace.append(loss)
        print(f'Loss[{step}] = {loss}')
        updates, opt_state = opt_update(grad, opt_state, params)
        params = optax.apply_updates(params, updates)
    
    return params, loss_trace

In [44]:
# batch twice with vmap
b_train_mlp = jax.vmap(train_mlp, (0, 0, 0),(0, 0))
bb_train_mlp = jax.vmap(b_train_mlp, (0, 0, 0),(0, 0))

In [49]:
# call training process
params, loss_trace = bb_train_mlp(batch_keys, ensemble_unirep_seq, ensemble_labels)

Traced<ShapedArray(float32[])>with<JVPTrace(level=4/0)>
  with primal = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
                  with val = Traced<ShapedArray(float32[4])>with<BatchTrace(level=1/0)>
                               with val = DeviceArray([[  55007.5 ,  283934.94, 2374346.8 , 1976548.5 ],
                                                       [  55007.5 ,  283934.94, 2374346.8 , 1976548.5 ],
                                                       [  55007.5 ,  283934.94, 2374346.8 , 1976548.5 ],
                                                       [  55007.5 ,  283934.94, 2374346.8 , 1976548.5 ],
                                                       [  55007.5 ,  283934.94, 2374346.8 , 1976548.5 ]],            dtype=float32)
                                    batch_dim = 0
                       batch_dim = 0
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=3/0)>
Loss[0] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  wit

       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=3/0)>
Loss[5] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[4])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[5.699456, 5.545487, 5.78509 , 5.673466],
                                       [5.699456, 5.545487, 5.78509 , 5.673466],
                                       [5.699456, 5.545487, 5.78509 , 5.673466],
                                       [5.699456, 5.545487, 5.78509 , 5.673466],
                                       [5.699456, 5.545487, 5.78509 , 5.673466]], dtype=float32)
                    batch_dim = 0
       batch_dim = 0
Traced<ShapedArray(float32[])>with<JVPTrace(level=4/0)>
  with primal = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
                  with val = Traced<ShapedArray(float32[4])>with<BatchTrace(level=1/0)>
                               with val = DeviceArray([[5.6542053, 5.4837255, 5.8483114, 5.50506

In [51]:
print(loss_trace)

[DeviceArray([[  55007.5 ,  283934.94, 2374346.8 , 1976548.5 ],
             [  55007.5 ,  283934.94, 2374346.8 , 1976548.5 ],
             [  55007.5 ,  283934.94, 2374346.8 , 1976548.5 ],
             [  55007.5 ,  283934.94, 2374346.8 , 1976548.5 ],
             [  55007.5 ,  283934.94, 2374346.8 , 1976548.5 ]],            dtype=float32), DeviceArray([[3.719528 , 3.5745597, 3.608096 , 3.642284 ],
             [3.719528 , 3.5745597, 3.608096 , 3.642284 ],
             [3.719528 , 3.5745597, 3.608096 , 3.642284 ],
             [3.719528 , 3.5745597, 3.608096 , 3.642284 ],
             [3.719528 , 3.5745597, 3.608096 , 3.642284 ]], dtype=float32), DeviceArray([[4.7581215, 4.659507 , 4.652138 , 4.6990595],
             [4.7581215, 4.659507 , 4.652138 , 4.6990595],
             [4.7581215, 4.659507 , 4.652138 , 4.6990595],
             [4.7581215, 4.659507 , 4.652138 , 4.6990595],
             [4.7581215, 4.659507 , 4.652138 , 4.6990595]], dtype=float32), DeviceArray([[5.308935 , 5.19972

In [None]:
def acquisition(X, Xsamples, model):
    # calculate the best surrogate score found so far
    yhat, _ = surrogate(model, X)
    best = max(yhat)
    # calculate mean and stdev via surrogate function
    mu, std = surrogate(model, Xsamples)
    mu = mu[:, 0]
    # calculate the probability of improvement
    probs = norm.cdf((mu - best) / (std+1E-9))
    return probs