In [1]:
import jax 
import jax.numpy as jnp
import haiku as hk

In [2]:
class BiLSTM(hk.Module):
    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size
    
    def __call__(self, seqs): # batch size X sequence length X embedding dim
        batch_size = seqs.shape[0]
        fwd_core = hk.LSTM(16)
        bwd_core = hk.LSTM(16)
        fwd_outs, fwd_state = hk.dynamic_unroll(fwd_core, seqs, fwd_core.initial_state(batch_size), time_major=False)
        bwd_outs, bwd_state = hk.dynamic_unroll(bwd_core, jnp.flip(seqs, axis=-1), bwd_core.initial_state(batch_size), time_major=False)
        outs = jnp.concatenate([fwd_outs, bwd_outs], axis=0)
        return hk.BatchApply(hk.Linear(self.output_size))(outs), fwd_state, bwd_state
 
        

In [3]:
seqs = jnp.ones((8, 10, 16))

In [4]:
def bi_lstm(seqs):
    out = BiLSTM(2)(seqs)
    return out

bi_lstm = hk.transform(bi_lstm)

In [7]:
key = jax.random.PRNGKey(0)
params = bi_lstm.init(key, seqs)
out, _, _ = bi_lstm.apply(params, key, seqs)

In [9]:
print(out.shape)

(16, 10, 2)
