Skip to content

Commit

Permalink
Add simple RandomWalkKernel (#3311)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak committed Jan 5, 2024
1 parent 5d920aa commit f4a6168
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/pyro.infer.mcmc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ NUTS
:undoc-members:
:show-inheritance:

RandomWalkKernel
----------------

.. autoclass:: pyro.infer.mcmc.RandomWalkKernel
:members:
:undoc-members:
:show-inheritance:

BlockMassMatrix
---------------

Expand Down
2 changes: 2 additions & 0 deletions pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc.hmc import HMC
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.mcmc.rwkernel import RandomWalkKernel
from pyro.infer.predictive import Predictive
from pyro.infer.renyi_elbo import RenyiELBO
from pyro.infer.rws import ReweightedWakeSleep
Expand Down Expand Up @@ -45,6 +46,7 @@
"MCMC",
"NUTS",
"Predictive",
"RandomWalkKernel",
"RBFSteinKernel",
"RenyiELBO",
"ReweightedWakeSleep",
Expand Down
2 changes: 2 additions & 0 deletions pyro/infer/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from pyro.infer.mcmc.api import MCMC, StreamingMCMC
from pyro.infer.mcmc.hmc import HMC
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.mcmc.rwkernel import RandomWalkKernel

__all__ = [
"ArrowheadMassMatrix",
"BlockMassMatrix",
"HMC",
"MCMC",
"NUTS",
"RandomWalkKernel",
"StreamingMCMC",
]
143 changes: 143 additions & 0 deletions pyro/infer/mcmc/rwkernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
from collections import OrderedDict

import torch

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc.mcmc_kernel import MCMCKernel
from pyro.infer.mcmc.util import initialize_model


class RandomWalkKernel(MCMCKernel):
r"""
Simple gradient-free kernel that utilizes an isotropic gaussian random walk in the unconstrained
latent space of the model. The step size that controls the variance of the kernel is adapted during
warm-up with a simple adaptation scheme that targets a user-provided acceptance probability.
:param model: Python callable containing Pyro primitives.
:param float init_step_size: A positive float that controls the initial step size. Defaults to 0.1.
:param float target_accept_prob: The target acceptance probability used during adaptation of
the step size. Defaults to 0.234.
Example:
>>> true_coefs = torch.tensor([1., 2., 3.])
>>> data = torch.randn(2000, 3)
>>> dim = 3
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample()
>>>
>>> def model(data):
... coefs_mean = torch.zeros(dim)
... coefs = pyro.sample('beta', dist.Normal(coefs_mean, torch.ones(3)))
... y = pyro.sample('y', dist.Bernoulli(logits=(coefs * data).sum(-1)), obs=labels)
... return y
>>>
>>> hmc_kernel = RandomWalkKernel(model, init_step_size=0.2)
>>> mcmc = MCMC(hmc_kernel, num_samples=200, warmup_steps=100)
>>> mcmc.run(data)
>>> mcmc.get_samples()['beta'].mean(0) # doctest: +SKIP
tensor([ 0.9819, 1.9258, 2.9737])
"""

def __init__(
self, model, init_step_size: float = 0.1, target_accept_prob: float = 0.234
):
if not isinstance(init_step_size, float) or init_step_size <= 0.0:
raise ValueError("init_step_size must be a positive float.")

if (
not isinstance(target_accept_prob, float)
or target_accept_prob <= 0.0
or target_accept_prob >= 1.0
):
raise ValueError(
"target_accept_prob must be a float in the interval (0, 1)."
)

self.model = model
self.init_step_size = init_step_size
self.target_accept_prob = target_accept_prob

self._t = 0
self._log_step_size = math.log(init_step_size)
self._accept_cnt = 0
self._mean_accept_prob = 0.0
super().__init__()

def setup(self, warmup_steps, *args, **kwargs):
self._warmup_steps = warmup_steps
(
self._initial_params,
self.potential_fn,
self.transforms,
self._prototype_trace,
) = initialize_model(
self.model,
model_args=args,
model_kwargs=kwargs,
)
self._energy_last = self.potential_fn(self._initial_params)

def sample(self, params):
step_size = math.exp(self._log_step_size)
new_params = {
k: v + step_size * torch.randn(v.shape, dtype=v.dtype, device=v.device)
for k, v in params.items()
}
energy_proposal = self.potential_fn(new_params)
delta_energy = energy_proposal - self._energy_last

accept_prob = (-delta_energy).exp().clamp(max=1.0).item()
rand = pyro.sample(
"rand_t={}".format(self._t),
dist.Uniform(0.0, 1.0),
)
accepted = False
if rand < accept_prob:
accepted = True
params = new_params
self._energy_last = energy_proposal

if self._t <= self._warmup_steps:
adaptation_speed = max(0.001, 0.1 / math.sqrt(1 + self._t))
self._log_step_size += adaptation_speed * (
accept_prob - self.target_accept_prob
)

self._t += 1

if self._t > self._warmup_steps:
n = self._t - self._warmup_steps
if accepted:
self._accept_cnt += 1
else:
n = self._t

self._mean_accept_prob += (accept_prob - self._mean_accept_prob) / n

return params.copy()

@property
def initial_params(self):
return self._initial_params

@initial_params.setter
def initial_params(self, params):
self._initial_params = params

def logging(self):
return OrderedDict(
[
("step size", "{:.2e}".format(math.exp(self._log_step_size))),
("acc. prob", "{:.3f}".format(self._mean_accept_prob)),
]
)

def diagnostics(self):
return {
"acceptance rate": self._accept_cnt / (self._t - self._warmup_steps),
}
40 changes: 40 additions & 0 deletions tests/infer/mcmc/test_rwkernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc.rwkernel import RandomWalkKernel
from tests.common import assert_equal


def test_beta_bernoulli():
alpha = torch.tensor([1.1, 2.2])
beta = torch.tensor([1.1, 2.2])

def model(data):
p_latent = pyro.sample("p_latent", dist.Beta(alpha, beta))
with pyro.plate("data", data.shape[0], dim=-2):
pyro.sample("obs", dist.Bernoulli(p_latent), obs=data)

num_data = 5
true_probs = torch.tensor([0.9, 0.1])
data = dist.Bernoulli(true_probs).sample(sample_shape=(torch.Size((num_data,))))

kernel = RandomWalkKernel(model)
mcmc = MCMC(kernel, num_samples=2000, warmup_steps=500)
mcmc.run(data)
samples = mcmc.get_samples()

data_sum = data.sum(0)
alpha_post = alpha + data_sum
beta_post = beta + num_data - data_sum
expected_mean = alpha_post / (alpha_post + beta_post)
expected_var = (
expected_mean.pow(2) * beta_post / (alpha_post * (1 + alpha_post + beta_post))
)

assert_equal(samples["p_latent"].mean(0), expected_mean, prec=0.03)
assert_equal(samples["p_latent"].var(0), expected_var, prec=0.005)

0 comments on commit f4a6168

Please sign in to comment.