### MLP in JAX

In [122]:
import haiku as hk
import jax
import jax.numpy as jnp
import jax.scipy.stats.norm as norm
from jax_unirep import get_reps
import optax

In [123]:
# multi-model mlp
# input
# model_batch X sequence_batch X 1900
# output
# model_batch X sequence_batch X 2 (mean, std)
def forward(x):
    mlp = hk.nets.MLP([1900, 256, 32, 2])
    return mlp(x)

forward = hk.transform(forward)

In [95]:
# make data
seqs = ['MSADDGIVYCQMRQGPWEFHIVTVESSAYDWVVVPGARIALDKYNAACEQHWSCILRRGIDQKPYAPDMLKCQCSDMCHPSDSFTWEIDAEAWYCNTDNLFTGIALYKNNDDYPDWYPIRCLKHKNVTAAQVPLVHFNDNKFTHHVHNDMPACDFKFFKTPTVRHACQFGSIYHSKQSRMDYSDLMQDEKAKHLKESHNVVPDDGIIIDPYMDILFGGRMNNREHCAKNE',
        'EKMHIKESATRMGFQYEYKLPYCIWAFIIGRAWHFVSLHGDQWDCWKMTFVIYSACSNGHIDGCEVQHANLSSGVLPARWFDAFQQNMKGFHKMKCGGFCTYAFLWGLAMRIYVRNMGNLAIYQNGGTSEWLTEFWYRLAGAVWPFKQFSINGECEHFWWSFHPFTLFDNPPAKDRNVTAYLHFDAHFYSIAMVWLMSPVVKGDSPVNCCAVDVEQSGESWALLNNWCAP',
        'HSFHKYKHGNWKSEGDQCLKVGQLRDECPQVNTPMYCSWGPHYFSIFHWIIPVAKAYHMLHNIEQQVYRCHWQERYKELHDATKTHQLEWSFGKSVWCAHCKPYIGWYRSPAGWHMPIKPPATKNLWVVRHKSKRKEGTISWENTLTCVWFHEICYGHGVCHQVHPWVVDSNEEYEMQWMETEVGECSYPAERQGAWYSFTQQQKWICIHVCNMSSGRVFCWYVLQLFRN',
        'LDHAVLKILQAMGPWNNRVEHPRLGKRSTEWPAAIYEGEPRWRLKCDTTATYYKAFETRWYNCHMTLTCWWHGATIRSKLTTMCMMVTNGYRDFYRYNDWKGRKATKHHPMVCIYEILWIAFMGCLHMWAGARVSKIWVGFCIFFASCLQMSPLKDWHNKCAFGRNNPLGMKGWGMMIGNSFCHIVHEMDNKYYAGAPVDEPFMYNQQVFGFGAMHCLCMADFCNEWGIQ',
        'PERHHYIGFHCYMQLDAIQQNPHWNAHVLFRAFDYVSNYWTWITMYDKYQGFLGIYVTSCKVHEHGACKHCHWPICYDCGQHADKMLWRKSFALHGQSHAYRPLWDRDLTGVLGISIDLNQGIKVAEAEGEILYCNVTDMTVMMHQSVGVFWCHDMAYPQWTDWYSSDNMMNSIPEISHMKNYRVTMVHEPLFIWECVSEWTENAEHEGHLITVGSTGGKWDTGMEREVM',
        'DPSQTIHCGTTGMSWGTMFKRSYILIIRYGTPEATCPCIVNCQIVVYWGCMFKKDRDPRGTPIQSTENFFKHAMMEPSYAGGTAHMEKEIEYRSQDSWHAYFSYWVKVWCYVCIALSQIPNVAHHGMHLHASPEDKKCANNWRFRYVAFIRIAHGCSWCYRECYNFRYDRYIAWNPVHLESVPEWWAHPAFEIVKDTVDDNQYSGADERQGDPIGGQPCLLCATWEDSWT',
        'LIDLFSLTRKFSRMPCRHNMNESYKEEWCETNNVKEYPHEQLLDKRYDIITLDGCKRMYCRSESQRITELHFIRYNMLCWPDRCIPLQYSQYESNMPPPMMRMWGCYHFGTLMFMSYAMPPTGEKREIVGGKEDHSGDLEDAFTDEDFNMDPAHQDYRHIAGTWHEPMFEIRMRYELTCNNMWSPIYANNAGMKQLTICNNDKICPTEGRRRQREIFNYKLHGRDQCQHI',
        'SCDVGPHPLHGQCTGMAKQVMETANIPQCPIDDHVTRATMGLIDAGACDRDRVCVREIWNVYYDKSTMKIIMDPPSDTCKHKSFYGDMMSHQQMGWLSECIIANMQHNLPWQLWESWMIHSEICMIKQRKVMMFCGIQSKYTEDFARFHPFILANTQYIIFKRPTTWPRVYAFLHRCMVLGWSAYGMTAMIPNTKETIKLAHCEKWPLTGSYTPSFVIFDGWLARKCQWP',
        'DWIEHVHTFWVLMFISNYPQIVCGLINQIEPWKSKFHSLAGFNQGCQCEKNYQGPIQAINGINQLVTITTPINNQENVDKKPHPGSVHTKSDAITLRFNQGVHNIFMWDMATQGRASIPFLNNMNGGGLTDYSWEQVVTCHCHMTNDLELDPQMLYMWWIVSANAWMVNGMRRQHMACHWAQWEGFRWPRYVQSVPMKVLLTTQKIHWMQYFREKFCFILMKWQGYWYTV',
        'RHWRAPLLMYRDKEVQITWHFRFMYHCDALTCSEVHCHARNFMVFGYSTPQNYNPVILYWVTWANTCLTPKGAYCARQMRMYATVTMSKINQMTITYLVDRQRQHWGLAFRSDNTCNHKWYLKHRCKVWNWGWLIDCYDLDRNLPKQVSRNQSSKSLRDLFNYIHYHWAMLPINIYCYSGDIWTTISTDDQFHIPTFIPCGKTVHEDLQPYEMCGMWHQCEDADYTMQPV',
]
labels = jnp.array([25.217391304347824,
                    15.652173913043478,
                    23.478260869565219,
                    22.173913043478262,
                    23.913043478260871,
                    24.782608695652176,
                    26.956521739130434,
                    17.391304347826086,
                    19.130434782608695,
                    26.521739130434781
                   ])
