Skip to content

Commit

Permalink
Using torch transforms to run HMC on distributions with constrained s…
Browse files Browse the repository at this point in the history
…upport (#740)
  • Loading branch information
neerajprad authored and fritzo committed Mar 6, 2018
1 parent e3f0adf commit 5981247
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 10 deletions.
45 changes: 37 additions & 8 deletions pyro/infer/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import OrderedDict

import torch
from torch.distributions import biject_to, constraints

import pyro
import pyro.distributions as dist
Expand All @@ -28,12 +29,20 @@ class HMC(TraceKernel):
:param int num_steps: The number of discrete steps over which to simulate
Hamiltonian dynamics. The state at the end of the trajectory is
returned as the proposal.
:param dict transforms: Optional dictionary that specifies a transform
for a sample site with constrained support to unconstrained space. The
transform should be invertible, and implement `log_abs_det_jacobian`.
If not specified and the model has sites with constrained support,
automatic transformations will be applied, as specified in
:mod:`torch.distributions.constraint_registry`.
"""

def __init__(self, model, step_size=0.5, num_steps=3):
def __init__(self, model, step_size=0.5, num_steps=3, transforms=None):
self.model = model
self.step_size = step_size
self.num_steps = num_steps
self.transforms = {} if transforms is None else transforms
self._automatic_transform_enabled = True if transforms is None else False
self._reset()
super(HMC, self).__init__()

Expand All @@ -49,7 +58,17 @@ def _kinetic_energy(self, r):
return 0.5 * torch.sum(torch.stack([r[name]**2 for name in r]))

def _potential_energy(self, z):
return -self._get_trace(z).log_pdf()
# Since the model is specified in the constrained space, transform the
# unconstrained R.V.s `z` to the constrained space.
z_constrained = z.copy()
for name, transform in self.transforms.items():
z_constrained[name] = transform.inv(z_constrained[name])
trace = self._get_trace(z_constrained)
potential_energy = -trace.log_pdf()
# adjust by the jacobian for this transformation.
for name, transform in self.transforms.items():
potential_energy += transform.log_abs_det_jacobian(z_constrained[name], z[name]).sum()
return potential_energy

def _energy(self, z, r):
return self._kinetic_energy(r) + self._potential_energy(z)
Expand All @@ -63,9 +82,6 @@ def _reset(self):
self._prototype_trace = None

def _validate_trace(self, trace):
for name, node in trace.iter_stochastic_nodes():
if not node['fn'].reparameterized:
raise ValueError('Found non-reparameterized node in the model at site: {}'.format(name))
trace_log_pdf = trace.log_pdf()
if is_nan(trace_log_pdf) or is_inf(trace_log_pdf):
raise ValueError('Model specification incorrect - trace log pdf is NaN, Inf or 0.')
Expand All @@ -85,16 +101,25 @@ def setup(self, *args, **kwargs):
r_mu = torch.zeros_like(node['value'])
r_sigma = torch.ones_like(node['value'])
self._r_dist[name] = dist.Normal(mu=r_mu, sigma=r_sigma)
if node['fn'].support is not constraints.real and self._automatic_transform_enabled:
self.transforms[name] = biject_to(node['fn'].support).inv
self._validate_trace(self._prototype_trace)

def cleanup(self):
self._reset()

def sample(self, trace):
z = {name: node['value'] for name, node in trace.iter_stochastic_nodes()}
r = {name: pyro.sample('r_{}_t={}'.format(name, self._t), self._r_dist[name]) for name in self._r_dist}
z_new, r_new = velocity_verlet(z, r, self._potential_energy, self.step_size, self.num_steps)
# apply Metropolis correction
# automatically transform `z` to unconstrained space, if needed.
for name, transform in self.transforms.items():
z[name] = transform(z[name])
r = {name: pyro.sample('r_{}_t={}'.format(name, self._t), self._r_dist[name])
for name in self._r_dist}
z_new, r_new = velocity_verlet(z, r,
self._potential_energy,
self.step_size,
self.num_steps)
# apply Metropolis correction.
energy_proposal = self._energy(z_new, r_new)
energy_current = self._energy(z, r)
delta_energy = energy_proposal - energy_current
Expand All @@ -103,6 +128,10 @@ def sample(self, trace):
self._accept_cnt += 1
z = z_new
self._t += 1

# get trace with the constrained values for `z`.
for name, transform in self.transforms.items():
z[name] = transform.inv(z[name])
return self._get_trace(z)

def diagnostics(self):
Expand Down
6 changes: 6 additions & 0 deletions pyro/infer/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth):

def sample(self, trace):
z = {name: node["value"] for name, node in trace.iter_stochastic_nodes()}
# automatically transform `z` to unconstrained space, if needed.
for name, transform in self.transforms.items():
z[name] = transform(z[name])
r = {name: pyro.sample("r_{}_t={}".format(name, self._t), self._r_dist[name]) for name in self._r_dist}

# Ideally, following a symplectic integrator trajectory, the energy is constant.
Expand Down Expand Up @@ -222,4 +225,7 @@ def sample(self, trace):
if is_accepted:
self._accept_cnt += 1
self._t += 1
# get trace with the constrained values for `z`.
for name, transform in self.transforms.items():
z[name] = transform.inv(z[name])
return self._get_trace(z)
59 changes: 58 additions & 1 deletion tests/infer/mcmc/test_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
import torch
from torch.autograd import Variable
from torch.autograd import Variable, variable

import pyro
import pyro.distributions as dist
Expand Down Expand Up @@ -177,3 +177,60 @@ def model(data):
posterior.append(trace.nodes['beta']['value'])
posterior_mean = torch.mean(torch.stack(posterior), 0)
assert_equal(rmse(true_coefs, posterior_mean).item(), 0.0, prec=0.05)


def test_bernoulli_beta():
def model(data):
alpha = pyro.param('alpha', variable([1.1, 1.1], requires_grad=True))
beta = pyro.param('beta', variable([1.1, 1.1], requires_grad=True))
p_latent = pyro.sample('p_latent', dist.Beta(alpha, beta))
pyro.observe('obs', dist.Bernoulli(p_latent), data)
return p_latent

hmc_kernel = HMC(model, step_size=0.02, num_steps=3)
mcmc_run = MCMC(hmc_kernel, num_samples=800, warmup_steps=500)
posterior = []
true_probs = variable([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,))))
for trace, _ in mcmc_run._traces(data):
posterior.append(trace.nodes['p_latent']['value'])
posterior_mean = torch.mean(torch.stack(posterior), 0)
assert_equal(posterior_mean, true_probs, prec=0.01)


def test_normal_gamma():
def model(data):
rate = pyro.param('rate', variable([1.0, 1.0], requires_grad=True))
concentration = pyro.param('conc', variable([1.0, 1.0], requires_grad=True))
p_latent = pyro.sample('p_latent', dist.Gamma(rate, concentration))
pyro.observe("obs", dist.Normal(3, p_latent), data)
return p_latent

hmc_kernel = HMC(model, step_size=0.01, num_steps=3)
mcmc_run = MCMC(hmc_kernel, num_samples=200, warmup_steps=100)
posterior = []
true_std = variable([0.5, 2])
data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,))))
for trace, _ in mcmc_run._traces(data):
posterior.append(trace.nodes['p_latent']['value'])
posterior_mean = torch.mean(torch.stack(posterior), 0)
assert_equal(posterior_mean, true_std, prec=0.02)


@pytest.mark.xfail(reason='log_abs_det_jacobian not implemented for StickBreakingTransform')
def test_categorical_dirichlet():
def model(data):
concentration = pyro.param('conc', variable([1.0, 1.0, 1.0], requires_grad=True))
p_latent = pyro.sample('p_latent', dist.Dirichlet(concentration))
pyro.observe("obs", dist.Categorical(p_latent), data)
return p_latent

hmc_kernel = HMC(model, step_size=0.01, num_steps=3)
mcmc_run = MCMC(hmc_kernel, num_samples=200, warmup_steps=100)
posterior = []
true_probs = variable([0.1, 0.6, 0.3])
data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,))))
for trace, _ in mcmc_run._traces(data):
posterior.append(trace.nodes['p_latent']['value'])
posterior_mean = torch.mean(torch.stack(posterior), 0)
assert_equal(posterior_mean, true_probs, prec=0.02)
40 changes: 39 additions & 1 deletion tests/infer/mcmc/test_nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest
import torch
from torch.autograd import Variable
from torch.autograd import Variable, variable

import pyro
import pyro.distributions as dist
Expand Down Expand Up @@ -91,3 +91,41 @@ def model(data):
posterior.append(trace.nodes['beta']['value'])
posterior_mean = torch.mean(torch.stack(posterior), 0)
assert_equal(rmse(true_coefs, posterior_mean).item(), 0.0, prec=0.05)


def test_bernoulli_beta():
def model(data):
alpha = pyro.param('alpha', Variable(torch.Tensor([1.1, 1.1]), requires_grad=True))
beta = pyro.param('beta', Variable(torch.Tensor([1.1, 1.1]), requires_grad=True))
p_latent = pyro.sample("p_latent", dist.Beta(alpha, beta))
pyro.observe("obs", dist.Bernoulli(p_latent), data)
return p_latent

nuts_kernel = NUTS(model, step_size=0.02)
mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=100)
posterior = []
true_probs = Variable(torch.Tensor([0.9, 0.1]))
data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((1000,))))
for trace, _ in mcmc_run._traces(data):
posterior.append(trace.nodes['p_latent']['value'])
posterior_mean = torch.mean(torch.stack(posterior), 0)
assert_equal(posterior_mean.data, true_probs.data, prec=0.01)


def test_normal_gamma():
def model(data):
rate = pyro.param('rate', variable([1.0, 1.0], requires_grad=True))
concentration = pyro.param('conc', variable([1.0, 1.0], requires_grad=True))
p_latent = pyro.sample('p_latent', dist.Gamma(rate, concentration))
pyro.observe("obs", dist.Normal(3, p_latent), data)
return p_latent

nuts_kernel = NUTS(model, step_size=0.01)
mcmc_run = MCMC(nuts_kernel, num_samples=200, warmup_steps=100)
posterior = []
true_std = variable([0.5, 2])
data = dist.Normal(3, true_std).sample(sample_shape=(torch.Size((2000,))))
for trace, _ in mcmc_run._traces(data):
posterior.append(trace.nodes['p_latent']['value'])
posterior_mean = torch.mean(torch.stack(posterior), 0)
assert_equal(posterior_mean, true_std, prec=0.02)

0 comments on commit 5981247

Please sign in to comment.