In [2]:
from functools import partial # for use with vmap
import jax
import jax.numpy as jnp
import haiku as hk
import jax.scipy.stats.norm as norm
import optax
from jax_unirep.layers import AAEmbedding, mLSTM, mLSTMAvgHidden
from jax_unirep.utils import load_params, load_embedding, seq_to_oh
from jax_unirep.utils import *
from jax_unirep import get_reps
import matplotlib.pyplot as plt




In [44]:
key = jax.random.PRNGKey(0)

def forward(x):
    mlp = hk.nets.MLP([256, 32, 2])
    return mlp(x)

forward  = hk.without_apply_rng(hk.transform(forward))

class MLP:

    def __init__(self, key, forward):
        self.key = key
        self.forward = forward

    def deep_ensemble_loss(self, params, ins, labels):
        outs = self.forward.apply(params, ins)
        means = outs[0]
        stds = outs[1]
        n_log_likelihoods = 0.5*jnp.log(jnp.abs(stds)) + 0.5*(labels-means)**2/jnp.abs(stds)

        return n_log_likelihoods[0]

    def adv_loss_func(self, params, seqs, labels, loss_func):
        epsilon = 1e-3
        grad_inputs = jax.grad(loss_func, 1)(params, seqs, labels)
        seqs_ = seqs + epsilon * jnp.sign(grad_inputs)

        return loss_func(params, seqs, labels) + loss_func(params, seqs_, labels)

    def train_mlp(self, external_keys, seqs, labels):
        learning_rate = 1e-2
        n_training_steps = 2
        opt_init, opt_update = optax.chain(
            optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-4),
            optax.scale(-learning_rate) # minus sign -- minimizing the loss
        )
        self.key, key_ = jax.random.split(self.key, num=2)
        self.params = self.forward.init(self.key, jax.random.normal(key_, shape=(1900,)))
        opt_state = opt_init(self.params)

        loss_trace=[]
        for step in range(n_training_steps):
            loss, grad=jax.value_and_grad(self.adv_loss_func)(self.params, seqs, labels, self.deep_ensemble_loss)
            loss_trace.append(loss)

            updates, opt_state = opt_update(grad, opt_state, self.params)
            self.params = optax.apply_updates(self.params, updates)
        outs = self.forward.apply(self.params, seqs)

        #joint_outs = model_stack(outs)
        return loss_trace, outs

    def batch(self, seqs, labels):
        self.ensemble_seqs = jnp.tile(seqs, (5, 1 ,1))
        self.ensemble_labels = jax.lax.broadcast(labels, (5,))[...,jnp.newaxis]
        self.b_training_mlp = jax.vmap(self.train_mlp, (None, 0, 0), (0, 0))
        self.bb_training_mlp = jax.vmap(self.b_training_mlp, (0, 0, 0), (0, 0))
        self.external_keys = jax.random.split(self.key, num=5)
        self.external_keys = jnp.reshape(self.external_keys, (5, -1))



    def model_stack(self):
        mu = jnp.mean(self.outs[..., 0], axis=0)
        std = jnp.mean(self.outs[...,1] + self.outs[...,0]**2,axis=0) - mu**2
        return mu, std

    def call_train(self):
        self.loss_trace, self.outs = self.bb_training_mlp(self.external_keys, self.ensemble_seqs, self.ensemble_labels) # batched
        self.joint_outs = self.model_stack()
        
    def apply_(self, new_model, seq):
        outs = self.forward.apply(self.params,seq)
        print(outs.shape)
        mu,std = outs
        #mu = jnp.mean(outs[0], axis=0)
        #std = jnp.mean(outs[1] + outs[0]**2,axis=0) - mu**2
        return mu, std



