### MLP in JAX

In [33]:
import haiku as hk
import jax
import jax.numpy as jnp
from jax_unirep import get_reps

In [26]:
# multi-model mlp
# input
# model_batch X sequence_batch X 1900
# output
# model_batch X sequence_batch X 2 (mean, std)
def forward(x):
    mlp = hk.nets.MLP([1900, 256, 32, 2])
    return mlp(x)

forward = hk.transform(forward)

In [53]:
# make data
seqs = ['AFKV', 'TUIS', 'PLKAH', 'URNLA']
labels = jnp.array([25.4, 10.2, 37.0, 7.3])
unirep_seq = get_reps(seqs)[0]

In [81]:
# make data duplicate model_batch times
#deep_ensemble_unirep_seq = unirep_seq[jnp.newaxis, ...]
ensemble_unirep_seq = jax.lax.broadcast(unirep_seq, (5,))
ensemble_labels = jax.lax.broadcast(labels, (5,))[...,jnp.newaxis]
#deep_ensemble_unirep_seq = unirep_seq[jnp.newaxis, ...]
print(deep_ensemble_unirep_seq.shape)
print(ensemble_labels.shape)

(5, 4, 1900)
(5, 4, 1)


In [121]:
rng = jax.random.PRNGKey(37)
model_batch = 5
seq_batch = 4 # nothing related to the seqprop sequence batch, it's training batch
#x = jnp.ones([model_batch, seq_batch, 1900]) # input
params = forward.init(rng, jnp.ones([1900]))
#outs = forward.apply(params, rng, ensemble_unirep_seq)
batch_keys = jax.random.split(rng, num=seq_batch*model_batch)
batch_keys = jnp.reshape(batch_keys, (model_batch, seq_batch, -1))
#batch_keys.reshape()

In [122]:
def deep_ensemble_loss(params, ins, labels): # labels are in batches
    # in batches
    outs = forward.apply(params, rng, ins)
    means = outs[0]
    print(means)
    stds = outs[1]
    n_log_likelihoods = 0.5*jnp.log(stds**2) + 0.5*(labels-means)**2/stds**2
    return n_log_likelihoods[0]

In [123]:
LEARNING_RATE = 0.003
def train_mlp(key, seqs, labels):
    params = forward.init(rng, jax.random.normal(key, shape=(jnp.shape(seqs))))
    
    #@jax.jit
    def update(params, x, y):
        grads = jax.grad(deep_ensemble_loss)(params, x, y)
        return jax.tree_multimap(
            lambda p, g: p - LEARNING_RATE * g, params, grads
      )
    for _ in range(10):
        params = update(params, seqs, labels)
    #outs = forward.apply(params, key, ins)
    
    return params

In [124]:
#train_mlp(rng, unirep_seq[0], 0.5)

In [125]:
b_train_mlp = jax.vmap(train_mlp, (0,0, 0),0)
bb_train_mlp = jax.vmap(b_train_mlp, (0,0, 0),0)

In [126]:
bb_train_mlp(batch_keys, ensemble_unirep_seq, ensemble_labels)

Traced<ShapedArray(float32[])>with<JVPTrace(level=4/0)>
  with primal = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
                  with val = Traced<ShapedArray(float32[4])>with<BatchTrace(level=1/0)>
                               with val = DeviceArray([[0.0472665 , 0.09724212, 0.0451394 , 0.09971958],
                                                       [0.0472665 , 0.09724212, 0.0451394 , 0.09971958],
                                                       [0.0472665 , 0.09724212, 0.0451394 , 0.09971958],
                                                       [0.0472665 , 0.09724212, 0.0451394 , 0.09971958],
                                                       [0.0472665 , 0.09724212, 0.04513941, 0.0997196 ]],            dtype=float32)
                                    batch_dim = 0
                       batch_dim = 0
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=3/0)>
Traced<ShapedArray(float32[])>with<JVPTrace(level=4/0)>
  with primal = T

Traced<ShapedArray(float32[])>with<JVPTrace(level=4/0)>
  with primal = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
                  with val = Traced<ShapedArray(float32[4])>with<BatchTrace(level=1/0)>
                               with val = DeviceArray([[          nan, 4.4692246e+11,           nan, 1.6706400e+13],
                                                       [          nan, 4.4692246e+11,           nan, 1.6706400e+13],
                                                       [          nan, 4.4692246e+11,           nan, 1.6706400e+13],
                                                       [          nan, 4.4692246e+11,           nan, 1.6706400e+13],
                                                       [          nan, 4.4692246e+11,           nan, 1.6706380e+13]],            dtype=float32)
                                    batch_dim = 0
                       batch_dim = 0
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=3/0)>


FlatMap({
  'mlp/~/linear_0': FlatMap({
                      'b': DeviceArray([[[ 0.0000000e+00,            nan,            nan, ...,
                                           0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
                                         [ 0.0000000e+00,  2.8082264e+01, -7.3755441e+00, ...,
                                           0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
                                         [ 0.0000000e+00, -1.0345431e+04,            nan, ...,
                                           0.0000000e+00,  0.0000000e+00, -1.9428212e+03],
                                         [ 0.0000000e+00,  9.3455368e+01, -5.1615295e+00, ...,
                                           0.0000000e+00,  0.0000000e+00,  0.0000000e+00]],
                           
                                        [[ 0.0000000e+00,            nan,            nan, ...,
                                           0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
 