In [1]:
%matplotlib inline

import time

import matplotlib.pyplot as plt
import numpy as onp
import seaborn as sns; sns.set(palette="bright")
import tqdm

import jax.numpy as np
from jax import jit, random
from jax.config import config; config.update("jax_platform_name", "cpu")
from jax.scipy.special import logit, logsumexp
from jax.tree_util import tree_map, tree_multimap

import numpyro.distributions as dist
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc

In [2]:
num_categories = 3
num_words = 10
num_supervised_data = 1
num_unsupervised_data = 5
rng = random.PRNGKey(1)
rng, rng_transition, rng_emission = random.split(rng, 3)

transition_prior = np.ones(num_categories)
emission_prior = np.full((num_words,), 0.1)

transition_prob = dist.dirichlet.rvs(transition_prior, size=num_categories, random_state=rng_transition)
emission_prob = dist.dirichlet.rvs(emission_prior, size=num_categories, random_state=rng_emission)

In [3]:
def equilibrium(mc_matrix):
    n = mc_matrix.shape[0]
    return np.sum(onp.linalg.inv(np.identity(n) - mc_matrix.T + 1), axis=-1)

start_prob = equilibrium(transition_prob)

# simulate data
categories, words = [], []
for t in range(num_supervised_data + num_unsupervised_data):
    rng, rng_transition, rng_emission = random.split(rng, 3)
    if t == 0 or t == num_supervised_data:
        category = dist.categorical.rvs(start_prob, random_state=rng_transition)
    else:
        category = dist.categorical.rvs(transition_prob[category], random_state=rng_transition)
    word = dist.categorical.rvs(emission_prob[category], random_state=rng_emission)
    categories.append(category)
    words.append(word)
categories, words = np.stack(categories), np.stack(words)

# split into supervised data and unsupervised data
supervised_categories = categories[:num_supervised_data]
supervised_words = categories[:num_supervised_data]
unsupervised_words = categories[num_supervised_data:]

In [4]:
def plot_posterior(posterior):
    # generate Marginal distribution for `transition_prob` from posterior
    marginal = posterior.marginal(["transition_prob"])
    # get support of the marginal distribution
    trace_transition_prob = marginal.support()["transition_prob"]  # shape: num_samples x 3 x 3

    plt.figure(figsize=(10, 6))
    for i in range(num_categories):
        for j in range(num_categories):
            sns.distplot(trace_transition_prob[:, i, j], hist=False, kde_kws={"lw": 2},
                         label="transition_prob[{}, {}], true value = {:.2f}"
                         .format(i, j, transition_prob[i, j]))
    plt.xlabel("Probability", fontsize=13)
    plt.ylabel("Frequency", fontsize=13)
    plt.title("Transition probability posterior", fontsize=15)

In [5]:
def forward_log_prob(prev_log_prob, curr_word, transition_log_prob, emission_log_prob):
    log_prob = emission_log_prob[:, curr_word] + transition_log_prob
    log_prob = log_prob + np.expand_dims(prev_log_prob, axis=1)
    return logsumexp(log_prob, axis=0)

In [6]:
def semi_supervised_hmm(supervised_categories, supervised_words, unsupervised_words):
    transition_prob = sample("transition_prob", dist.dirichlet(
        np.broadcast_to(transition_prior, (num_categories, num_categories))))
    emission_prob = sample("emission_prob", dist.dirichlet(
        np.broadcast_to(emission_prior, (num_categories, num_words))))

    category = supervised_categories[0]
    for t in range(len(supervised_words)):
        if t > 0:
            category = sample("category_{}".format(t), dist.categorical(transition_prob[category]),
                              obs=supervised_categories[t])
        sample("word_{}".format(t), dist.categorical(emission_prob[category]),
               obs=supervised_words[t])

    transition_log_prob = np.log(transition_prob)
    emission_log_prob = np.log(emission_prob)
    log_prob = emission_log_prob[:, unsupervised_words[0]]
    for t in range(1, len(unsupervised_words)):
        log_prob = forward_log_prob(log_prob, unsupervised_words[t],
                                    transition_log_prob, emission_log_prob)
    prob = np.clip(np.exp(logsumexp(log_prob, axis=0)), a_min=np.finfo(log_prob.dtype).tiny)
    return sample("forward_prob", dist.bernoulli(prob), obs=1)

