### MLP in JAX

In [61]:
import haiku as hk
import jax
import jax.numpy as jnp
from jax_unirep import get_reps
import optax

In [62]:
# multi-model mlp
# input
# model_batch X sequence_batch X 1900
# output
# model_batch X sequence_batch X 2 (mean, std)
def forward(x):
    mlp = hk.nets.MLP([1900, 256, 32, 2])
    return mlp(x)

forward = hk.transform(forward)

In [63]:
# make data
seqs = ['AAAAA', 'AAAAAA', 'AAAAAAA', 'AAAAAAAA']
labels = jnp.array([5.4, 10.2, 25.0, 27.3])
unirep_seq = get_reps(seqs)[0]
print(unirep_seq.shape)

(4, 1900)


In [64]:
# make data duplicate model_batch times
#deep_ensemble_unirep_seq = unirep_seq[jnp.newaxis, ...]
ensemble_unirep_seq = jax.lax.broadcast(unirep_seq, (5,))
ensemble_labels = jax.lax.broadcast(labels, (5,))[...,jnp.newaxis]
#deep_ensemble_unirep_seq = unirep_seq[jnp.newaxis, ...]
#print(deep_ensemble_unirep_seq.shape)
#print(ensemble_labels.shape)

In [65]:
rng = jax.random.PRNGKey(37)
model_batch = 5
seq_batch = 4 # nothing related to the seqprop sequence batch, it's training batch
#x = jnp.ones([model_batch, seq_batch, 1900]) # input
#params = forward.init(rng, jnp.ones([1900]))
#outs = forward.apply(params, rng, ensemble_unirep_seq)
batch_keys = jax.random.split(rng, num=model_batch)
batch_keys = jnp.reshape(batch_keys, (model_batch, -1))
#batch_keys.reshape()

In [81]:
def deep_ensemble_loss(params, ins, labels): # labels are in batches
    # in batches
    outs = forward.apply(params, rng, ins)
    means = outs[0]
    #print(means)
    stds = outs[1]
    n_log_likelihoods = 0.5*jnp.log(jnp.abs(stds)) + 0.5*(labels-means)**2/jnp.abs(stds)
    #print(n_log_likelihoods[0])
    return n_log_likelihoods[0]

In [82]:
def adv_loss_func(params, seqs, labels, loss_func):
    epsilon = 1e-3
    grad_inputs = jax.grad(loss_func, 1)(params, seqs, labels)
    seqs_ = seqs + epsilon * jnp.sign(grad_inputs)
    
    return loss_func(params, seqs, labels) + loss_func(params, seqs_, labels)

def train_mlp(key, seqs, labels):
    learning_rate = 1e-2
    n_training_steps = 10
    opt_init, opt_update = optax.chain(
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-4),
        optax.scale(-learning_rate)) # minus sign -- minimizing the loss

  # Initialise the model's parameters and the optimiser's state.
  # The `state` of an optimiser contains all statistics used by the
  # stateful transformations in the `chain` (in this case just `scale_by_adam`).
    key, key2 = jax.random.split(key, num=2)
    params = forward.init(key, jax.random.normal(key2, shape=(1900,)))
    opt_state = opt_init(params)

    loss_trace = []
    for step in range(n_training_steps):
        # generate adversarial example
        #loss_func = adv_loss_func(params, seqs, labels, deep_ensemble_loss)
        
        loss, grad = jax.value_and_grad(adv_loss_func)(params, seqs, labels, deep_ensemble_loss)
        loss_trace.append(loss)
        print(f'Loss[{step}] = {loss}')
        updates, opt_state = opt_update(grad, opt_state, params)
        params = optax.apply_updates(params, updates)
    outs = forward.apply(params, key, seqs)
    return loss_trace, outs

In [83]:
# batch twice with vmap
b_train_mlp = jax.vmap(train_mlp, (None, 0, 0),(0, 0))
bb_train_mlp = jax.vmap(b_train_mlp, (0, 0, 0),(0, 0))

In [84]:
# call training process
#params, loss_trace, outs = b_train_mlp(rng, unirep_seq, labels)
loss_trace, outs= bb_train_mlp(batch_keys, ensemble_unirep_seq, ensemble_labels)

Loss[0] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[4])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[4.0061343e+03, 8.4170898e+03, 3.8728430e+04, 4.0700961e+04],
                                       [9.3068975e+03, 3.2290648e+04, 3.4838169e+05, 7.1373380e+06],
                                       [2.7320386e+02, 1.0468959e+03, 6.7169062e+03, 8.5491191e+03],
                                       [5.5152852e+03, 9.6688809e+03, 4.3855629e+04, 4.9892555e+04],
                                       [1.3133713e+03, 6.5267178e+03, 4.8829848e+04, 6.6872117e+04]],            dtype=float32)
                    batch_dim = 0
       batch_dim = 0
Loss[1] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[4])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[ 4.6296444,  3.2547464, 14.665989 , 17.375626 ],
                                       [27.6

In [85]:
# joint distribution across models
#def model_stack(batch_keys, ensemble_unirep_seq, ensemble_labels):
def model_stack(outs):
    #_, outs = bb_train_mlp(batch_keys, ensemble_unirep_seq, ensemble_labels)
    mu = jnp.mean(outs[:,:,0], axis=0)
    std = jnp.mean(outs[:,:,1] + outs[:,:,0]**2,axis=0) - mu**2
    return mu, std

#print(loss_trace)
#print(outs)
#print(params)
model_stack(outs)

(DeviceArray([  0.06314812,   9.401098  ,   3.087869  , -11.946028  ], dtype=float32),
 DeviceArray([1234.2212, 1737.8561, 3502.8018, 1322.3152], dtype=float32))

In [None]:
def surrogate(params, key, X):
    predict, std = forward.apply(params, key, seqs)
    return predict, std

def acquisition(X, Xsamples, model):
    # calculate the best surrogate score found so far
    yhat, _ = surrogate(model, X)
    best = max(yhat)
    # calculate mean and stdev via surrogate function
    mu, std = surrogate(model, Xsamples)
    mu = mu[:, 0]
    # calculate the probability of improvement
    probs = norm.cdf((mu - best) / (std+1E-9))
    return probs