-
Notifications
You must be signed in to change notification settings - Fork 222
/
minipyro.py
67 lines (50 loc) · 2.15 KB
/
minipyro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import argparse
from jax import random
import jax.numpy as jnp
from jax.random import PRNGKey
import numpyro
from numpyro import optim
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.util import fori_loop
def model(data):
loc = numpyro.sample("loc", dist.Normal(0.0, 1.0))
numpyro.sample("obs", dist.Normal(loc, 1.0), obs=data)
# Define a guide (i.e. variational distribution) with a Normal
# distribution over the latent random variable `loc`.
def guide(data):
guide_loc = numpyro.param("guide_loc", 0.0)
guide_scale = jnp.exp(numpyro.param("guide_scale_log", 0.0))
numpyro.sample("loc", dist.Normal(guide_loc, guide_scale))
def main(args):
# Generate some data.
data = random.normal(PRNGKey(0), shape=(100,)) + 3.0
# Construct an SVI object so we can do variational inference on our
# model/guide pair.
adam = optim.Adam(args.learning_rate)
svi = SVI(model, guide, adam, Trace_ELBO(num_particles=100))
svi_state = svi.init(PRNGKey(0), data)
# Training loop
def body_fn(i, val):
svi_state, loss = svi.update(val, data)
return svi_state
svi_state = fori_loop(0, args.num_steps, body_fn, svi_state)
# Report the final values of the variational parameters
# in the guide after training.
params = svi.get_params(svi_state)
for name, value in params.items():
print("{} = {}".format(name, value))
# For this simple (conjugate) model we know the exact posterior. In
# particular we know that the variational distribution should be
# centered near 3.0. So let's check this explicitly.
assert jnp.abs(params["guide_loc"] - 3.0) < 0.1
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.9.2")
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-f", "--full-pyro", action="store_true", default=False)
parser.add_argument("-n", "--num-steps", default=1001, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.02, type=float)
args = parser.parse_args()
main(args)