# Deep Markov Model
Based on "Structured Inference Networks for Nonlinear State Space Models" by Krishnan, Shalit and Sontag. (AAAI 2017)

In [1]:
import jax
import jax.numpy as np
from jax.experimental import stax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, ELBO
from numpyro.optim import Adam
from numpyro.examples.datasets import load_dataset, JSBCHORALES

In [2]:
init, get_batch = load_dataset(JSBCHORALES, batch_size=32)
ds_count, ds_indxs = init()
seqs, seqs_rev, lengths = get_batch(0, ds_indxs)
print("Sequences: ", seqs.shape)
print("Length min: ", min(lengths), "max: ", max(lengths))

Sequences:  (32, 129, 4)
Length min:  33 max:  108


In [3]:
def _one_hot_chorales(seqs, num_nodes=88):
    return np.sum(np.array((seqs[..., None] == np.arange(num_nodes + 1)), 'int'),axis=-2)[..., 1:]
_one_hot_chorales(seqs[:, 0]).shape

(32, 88)

## DMM Neural Components

In [4]:
def Emitter(hidden_dim1, hidden_dim2, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim1), stax.Relu,
        stax.Dense(hidden_dim2), stax.Relu,
        stax.Dense(out_dim), stax.Sigmoid
    )

In [5]:
def Transition(gate_hidden_dim, prop_mean_hidden_dim, out_dim):
    gate_init_fun, gate_apply_fun = stax.serial(
        stax.Dense(gate_hidden_dim), stax.Relu,
        stax.Dense(out_dim), stax.Sigmoid
    )

    prop_mean_init_fun, prop_mean_apply_fun = stax.serial(
        stax.Dense(prop_mean_hidden_dim), stax.Relu,
        stax.Dense(out_dim)
    )

    mean_init_fun, mean_apply_fun = stax.Dense(out_dim)

    stddev_init_fun, stddev_apply_fun = stax.serial(
        stax.Relu, stax.Dense(out_dim),
        stax.Softplus
    )

    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, 2)
        k1, k2, k3, k4 = jax.random.split(rng, num=4)
        _, gate_params = gate_init_fun(k1, input_shape)
        prop_mean_output_shape, prop_mean_params = prop_mean_init_fun(k2, input_shape)
        _, mean_params = mean_init_fun(k3, input_shape)
        _, stddev_params = stddev_init_fun(k4, prop_mean_output_shape)
        return output_shape, (gate_params, prop_mean_params, 
                              mean_params, stddev_params)

    def apply_fun(params, inputs, **kwargs):
        gate_params, prop_mean_params, mean_params, stddev_params = params
        gt = gate_apply_fun(gate_params, inputs)
        ht = prop_mean_apply_fun(prop_mean_params, inputs)
        mut = (1 - gt) * mean_apply_fun(mean_params, inputs) + gt * ht
        sigmat = stddev_apply_fun(stddev_params, ht)
        return mut, sigmat
    
    return init_fun, apply_fun

In [29]:
def Combiner(hidden_dim, out_dim):
    comb_init_fun, comb_apply_fun = stax.serial(
        stax.Dense(hidden_dim),
        stax.Tanh
    )

    mean_init_fun, mean_apply_fun = stax.Dense(out_dim)

    stddev_init_fun, stddev_apply_fun = stax.serial(
        stax.Dense(out_dim),
        stax.Softplus
    )

    def init_fun(rng, input_shape):
        prev_shape, hidden_shape = input_shape
        output_shape = input_shape[:-1] + (out_dim, 2)
        k1, k2, k3 = jax.random.split(rng, num=3)
        comb_shape, comb_params = comb_init_fun(k1, prev_shape)
        _, mean_params = mean_init_fun(k2, comb_shape)
        _, stddev_params = stddev_init_fun(k3, comb_shape)
        return output_shape, (comb_params, mean_params, stddev_params)

    def apply_fun(params, inputs, **kwargs):
        comb_params, mean_params, stddev_params = params
        prev_input, hidden_input = inputs
        print(prev_input.shape)
        print(hidden_input.shape)
        prev_contr = comb_apply_fun(comb_params, prev_input)
        comb = np.mean(np.stack([prev_contr, hidden_input], axis=-1), axis=-1)
        mut = mean_apply_fun(mean_params, comb)
        sigmat = stddev_apply_fun(stddev_params, comb)
        return mut, sigmat
    return init_fun, apply_fun