seqs = ['MSADDGIVYCQMRQGPWEFHIVTVESSAYDWVVVPGARIALDKYNAACEQHWSCILRRGIDQKPYAPDMLKCQCSDMCHPSDSFTWEIDAEAWYCNTDNLFTGIALYKNNDDYPDWYPIRCLKHKNVTAAQVPLVHFNDNKFTHHVHNDMPACDFKFFKTPTVRHACQFGSIYHSKQSRMDYSDLMQDEKAKHLKESHNVVPDDGIIIDPYMDILFGGRMNNREHCAKNE',
        'EKMHIKESATRMGFQYEYKLPYCIWAFIIGRAWHFVSLHGDQWDCWKMTFVIYSACSNGHIDGCEVQHANLSSGVLPARWFDAFQQNMKGFHKMKCGGFCTYAFLWGLAMRIYVRNMGNLAIYQNGGTSEWLTEFWYRLAGAVWPFKQFSINGECEHFWWSFHPFTLFDNPPAKDRNVTAYLHFDAHFYSIAMVWLMSPVVKGDSPVNCCAVDVEQSGESWALLNNWCAP',
        'HSFHKYKHGNWKSEGDQCLKVGQLRDECPQVNTPMYCSWGPHYFSIFHWIIPVAKAYHMLHNIEQQVYRCHWQERYKELHDATKTHQLEWSFGKSVWCAHCKPYIGWYRSPAGWHMPIKPPATKNLWVVRHKSKRKEGTISWENTLTCVWFHEICYGHGVCHQVHPWVVDSNEEYEMQWMETEVGECSYPAERQGAWYSFTQQQKWICIHVCNMSSGRVFCWYVLQLFRN',
        'LDHAVLKILQAMGPWNNRVEHPRLGKRSTEWPAAIYEGEPRWRLKCDTTATYYKAFETRWYNCHMTLTCWWHGATIRSKLTTMCMMVTNGYRDFYRYNDWKGRKATKHHPMVCIYEILWIAFMGCLHMWAGARVSKIWVGFCIFFASCLQMSPLKDWHNKCAFGRNNPLGMKGWGMMIGNSFCHIVHEMDNKYYAGAPVDEPFMYNQQVFGFGAMHCLCMADFCNEWGIQ',
        'PERHHYIGFHCYMQLDAIQQNPHWNAHVLFRAFDYVSNYWTWITMYDKYQGFLGIYVTSCKVHEHGACKHCHWPICYDCGQHADKMLWRKSFALHGQSHAYRPLWDRDLTGVLGISIDLNQGIKVAEAEGEILYCNVTDMTVMMHQSVGVFWCHDMAYPQWTDWYSSDNMMNSIPEISHMKNYRVTMVHEPLFIWECVSEWTENAEHEGHLITVGSTGGKWDTGMEREVM',
        'DPSQTIHCGTTGMSWGTMFKRSYILIIRYGTPEATCPCIVNCQIVVYWGCMFKKDRDPRGTPIQSTENFFKHAMMEPSYAGGTAHMEKEIEYRSQDSWHAYFSYWVKVWCYVCIALSQIPNVAHHGMHLHASPEDKKCANNWRFRYVAFIRIAHGCSWCYRECYNFRYDRYIAWNPVHLESVPEWWAHPAFEIVKDTVDDNQYSGADERQGDPIGGQPCLLCATWEDSWT',
        'LIDLFSLTRKFSRMPCRHNMNESYKEEWCETNNVKEYPHEQLLDKRYDIITLDGCKRMYCRSESQRITELHFIRYNMLCWPDRCIPLQYSQYESNMPPPMMRMWGCYHFGTLMFMSYAMPPTGEKREIVGGKEDHSGDLEDAFTDEDFNMDPAHQDYRHIAGTWHEPMFEIRMRYELTCNNMWSPIYANNAGMKQLTICNNDKICPTEGRRRQREIFNYKLHGRDQCQHI',
        'SCDVGPHPLHGQCTGMAKQVMETANIPQCPIDDHVTRATMGLIDAGACDRDRVCVREIWNVYYDKSTMKIIMDPPSDTCKHKSFYGDMMSHQQMGWLSECIIANMQHNLPWQLWESWMIHSEICMIKQRKVMMFCGIQSKYTEDFARFHPFILANTQYIIFKRPTTWPRVYAFLHRCMVLGWSAYGMTAMIPNTKETIKLAHCEKWPLTGSYTPSFVIFDGWLARKCQWP',
        'DWIEHVHTFWVLMFISNYPQIVCGLINQIEPWKSKFHSLAGFNQGCQCEKNYQGPIQAINGINQLVTITTPINNQENVDKKPHPGSVHTKSDAITLRFNQGVHNIFMWDMATQGRASIPFLNNMNGGGLTDYSWEQVVTCHCHMTNDLELDPQMLYMWWIVSANAWMVNGMRRQHMACHWAQWEGFRWPRYVQSVPMKVLLTTQKIHWMQYFREKFCFILMKWQGYWYTV',
        'RHWRAPLLMYRDKEVQITWHFRFMYHCDALTCSEVHCHARNFMVFGYSTPQNYNPVILYWVTWANTCLTPKGAYCARQMRMYATVTMSKINQMTITYLVDRQRQHWGLAFRSDNTCNHKWYLKHRCKVWNWGWLIDCYDLDRNLPKQVSRNQSSKSLRDLFNYIHYHWAMLPINIYCYSGDIWTTISTDDQFHIPTFIPCGKTVHEDLQPYEMCGMWHQCEDADYTMQPV',
]
labels = jnp.array([25.217391304347824,
                    15.652173913043478,
                    23.478260869565219,
                    22.173913043478262,
                    23.913043478260871,
                    24.782608695652176,
                    26.956521739130434,
                    17.391304347826086,
                    19.130434782608695,
                    26.521739130434781
                   ])
seqs = get_reps(seqs)[0]

model = MLP(key, forward)
model.batch(seqs, labels)
model.call_train()



In [49]:
#model.apply_(seqs[0])
def forward2(x):
    mlp2 = hk.nets.MLP([256, 32, 2])
    return mlp2(x)