In [7]:
def mcmc(sample_kernel, state, num_samples, transform):
    for i in tqdm.tqdm(range(num_samples)):
        state = sample_kernel(state)
        state_out = transform(tree_map(lambda x: np.expand_dims(x, axis=0), state))
        if i == 0:
            states = state_out
        else:
            states = tree_multimap(lambda x, y: np.concatenate((x, y))
                                   if x is not None else None, states, state_out)
    return states

In [11]:
init_params, potential_fn, transform_fn = initialize_model(
    random.PRNGKey(2), semi_supervised_hmm,
    (supervised_categories, supervised_words, unsupervised_words), {})
init_kernel, sample_kernel = hmc(potential_fn, algo="NUTS")

In [12]:
hmc_state, _, _ = init_kernel(init_params, num_warmup_steps=0, adapt_step_size=False,
                              run_warmup=False)

jsample_kernel = jit(sample_kernel)
start = time.time()
hmc_state = hmc_state.update(step_size=100.)  # HACK: force fast compiling!
jsample_kernel(hmc_state)
hmc_state = hmc_state.update(step_size=1.)
print("time to compile sample_kernel:", time.time() - start)

time to compile sample_kernel: 10.987462282180786


In [13]:
num_samples = 100
start = time.time()
hmc_states = mcmc(jsample_kernel, hmc_state, num_samples,
                  transform=lambda state: {"transition_prob": transform_fn(state.z)["transition_prob"],
                                           "num_steps": state.num_steps})
num_leapfrogs = np.sum(hmc_states["num_steps"]).copy()
print("number of leapfrog steps:", num_leapfrogs)
print("avg. time for each step :", (time.time() - start) / num_leapfrogs)

100%|██████████| 100/100 [00:09<00:00, 10.85it/s]

number of leapfrog steps: 100
avg. time for each step : 0.09226186275482177





### Stan

In [1]:
import urllib

import pystan

url = "https://raw.githubusercontent.com/stan-dev/example-models/master/misc/hmm/hmm-semisup.stan"
stan_model = urllib.request.urlopen(url).read().decode("utf-8")
print(stan_model)

data {
  int<lower=1> K;  // num categories
  int<lower=1> V;  // num words
  int<lower=0> T;  // num supervised items
  int<lower=1> T_unsup;  // num unsupervised items
  int<lower=1,upper=V> w[T]; // words
  int<lower=1,upper=K> z[T]; // categories
  int<lower=1,upper=V> u[T_unsup]; // unsup words
  vector<lower=0>[K] alpha;  // transit prior
  vector<lower=0>[V] beta;   // emit prior
}
parameters {
  simplex[K] theta[K];  // transit probs
  simplex[V] phi[K];    // emit probs
}
model {
  for (k in 1:K) 
    theta[k] ~ dirichlet(alpha);
  for (k in 1:K)
    phi[k] ~ dirichlet(beta);
  for (t in 1:T)
    w[t] ~ categorical(phi[z[t]]);
  for (t in 2:T)
    z[t] ~ categorical(theta[z[t-1]]);

  { 
    // forward algorithm computes log p(u|...)
    real acc[K];
    real gamma[T_unsup,K];
    for (k in 1:K)
      gamma[1,k] <- log(phi[k,u[1]]);
    for (t in 2:T_unsup) {
      for (k in 1:K) {
        for (j in 1:K)
          acc[j] <- gamma[t-1,j] + log(theta[j,k]) + log(phi[k,u[t]]);
  

In [2]:
%%time
model = pystan.StanModel(model_code=stan_model)

INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_14c781acf8dbbaa6f89d694cd310dd75 NOW.


CPU times: user 903 ms, sys: 40.6 ms, total: 944 ms
Wall time: 50 s


In [12]:
data = {"K": num_categories, "V": num_words, "T": num_supervised_data, "T_unsup": num_unsupervised_data,
        "alpha": transition_prior, "beta": emission_prior,
        "w": supervised_words + 1, "z": supervised_categories + 1, "u": unsupervised_words + 1}

{'K': 3,
 'V': 10,
 'T': 100,
 'T_unsup': 500,
 'alpha': DeviceArray{float32[3]},
 'beta': DeviceArray{float32[10]},
 'w': DeviceArray{int32[100]},
 'z': DeviceArray{int32[100]},
 'u': DeviceArray{int32[500]}}

In [13]:
%%time
fit = model.sampling(data, chains=1, iter=200)



CPU times: user 31.6 s, sys: 68.7 ms, total: 31.7 s
Wall time: 31.6 s
