In [1]:
import jax
import jax.numpy as np
from jax.experimental import stax
import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import load_dataset, JSBCHORALES

In [2]:
init, get_batch = load_dataset(JSBCHORALES, batch_size=32)
ds_count, ds_indxs = init()
seqs, seqs_rev, lengths = get_batch(0, ds_indxs)
print("Sequences: ", seqs.shape)
print("Length min: ", min(lengths), "max: ", max(lengths))

Sequences:  (32, 129, 4)
Length min:  33 max:  109


In [3]:
def emitter(hidden_dim1, hidden_dim2, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim1), stax.Relu(),
        stax.Dense(hidden_dim2), stax.Relu(),
        stax.Dense(out_dim), stax.Sigmoid()
    )

In [4]:
def transition(gate_hidden_dim, prop_mean_hidden_dim, out_dim):
    gate_init_fun, gate_apply_fun = stax.serial(
        stax.Dense(gate_hidden_dim), stax.Relu(),
        stax.Dense(out_dim), stax.Sigmoid()
    )

    prop_mean_init_fun, prop_mean_apply_fun = stax.serial(
        stax.Dense(prop_mean_hidden_dim), stax.Relu(),
        stax.Dense(out_dim)
    )

    mean_init_fun, mean_apply_fun = stax.Dense(out_dim)

    stddev_init_fun, stddev_apply_fun = stax.serial(
        stax.Relu(), stax.Dense(out_dim),
        stax.Softplus()
    )

    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, 2)
        k1, k2, k3, k4 = jax.random.split(rng, num=4)
        _, gate_params = gate_init_fun(k1, input_shape)
        prop_mean_output_shape, prop_mean_params = prop_mean_init_fun(k2, input_shape)
        _, mean_params = mean_init_fun(k3, input_shape)
        _, stddev_params = stddev_init_fun(k4, prop_mean_output_shape)
        return output_shape, (gate_params, prop_mean_params, 
                              mean_params, stddev_params)

    def apply_fun(params, inputs, **kwargs):
        gate_params, prop_mean_params, mean_params, stddev_params = params
        gt = gate_apply_fun(gate_params, inputs)
        ht = prop_mean_apply_fun(prop_mean_params, inputs)
        mut = (1 - gt) * mean_apply_fun(mean_params, inputs) + gt * ht
        sigmat = stddev_apply_fun(stddev_params, ht)
        return np.stack([mut, sigmat], axis=-1)
    
    return init_fun, apply_fun

In [5]:
def combiner(out_dim):
    prev_init_fun, prev_apply_fun = stax.serial(
        stax.Dense(out_dim),
        stax.Tanh()
    )

    def init_fun(rng, input_shape):
        prev_shape, hidden_shape = input_shape
        output_shape, prev_params = prev_init_fun(rng, prev_shape)
        return output_shape, prev_params

    def apply_fun(params, inputs, **kwargs):
        prev_params = params
        prev_input, hidden_input = inputs
        prev_contr = prev_apply_fun(prev_contr, prev_input)
        return np.mean(np.stack([prev_contr, hidden_input], axis=-1), axis=-1)

    return init_fun, apply_fun

In [6]:
def gru(hidden_dim, W_init=stax.glorot_normal()):
    # Inspired by https://github.com/google/jax/pull/2298
    input_update_init_fun, input_update_apply_fun = stax.Dense(hidden_dim)
    input_reset_init_fun, input_reset_apply_fun = stax.Dense(hidden_dim)
    input_output_init_fun, input_output_apply_fun = stax.Dense(hidden_dim)

    def init_fun(rng, input_shape):
        input_shape, hidden_shape = input_shape
        output_shape = input_shape[:-1] + (hidden_dim,)
        rng, k1, k2 = jax.random.split(rng, num=3)
        hidden_update_w = W_init(k1, (hidden_shape[-1], hidden_dim))
        _, input_update_params = input_update_init_fun(k2, input_shape)

        rng, k1, k2 = jax.random.split(rng, num=3)
        hidden_reset_w = W_init(k1, (hidden_shape[-1], hidden_dim))
        _, input_reset_params = input_reset_init_fun(k2, hidden_shape)

        rng, k1, k2 = jax.random.split(rng, num=3)
        hidden_output_w = W_init(k1, (hidden_shape[-1], hidden_dim))
        _, input_output_params = input_output_init_fun(k2, input_shape)

        return output_shape, (hidden_update_w, input_update_params,
                              hidden_reset_w, input_reset_params,
                              hidden_output_w, input_output_params)
    
    def apply_fun(params, inputs, **kwargs):
        (hidden_update_w, input_update_params,
         hidden_reset_w, input_reset_params,
         hidden_output_w, input_output_params) = params
        inps, lengths, init_hidden = inputs

        def apply_fun_single(prev_hidden, inp):
            i, inpv, inpl = inp
            update_gate = stax.sigmoid(input_update_apply_fun(input_update_params, inpv) + 
                                       np.dot(prev_hidden, hidden_update_w))
            reset_gate = stax.sigmoid(input_reset_apply_fun(input_reset_params, inpv) +
                                      np.dot(prev_hidden, hidden_reset_w))
            output_gate = update_gate * prev_hidden + (1 - update_gate) * np.tanh(
                input_output_apply_fun(input_output_params, inpv) + 
                np.dot(reset_gate * prev_hidden, hidden_output_w))
            hidden = np.where(i < inpl, output_gate, prev_hidden)
            return hidden, hidden
        
        return jax.lax.scan(apply_fun_single, init_hidden, (np.arange(inps.shape[0]), inps, lengths))
    return init_fun, apply_fun

In [None]:
def model(seqs, seqs_rev, lengths, *,
          latent_dim=100, emission_dim=100, transition_dim=200, **kwargs):
    batch_size, max_seq_length, *_ = seqs.shape
    seqs = np.transpose(seqs, axes=(0, 1))
    z0 = np.zeros((batch_size, latent_dim))
    def dmm_iter(prev_z, ):
        pass
    pass