unirep_seq = get_reps(seqs)[0]
print(unirep_seq.shape)

(10, 1900)


In [97]:
# make data duplicate model_batch times
#deep_ensemble_unirep_seq = unirep_seq[jnp.newaxis, ...]
ensemble_unirep_seq = jax.lax.broadcast(unirep_seq, (5,))
ensemble_labels = jax.lax.broadcast(labels, (5,))[...,jnp.newaxis]
#deep_ensemble_unirep_seq = unirep_seq[jnp.newaxis, ...]
print(ensemble_unirep_seq.shape)
#print(ensemble_labels.shape)

(5, 10, 1900)


In [98]:
rng = jax.random.PRNGKey(37)
model_batch = 5
seq_batch = 10 # nothing related to the seqprop sequence batch, it's training batch
#x = jnp.ones([model_batch, seq_batch, 1900]) # input
#params = forward.init(rng, jnp.ones([1900]))
#outs = forward.apply(params, rng, ensemble_unirep_seq)
batch_keys = jax.random.split(rng, num=model_batch)
batch_keys = jnp.reshape(batch_keys, (model_batch, -1))
#batch_keys.reshape()

In [99]:
def deep_ensemble_loss(params, ins, labels): # labels are in batches
    # in batches
    outs = forward.apply(params, rng, ins)
    means = outs[0]
    #print(means)
    stds = outs[1]
    n_log_likelihoods = 0.5*jnp.log(jnp.abs(stds)) + 0.5*(labels-means)**2/jnp.abs(stds)
    #print(n_log_likelihoods[0])
    return n_log_likelihoods[0]

In [119]:
def adv_loss_func(params, seqs, labels, loss_func):
    epsilon = 1e-3
    grad_inputs = jax.grad(loss_func, 1)(params, seqs, labels)
    seqs_ = seqs + epsilon * jnp.sign(grad_inputs)
    
    return loss_func(params, seqs, labels) + loss_func(params, seqs_, labels)

def train_mlp(key, seqs, labels):
    learning_rate = 1e-2
    n_training_steps = 25
    opt_init, opt_update = optax.chain(
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-4),
        optax.scale(-learning_rate)) # minus sign -- minimizing the loss

  # Initialise the model's parameters and the optimiser's state.
  # The `state` of an optimiser contains all statistics used by the
  # stateful transformations in the `chain` (in this case just `scale_by_adam`).
    key, key2 = jax.random.split(key, num=2)
    params = forward.init(key, jax.random.normal(key2, shape=(1900,)))
    opt_state = opt_init(params)

    loss_trace = []
    for step in range(n_training_steps):
        # generate adversarial example
        #loss_func = adv_loss_func(params, seqs, labels, deep_ensemble_loss)
        
        loss, grad = jax.value_and_grad(adv_loss_func)(params, seqs, labels, deep_ensemble_loss)
        loss_trace.append(loss)
        print(f'Loss[{step}] = {loss}')
        updates, opt_state = opt_update(grad, opt_state, params)
        params = optax.apply_updates(params, updates)
    outs = forward.apply(params, key, seqs)
    def model_stack(outs):
        mu = jnp.mean(outs[:,:,0], axis=0)
        std = jnp.mean(outs[:,:,1] + outs[:,:,0]**2,axis=0) - mu**2
        return mu, std
    joint_outs = model_stack(outs)
    return loss_trace, joint_outs