In [30]:
def GRU(hidden_dim, W_init=stax.glorot_normal()):
    # Inspired by https://github.com/google/jax/pull/2298
    input_update_init_fun, input_update_apply_fun = stax.Dense(hidden_dim)
    input_reset_init_fun, input_reset_apply_fun = stax.Dense(hidden_dim)
    input_output_init_fun, input_output_apply_fun = stax.Dense(hidden_dim)

    def init_fun(rng, input_shape):
        indv_input_shape = input_shape[1:]
        output_shape = input_shape[:-1] + (hidden_dim,)
        rng, k1, k2 = jax.random.split(rng, num=3)
        hidden_update_w = W_init(k1, (hidden_dim, hidden_dim))
        _, input_update_params = input_update_init_fun(k2, indv_input_shape)

        rng, k1, k2 = jax.random.split(rng, num=3)
        hidden_reset_w = W_init(k1, (hidden_dim, hidden_dim))
        _, input_reset_params = input_reset_init_fun(k2, indv_input_shape)

        rng, k1, k2 = jax.random.split(rng, num=3)
        hidden_output_w = W_init(k1, (hidden_dim, hidden_dim))
        _, input_output_params = input_output_init_fun(k2, indv_input_shape)

        return output_shape, (hidden_update_w, input_update_params,
                              hidden_reset_w, input_reset_params,
                              hidden_output_w, input_output_params)
    
    def apply_fun(params, inputs, **kwargs):
        (hidden_update_w, input_update_params,
         hidden_reset_w, input_reset_params,
         hidden_output_w, input_output_params) = params
        inps, lengths, init_hidden = inputs

        def apply_fun_single(prev_hidden, inp):
            i, inpv = inp
            inp_update = input_update_apply_fun(input_update_params, inpv)
            hidden_update = np.dot(prev_hidden, hidden_update_w)
            update_gate = stax.sigmoid(inp_update + hidden_update)
            reset_gate = stax.sigmoid(input_reset_apply_fun(input_reset_params, inpv) +
                                      np.dot(prev_hidden, hidden_reset_w))
            output_gate = update_gate * prev_hidden + (1 - update_gate) * np.tanh(
                input_output_apply_fun(input_output_params, inpv) + 
                np.dot(reset_gate * prev_hidden, hidden_output_w))
            hidden = np.where((i < lengths)[:, None], output_gate, prev_hidden)
            return hidden, hidden
        
        return jax.lax.scan(apply_fun_single, init_hidden, (np.arange(inps.shape[0]), inps))
    return init_fun, apply_fun

## Probabilistic model and guide

In [31]:
def model(seqs, seqs_rev, lengths, *,
          latent_dim=100, emission_dim=100, transition_dim=200,
          data_dim=88, gru_dim=400):
    batch_size, max_seq_length, *_ = seqs.shape
    seqs = np.transpose(seqs, axes=(0, 1))
    z0 = np.zeros((batch_size, latent_dim))
    transition = numpyro.module('transition', Transition(transition_dim, transition_dim, latent_dim),
                                input_shape=(batch_size, latent_dim))
    emitter = numpyro.module('emitter', Emitter(emission_dim, emission_dim, data_dim),
                                input_shape=(batch_size, latent_dim))
    def dmm_iter(prev_z, inp):
        i, x = inp
        zmean, zstddev = transition(prev_z)
        with numpyro.plate('data', batch_size):
            z = numpyro.sample(f'z_{i}', dist.Normal(zmean, zstddev))
            xprobs = emitter(z)
            oh_x = _one_hot_chorales(x)
            return z, numpyro.sample(f'x_{i}', dist.Bernoulli(xprobs).mask(i < lengths), obs=oh_x)
    return jax.lax.scan(dmm_iter, z0, (np.arange(seqs.shape[0]), seqs))

In [32]:
def guide(seqs, seqs_rev, lengths, *,
          latent_dim=100, emission_dim=100, transition_dim=200,
          data_dim=88, gru_dim=400):
    batch_size, max_seq_length, *_ = seqs.shape
    seqs = np.transpose(seqs, axes=(1, 0, 2))
    seqs_rev = np.transpose(seqs_rev, axes=(1, 0, 2))
    z0 = np.zeros((batch_size, latent_dim))
    h0 = numpyro.sample('h0', dist.Normal().expand((batch_size, gru_dim)))
    gru = numpyro.module('gru', GRU(gru_dim), input_shape=(max_seq_length, batch_size, data_dim))
    combiner = numpyro.module('combiner', Combiner(gru_dim, latent_dim),
                              input_shape=((batch_size, latent_dim), (batch_size, gru_dim)))
    _, hs = gru((_one_hot_chorales(seqs_rev), lengths, h0))
    def dmm_iter(prev_z, inp):
        i, h = inp
        with numpyro.plate('data', batch_size):
            zmean, zstddev = combiner((prev_z, h))
            z = numpyro.sample(f'z_{i}', dist.Normal(zmean, zstddev))
            return z, z
    return jax.lax.scan(dmm_iter, z0, (np.arange(hs.shape[0]), hs))

## Stochastic Variational Inference

In [33]:
svi = SVI(model, guide, ELBO(), Adam(8e-4))

In [34]:
rng_key = jax.random.PRNGKey(seed=142)
svi_state = svi.init(rng_key, seqs, seqs_rev, lengths)

(32, 400)
(32, 400)


TypeError: Incompatible shapes for dot: got (32, 400) and (100, 400).