Skip to content

Commit

Permalink
update hmm and ucbadmit with new api (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and neerajprad committed May 29, 2019
1 parent 692788c commit 0262a16
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 35 deletions.
33 changes: 10 additions & 23 deletions examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
from jax.scipy.special import logsumexp

import numpyro.distributions as dist
from numpyro.diagnostics import summary
from numpyro.handlers import sample
from numpyro.hmc_util import initialize_model
from numpyro.mcmc import hmc
from numpyro.util import fori_collect
from numpyro.mcmc import mcmc


"""
Expand Down Expand Up @@ -117,21 +115,6 @@ def semi_supervised_hmm(transition_prior, emission_prior,
return sample('forward_log_prob', dist.Multinomial(logits=-log_prob), obs=0)


def run_inference(transition_prior, emission_prior, supervised_categories, supervised_words,
unsupervised_words, rng, args):
init_params, potential_fn, constrain_fn = initialize_model(
rng,
semi_supervised_hmm,
transition_prior, emission_prior, supervised_categories,
supervised_words, unsupervised_words,
)
init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS')
hmc_state = init_kernel(init_params, args.num_warmup)
hmc_states = fori_collect(args.num_samples, sample_kernel, hmc_state,
transform=lambda state: constrain_fn(state.z))
return hmc_states


def print_results(posterior, transition_prob, emission_prob):
header = semi_supervised_hmm.__name__ + ' - TRAIN'
columns = ['', 'ActualProb', 'Pred(p25)', 'Pred(p50)', 'Pred(p75)']
Expand Down Expand Up @@ -165,11 +148,15 @@ def main(args):
num_unsupervised_data=args.num_unsupervised,
)
print('Starting inference...')
zs = run_inference(transition_prior, emission_prior,
supervised_categories, supervised_words, unsupervised_words,
random.PRNGKey(2), args)
summary(zs)
print_results(zs, transition_prob, emission_prob)
init_params, potential_fn, constrain_fn = initialize_model(
random.PRNGKey(2),
semi_supervised_hmm,
transition_prior, emission_prior, supervised_categories,
supervised_words, unsupervised_words,
)
samples = mcmc(args.num_warmup, args.num_samples, init_params,
potential_fn=potential_fn, constrain_fn=constrain_fn)
print_results(samples, transition_prob, emission_prob)


if __name__ == '__main__':
Expand Down
22 changes: 10 additions & 12 deletions examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as onp

import jax.numpy as np
import jax.random as random
from jax import random, vmap
from jax.config import config as jax_config

import numpyro.distributions as dist
Expand Down Expand Up @@ -52,20 +52,18 @@
"""


# TODO: Remove broadcasting logic when support for `pyro.plate` is available.
def glmm(dept, male, applications, admit):
v_mu = sample('v_mu', dist.Normal(0, np.array([4., 1.])))

sigma = sample('sigma', dist.HalfNormal(np.ones(2)))
L_Rho = sample('L_Rho', dist.LKJCholesky(2))
scale_tril = np.expand_dims(sigma, axis=-1) * L_Rho
scale_tril = sigma[..., np.newaxis] * L_Rho
# non-centered parameterization
num_dept = len(onp.unique(dept))
z = sample('z', dist.Normal(np.zeros((num_dept, 2)), 1))
v = np.squeeze(np.matmul(np.expand_dims(scale_tril, axis=-3), np.expand_dims(z, axis=-1)),
axis=-1)
v = np.dot(scale_tril, z.T).T

logits = v_mu[..., :1] + v[..., dept, 0] + (v_mu[..., 1:] + v[..., dept, 1]) * male
logits = v_mu[0] + v[dept, 0] + (v_mu[1] + v[dept, 1]) * male
sample('admit', dist.Binomial(applications, logits=logits), obs=admit)


Expand All @@ -79,13 +77,10 @@ def run_inference(dept, male, applications, admit, rng, args):
return hmc_states


def predict(dept, male, applications, admit, z, rng):
header = glmm.__name__ + ' - TRAIN'
def predict(dept, male, applications, z, rng):
model = substitute(seed(glmm, rng), z)
model_trace = trace(model).get_trace(dept, male, applications, admit=None)
predictions = model_trace['admit']['fn'].probs
probs = admit / applications
print_results('=' * 30 + header + '=' * 30, predictions, dept, male, probs)
return model_trace['admit']['fn'].probs


def print_results(header, preds, dept, male, probs):
Expand All @@ -105,7 +100,10 @@ def main(args):
dept, male, applications, admit = fetch_train()
rng, rng_predict = random.split(random.PRNGKey(1))
zs = run_inference(dept, male, applications, admit, rng, args)
predict(dept, male, applications, admit, zs, rng_predict)
rngs = random.split(rng_predict, args.num_samples)
pred_probs = vmap(lambda z, rng: predict(dept, male, applications, z, rng))(zs, rngs)
header = '=' * 30 + 'glmm - TRAIN' + '=' * 30
print_results(header, pred_probs, dept, male, admit / applications)


if __name__ == '__main__':
Expand Down

0 comments on commit 0262a16

Please sign in to comment.