-
Notifications
You must be signed in to change notification settings - Fork 223
/
funnel.py
155 lines (127 loc) · 4.64 KB
/
funnel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""
Example: Neal's Funnel
======================
This example, which is adapted from [1], illustrates how to leverage non-centered
parameterization using the :class:`~numpyro.handlers.reparam` handler.
We will examine the difference between two types of parameterizations on the
10-dimensional Neal's funnel distribution. As we will see, HMC gets trouble at
the neck of the funnel if centered parameterization is used. On the contrary,
the problem can be solved by using non-centered parameterization.
Using non-centered parameterization through :class:`~numpyro.infer.reparam.LocScaleReparam`
or :class:`~numpyro.infer.reparam.TransformReparam` in NumPyro has the same effect as
the automatic reparameterisation technique introduced in [2].
**References:**
1. *Stan User's Guide*, https://mc-stan.org/docs/2_19/stan-users-guide/reparameterization-section.html
2. Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019), "Automatic
Reparameterisation of Probabilistic Programs", (https://arxiv.org/abs/1906.03028)
.. image:: ../_static/img/examples/funnel.png
:align: center
"""
import argparse
import os
import matplotlib.pyplot as plt
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.handlers import reparam
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam
def model(dim=10):
y = numpyro.sample("y", dist.Normal(0, 3))
numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))
reparam_model = reparam(model, config={"x": LocScaleReparam(0)})
def run_inference(model, args, rng_key):
kernel = NUTS(model)
mcmc = MCMC(
kernel,
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
mcmc.run(rng_key)
mcmc.print_summary(exclude_deterministic=False)
return mcmc
def main(args):
rng_key = random.PRNGKey(0)
# do inference with centered parameterization
print(
"============================= Centered Parameterization =============================="
)
mcmc = run_inference(model, args, rng_key)
samples = mcmc.get_samples()
diverging = mcmc.get_extra_fields()["diverging"]
# do inference with non-centered parameterization
print(
"\n=========================== Non-centered Parameterization ============================"
)
reparam_mcmc = run_inference(reparam_model, args, rng_key)
reparam_samples = reparam_mcmc.get_samples()
reparam_diverging = reparam_mcmc.get_extra_fields()["diverging"]
# collect deterministic sites
reparam_samples = Predictive(
reparam_model, reparam_samples, return_sites=["x", "y"]
)(random.PRNGKey(1))
# make plots
fig, (ax1, ax2) = plt.subplots(
2, 1, sharex=True, figsize=(8, 8), constrained_layout=True
)
ax1.plot(
samples["x"][~diverging, 0],
samples["y"][~diverging],
"o",
color="darkred",
alpha=0.3,
label="Non-diverging",
)
ax1.plot(
samples["x"][diverging, 0],
samples["y"][diverging],
"o",
color="lime",
label="Diverging",
)
ax1.set(
xlim=(-20, 20),
ylim=(-9, 9),
ylabel="y",
title="Funnel samples with centered parameterization",
)
ax1.legend()
ax2.plot(
reparam_samples["x"][~reparam_diverging, 0],
reparam_samples["y"][~reparam_diverging],
"o",
color="darkred",
alpha=0.3,
)
ax2.plot(
reparam_samples["x"][reparam_diverging, 0],
reparam_samples["y"][reparam_diverging],
"o",
color="lime",
)
ax2.set(
xlim=(-20, 20),
ylim=(-9, 9),
xlabel="x[0]",
ylabel="y",
title="Funnel samples with non-centered parameterization",
)
plt.savefig("funnel_plot.pdf")
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.9.2")
parser = argparse.ArgumentParser(
description="Non-centered reparameterization example"
)
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)