In [120]:
# batch twice with vmap
b_train_mlp = jax.vmap(train_mlp, (None, 0, 0),(0, 0))
bb_train_mlp = jax.vmap(b_train_mlp, (0, 0, 0),(0, 0))

In [121]:
# call training process
#params, loss_trace, outs = b_train_mlp(rng, unirep_seq, labels)
loss_trace, joint_outs = bb_train_mlp(batch_keys, ensemble_unirep_seq, ensemble_labels)

Loss[0] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[10])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[ 29679.479 ,  15054.434 ,  58980.406 ,  79798.1   ,
                                         46560.125 ,  29364.066 ,  62689.22  ,  55668.18  ,
                                        346394.12  ,  48799.402 ],
                                       [ 32338.64  ,  22045.902 ,  19727.695 ,  40523.23  ,
                                         29636.791 ,  26669.    ,  20501.46  ,  12082.07  ,
                                         48065.64  ,  83550.016 ],
                                       [ 22104.133 ,   6663.216 ,  18111.775 ,   8339.1455,
                                         21545.512 ,  14573.02  ,  24013.762 ,  10015.458 ,
                                         13725.695 ,  18139.023 ],
                                       [ 80278.22  ,  55856.89  ,  60111.47  , 118057.18  ,
              

       batch_dim = 0
Loss[7] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[10])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[ 7.953397 ,  6.927745 ,  9.9121895,  6.649955 , 11.948759 ,
                                        14.422711 ,  5.8057756,  6.898861 , 34.58297  ,  9.645617 ],
                                       [11.41423  , 11.835065 , 15.571138 ,  7.443117 , 14.433104 ,
                                        19.843735 , 16.60033  ,  8.70101  ,  8.126274 , 16.243805 ],
                                       [ 6.6127806,  7.2023745,  7.1173797,  7.807339 ,  7.0333767,
                                         6.685045 ,  6.7382135,  9.467621 ,  6.389475 ,  7.39784  ],
                                       [26.996449 , 20.33168  , 54.306507 , 10.306734 , 39.33034  ,
                                        20.89112  , 35.76461  , 61.562325 , 55.86208  , 49.916786 ],
                                  

       batch_dim = 0
Loss[14] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[10])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[ 9.045431 ,  8.669388 ,  7.2189035,  7.283632 ,  8.323056 ,
                                        10.031789 ,  8.616293 ,  6.7380996, 15.871378 ,  7.1829777],
                                       [ 7.845849 ,  7.6117296,  9.229811 ,  8.752951 ,  8.912332 ,
                                        12.426262 ,  9.443331 ,  7.4157314,  7.0460396, 10.386066 ],
                                       [10.211369 ,  8.520031 ,  9.764282 ,  9.36692  ,  9.649442 ,
                                         8.419582 ,  9.665142 ,  8.087975 ,  8.627354 , 10.372922 ],
                                       [16.544415 ,  8.629042 , 19.742119 ,  8.294771 , 23.002487 ,
                                        13.2876835, 12.631309 , 25.10633  , 20.630226 , 20.865337 ],
                                 

       batch_dim = 0
Loss[21] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[10])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[ 7.64912  ,  7.4834166,  7.161188 ,  7.8665643,  7.824604 ,
                                         8.797516 ,  7.6719017,  7.230053 , 11.291676 ,  7.238161 ],
                                       [ 7.4414673,  7.4425516,  7.703515 ,  7.779322 ,  7.737194 ,
                                        10.053865 ,  7.729328 ,  8.818529 ,  7.985178 ,  8.910851 ],
                                       [ 8.326874 ,  8.365593 ,  8.533541 ,  8.057083 ,  8.582642 ,
                                         8.476882 ,  7.821847 ,  7.8002496,  7.996052 ,  8.035376 ],
                                       [11.163536 ,  7.3773756,  7.3016233,  7.535657 , 11.403984 ,
                                         9.553177 ,  7.490673 , 11.685973 ,  8.532706 ,  8.527232 ],
                                 

IndexError: Too many indices for array: 3 non-None/Ellipsis indices for dim 1.