forward2  = hk.without_apply_rng(hk.transform(forward2))
params = forward2.init(key, jax.random.normal(key, shape=(1900,)))
out = forward2.apply(model.params, seqs[0])
print(out)

Traced<ShapedArray(float32[2])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[10,2])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[[-3.2417736 ,  2.048891  ],
                                        [-0.713571  ,  2.1600096 ],
                                        [ 0.12612917, -6.1499877 ],
                                        [-0.9862055 ,  2.7809763 ],
                                        [ 0.78969663, -5.394996  ],
                                        [ 0.62117404, -4.25655   ],
                                        [ 0.5410639 , -5.2603903 ],
                                        [-1.2641287 ,  3.5802088 ],
                                        [-1.1202598 ,  3.6278844 ],
                                        [-2.8280272 ,  5.2859383 ]],
                          
                                       [[-3.2417736 ,  2.048891  ],
                                        [-0.713571  ,  2.1600096 ],
                      

In [None]:
model = MLP(key, forward)
def end2end(key, model, forward, seqs, labels):
    model.batch(seqs, labels)
    model.call_train()
    return self.joint_outs

In [35]:
#print(seqs.shape)
#ensemble_unirep_seq = jnp.tile(seqs, (5,1,1))
#print(ensemble_unirep_seq.shape)
print(seqs[0])
outs = forward.apply(model.params, seqs[0])
print(outs)

[ 0.00536008 -0.09600413  0.04200451 ...  0.03934921  0.10036227
 -0.00582207]
Traced<ShapedArray(float32[2])>with<BatchTrace(level=2/0)>
  with val = Traced<ShapedArray(float32[10,2])>with<BatchTrace(level=1/0)>
               with val = DeviceArray([[[-3.2417736 ,  2.048891  ],
                                        [-0.713571  ,  2.1600096 ],
                                        [ 0.12612917, -6.1499877 ],
                                        [-0.9862055 ,  2.7809763 ],
                                        [ 0.78969663, -5.394996  ],
                                        [ 0.62117404, -4.25655   ],
                                        [ 0.5410639 , -5.2603903 ],
                                        [-1.2641287 ,  3.5802088 ],
                                        [-1.1202598 ,  3.6278844 ],
                                        [-2.8280272 ,  5.2859383 ]],
                          
                                       [[-3.2417736 ,  2.048891  ],
           

In [33]:
print(model.params['mlp/~/linear_1']['w'].shape)
print(model.joint_outs)

(256, 32)
(DeviceArray([-3.2417731 , -1.1716001 ,  0.15587394, -1.3546982 ,
              0.9841746 ,  0.8379841 ,  0.62949824, -1.5348262 ,
             -1.3603796 , -3.2406297 ], dtype=float32), DeviceArray([ 2.048892 ,  3.5862346, -6.970621 ,  3.8595724, -6.62054  ,
             -5.538826 , -6.020439 ,  4.405252 ,  4.4307556,  6.0628967],            dtype=float32))


In [None]:
def bayesian_ei(model, X):
    mu, std = model.joint_outs
    best = jnp.max(mu)
    epsilon = 0.1
    z = (mu-best-epsilon)/std
    return (mu-best-epsilon)*norm.cdf(z) + std*norm.pdf(z)

def optimizer(model, init_vec):
    ei = bayesian_ei(model)
    eta = 1e-2
    n_steps = 100
    opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2, b1=0.8, b2=0.9, eps=1e-5)
    opt_state = opt_init(init_vec)
    
    @jax.jit
    def step(i, opt_state):
        vec1900 = get_params(opt_state)
        outs = model.forward.apply(model.params, vec1900)
        
        
    
    
def train_seqprop_adam(key, target_rep, init_logits, init_r, init_b, iter_num=100):
    opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2, b1=0.8, b2=0.9, eps=1e-5)
    #opt_init, opt_update, get_params = optimizers.adagrad(step_size=1e-2)
    opt_state = opt_init((init_logits, init_r, init_b)) # initial state
    logits_trace = []
    loss_trace = []

    @jax.jit
    def step(key, i, opt_state):
        key, subkey = jax.random.split(key, num=2)
        p = get_params(opt_state)
        logits, r, b = p
        
        sampled_vec, norm_logits = forward_seqprop(key, logits, r, b)
        loss = loss_func(target_rep, sampled_vec)
        g = jax.grad(g_loss_func, (1,2,3))(key, logits, r, b, target_rep)
        return opt_update(i, g, opt_state), loss

    for step_idx in range(iter_num):
        #print(step_idx)
        opt_state, loss = step(key, step_idx, opt_state)
        #print(loss)
        loss_trace.append(loss)
        mid_logits, mid_r, mid_b = get_params(opt_state)
        logits_trace.append(mid_logits)
    final_logits, final_r, final_b = get_params(opt_state)
    sampled_vec, _ = forward_seqprop(key, final_logits, final_r, final_b)
    return sampled_vec, final_logits, logits_trace, loss_trace