In [1]:
import jax
import jax.numpy as jnp

In [2]:
class PytreeLSTMCell():
    def __init__(self, weight_ih, weight_hh, bias):
        self.weight_ih = weight_ih
        self.weight_hh = weight_hh
        self.bias = bias

    def __call__(self, inputs, h, c):
        ifgo = self.weight_ih @ inputs + self.weight_hh @ h + self.bias
        i, f, g, o = jnp.split(ifgo, indices_or_sections=4, axis=-1)
        i = jax.nn.sigmoid(i)
        f = jax.nn.sigmoid(f)
        g = jnp.tanh(g)
        o = jax.nn.sigmoid(o)
        new_c = f * c + i * g
        new_h = o * jnp.tanh(new_c)
        return (new_h, new_c)

jax.tree_util.register_pytree_node(
    PytreeLSTMCell,
    lambda c: ((c.weight_ih, c.weight_hh, c.bias), None),
    lambda _, ws: PytreeLSTMCell(*ws),
)

class PytreeLSTMLM():
    def __init__(self, cell, embeddings, c_0):
        self.cell = cell
        self.embeddings = embeddings
        self.c_0 = c_0
    
    @property
    def hc_0(self):
        return (jnp.tanh(self.c_0), self.c_0)

    @jax.jit
    def forward(self, seq, hc):
        loss = 0.
        for idx in seq:
            loss -= jax.nn.log_softmax(self.embeddings @ hc[0])[idx]
            hc = self.cell(self.embeddings[idx,:], *hc)
        return loss, hc

    def greedy_argmax(self, hc, length=6):
        idxs = []
        for i in range(length):
            idx = jnp.argmax(self.embeddings @ hc[0])
            idxs.append(int(idx))
            hc = self.cell(self.embeddings[idx,:], *hc)
        return idxs

# only need to return lm.cell, tree_map is clever enough to flatten/unflatten recursively.
def flatten_whole_lstmlm(lm):
    return (lm.cell, lm.embeddings, lm.c_0), None

def unflatten_whole_lstmlm(aux, weights):
    return PytreeLSTMLM(*weights)

jax.tree_util.register_pytree_node(
    PytreeLSTMLM,
    flatten_whole_lstmlm,
    unflatten_whole_lstmlm,
)

In [3]:
vocab_size = 43
hid_dim = 17
lm = PytreeLSTMLM(
    PytreeLSTMCell(
        jax.random.uniform(jax.random.PRNGKey(1234), (4*hid_dim, hid_dim)),
        jax.random.uniform(jax.random.PRNGKey(4321), (4*hid_dim, hid_dim)),
        jnp.zeros((4*hid_dim,)),
    ),
    jax.random.uniform(jax.random.PRNGKey(123), (vocab_size, hid_dim)),
    jnp.zeros((hid_dim,)),
)

In [14]:
jax.tree_util.tree_structure(lm)

PyTreeDef(CustomNode(PytreeLSTMLM[None], [CustomNode(PytreeLSTMCell[None], [*, *, *]), *, *]))

In [4]:
training_data = jnp.array([4, 8, 15, 16, 23, 42])

In [5]:
update_combiner = lambda p, g: p - 0.1*g

In [6]:
def pure_loss_fn(lm, seq, hc):
    if hc is None:
        hc = lm.hc_0
    loss, hc = lm.forward(seq, hc)
    return loss,hc

grad_fn = jax.value_and_grad(pure_loss_fn,has_aux=True)

print("Sample before:", lm.greedy_argmax(lm.hc_0))

bptt_length = 3
for epoch in range(101):
    totalloss = 0.
    hc = None
    for start in range(0, len(training_data), bptt_length):
        batch = training_data[start:start+bptt_length]
        (loss,hc),grad_lm = grad_fn(lm, batch, hc)
        if epoch % 50 == 0:
            totalloss += loss.item()
        lm = jax.tree_map(update_combiner, lm, grad_lm)
    if totalloss:
        print("Loss:", totalloss)

print("Sample after:", lm.greedy_argmax(lm.hc_0))

Sample before: [0, 25, 25, 25, 25, 25]
Loss: 26.58103370666504
Loss: 4.404338717460632
Loss: 2.6979196071624756
Sample after: [4, 8, 15, 16, 23, 42]