Loss[0] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[10])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[ 29679.479 ,  15054.434 ,  58980.406 ,  79798.1   ,
                                         46560.125 ,  29364.066 ,  62689.22  ,  55668.18  ,
                                        346394.12  ,  48799.402 ],
                                       [ 32338.64  ,  22045.902 ,  19727.695 ,  40523.23  ,
                                         29636.791 ,  26669.    ,  20501.46  ,  12082.07  ,
                                         48065.64  ,  83550.016 ],
                                       [ 22104.133 ,   6663.216 ,  18111.775 ,   8339.1455,
                                         21545.512 ,  14573.02  ,  24013.762 ,  10015.458 ,
                                         13725.695 ,  18139.023 ],
                                       [ 80278.22  ,  55856.89  ,  60111.47  , 118057.18  ,
              

KeyboardInterrupt: 

In [118]:
def ei(seqs, model):
#def ei(outs):
    mu, std = model_stack(keys, seqs, labels)
    best = max(mu)
    epsilon = 0.1
    z = (mu - best - epsilon) / std
    return (mu - best - epsilon) * norm.cdf(z) + std * norm.pdf(z)

ei(batch_keys, ensemble_unirep_seq, ensemble_labels)

Loss[0] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[10])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[ 29679.479 ,  15054.434 ,  58980.406 ,  79798.1   ,
                                         46560.125 ,  29364.066 ,  62689.22  ,  55668.18  ,
                                        346394.12  ,  48799.402 ],
                                       [ 32338.64  ,  22045.902 ,  19727.695 ,  40523.23  ,
                                         29636.791 ,  26669.    ,  20501.46  ,  12082.07  ,
                                         48065.64  ,  83550.016 ],
                                       [ 22104.133 ,   6663.216 ,  18111.775 ,   8339.1455,
                                         21545.512 ,  14573.02  ,  24013.762 ,  10015.458 ,
                                         13725.695 ,  18139.023 ],
                                       [ 80278.22  ,  55856.89  ,  60111.47  , 118057.18  ,
              

       batch_dim = 0
Loss[7] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[10])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[ 7.953397 ,  6.927745 ,  9.9121895,  6.649955 , 11.948759 ,
                                        14.422711 ,  5.8057756,  6.898861 , 34.58297  ,  9.645617 ],
                                       [11.41423  , 11.835065 , 15.571138 ,  7.443117 , 14.433104 ,
                                        19.843735 , 16.60033  ,  8.70101  ,  8.126274 , 16.243805 ],
                                       [ 6.6127806,  7.2023745,  7.1173797,  7.807339 ,  7.0333767,
                                         6.685045 ,  6.7382135,  9.467621 ,  6.389475 ,  7.39784  ],
                                       [26.996449 , 20.33168  , 54.306507 , 10.306734 , 39.33034  ,
                                        20.89112  , 35.76461  , 61.562325 , 55.86208  , 49.916786 ],
                                  

       batch_dim = 0
Loss[14] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[10])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[ 9.045431 ,  8.669388 ,  7.2189035,  7.283632 ,  8.323056 ,
                                        10.031789 ,  8.616293 ,  6.7380996, 15.871378 ,  7.1829777],
                                       [ 7.845849 ,  7.6117296,  9.229811 ,  8.752951 ,  8.912332 ,
                                        12.426262 ,  9.443331 ,  7.4157314,  7.0460396, 10.386066 ],
                                       [10.211369 ,  8.520031 ,  9.764282 ,  9.36692  ,  9.649442 ,
                                         8.419582 ,  9.665142 ,  8.087975 ,  8.627354 , 10.372922 ],
                                       [16.544415 ,  8.629042 , 19.742119 ,  8.294771 , 23.002487 ,
                                        13.2876835, 12.631309 , 25.10633  , 20.630226 , 20.865337 ],
                                 

       batch_dim = 0
Loss[21] = Traced<ShapedArray(float32[])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[10])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[ 7.64912  ,  7.4834166,  7.161188 ,  7.8665643,  7.824604 ,
                                         8.797516 ,  7.6719017,  7.230053 , 11.291676 ,  7.238161 ],
                                       [ 7.4414673,  7.4425516,  7.703515 ,  7.779322 ,  7.737194 ,
                                        10.053865 ,  7.729328 ,  8.818529 ,  7.985178 ,  8.910851 ],
                                       [ 8.326874 ,  8.365593 ,  8.533541 ,  8.057083 ,  8.582642 ,
                                         8.476882 ,  7.821847 ,  7.8002496,  7.996052 ,  8.035376 ],
                                       [11.163536 ,  7.3773756,  7.3016233,  7.535657 , 11.403984 ,
                                         9.553177 ,  7.490673 , 11.685973 ,  8.532706 ,  8.527232 ],
                                 

DeviceArray([-435.30835, -609.2503 ,  350.9254 , -497.05014, -231.78734,
             -265.169  ,  -44.51284, -175.4745 ,  499.16098,  -89.38466],            dtype=float32)