In [1]:
import jax
import jax.numpy as np
from jax.experimental import stax
import numpyro
import numpyro.distributions as dist

In [2]:
def emitter(hidden_dim1, hidden_dim2, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim1), stax.Relu(),
        stax.Dense(hidden_dim2), stax.Relu(),
        stax.Dense(out_dim), stax.Sigmoid()
    )

In [4]:
def transition(gate_hidden_dim, prop_mean_hidden_dim, out_dim):
    gate_init_fun, gate_apply_fun = stax.serial(
        stax.Dense(gate_hidden_dim), stax.Relu(),
        stax.Dense(out_dim), stax.Sigmoid()
    )

    prop_mean_init_fun, prop_mean_apply_fun = stax.serial(
        stax.Dense(prop_mean_hidden_dim), stax.Relu(),
        stax.Dense(out_dim)
    )

    mean_init_fun, mean_apply_fun = stax.Dense(out_dim)

    stddev_init_fun, stddev_apply_fun = stax.serial(
        stax.Relu(), stax.Dense(out_dim),
        stax.Softplus()
    )

    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, 2)
        k1, k2, k3, k4 = jax.random.split(rng, num=4)
        _, gate_params = gate_init_fun(k1, input_shape)
        prop_mean_output_shape, prop_mean_params = prop_mean_init_fun(k2, input_shape)
        _, mean_params = mean_init_fun(k3, input_shape)
        _, stddev_params = stddev_init_fun(k4, prop_mean_output_shape)
        return output_shape, (gate_params, prop_mean_params, 
                              mean_params, stddev_params)

    def apply_fun(params, inputs, **kwargs):
        gate_params, prop_mean_params, mean_params, stddev_params = params
        gt = gate_apply_fun(gate_params, inputs)
        ht = prop_mean_apply_fun(prop_mean_params, inputs)
        mut = (1 - gt) * mean_apply_fun(mean_params, inputs) + gt * ht
        sigmat = stddev_apply_fun(stddev_params, ht)
        return np.stack([mut, sigmat], axis=-1)
    
    return init_fun, apply_fun