Skip to content

Commit

Permalink
Add init_params argument to svi.init() and svi.run() (#1561)
Browse files Browse the repository at this point in the history
* init_params

* add docstring

* fixes
  • Loading branch information
ordabayevy committed Mar 24, 2023
1 parent f66ba4f commit f98bc4d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
13 changes: 11 additions & 2 deletions numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,32 @@ def __init__(self, model, guide, optim, loss, **static_kwargs):

self.optim = optax_to_numpyro(optim)

def init(self, rng_key, *args, **kwargs):
def init(self, rng_key, *args, init_params=None, **kwargs):
"""
Gets the initial SVI state.
:param jax.random.PRNGKey rng_key: random number generator seed.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param dict init_params: if not None, initialize :class:`numpyro.param` sites with values from
this dictionary instead of using ``init_value`` in :class:`numpyro.param` primitives.
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: the initial :data:`SVIState`
"""
rng_key, model_seed, guide_seed = random.split(rng_key, 3)
model_init = seed(self.model, model_seed)
guide_init = seed(self.guide, guide_seed)
if init_params is not None:
guide_init = substitute(guide_init, init_params)
guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
init_guide_params = {
name: site["value"]
for name, site in guide_trace.items()
if site["type"] == "param"
}
if init_params is not None:
init_guide_params.update(init_params)
model_trace = trace(
substitute(replay(model_init, guide_trace), init_guide_params)
).get_trace(*args, **kwargs, **self.static_kwargs)
Expand Down Expand Up @@ -305,6 +311,7 @@ def run(
progress_bar=True,
stable_update=False,
init_state=None,
init_params=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -333,6 +340,8 @@ def run(
# continue from the end of the previous svi run rather than beginning again from iteration 0
svi_result = svi.run(random.PRNGKey(1), 2000, data, init_state=svi_result.state)
:param dict init_params: if not None, initialize :class:`numpyro.param` sites with values from
this dictionary instead of using ``init_value`` in :class:`numpyro.param` primitives.
:param kwargs: keyword arguments to the model / guide
:return: a namedtuple with fields `params` and `losses` where `params`
holds the optimized values at :class:`numpyro.param` sites,
Expand All @@ -351,7 +360,7 @@ def body_fn(svi_state, _):
return svi_state, loss

if init_state is None:
svi_state = self.init(rng_key, *args, **kwargs)
svi_state = self.init(rng_key, *args, init_params=init_params, **kwargs)
else:
svi_state = init_state
if progress_bar:
Expand Down
24 changes: 24 additions & 0 deletions test/infer/test_svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
from numpyro.util import fori_loop


def assert_equal(a, b, prec=0):
return jax.tree_util.tree_map(lambda a, b: assert_allclose(a, b, atol=prec), a, b)


@pytest.mark.parametrize("alpha", [0.0, 2.0])
def test_renyi_elbo(alpha):
def model(x):
Expand Down Expand Up @@ -224,6 +228,26 @@ def guide():
assert_allclose(svi_result.params["shared"], target_value, atol=0.1)


def test_init_params():
init_params = {"b": 1.0, "c": 2.0}

def model():
numpyro.param("a", 0.0)
# should receive initial value from init_params
numpyro.param("b")

def guide():
# should receive initial value from init_params
numpyro.param("c")

svi = SVI(model, guide, optim.Adam(0.01), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0), init_params=init_params)
params = svi.get_params(svi_state)
init_params["a"] = 0.0
# make sure init params ended up in the SVI state
assert_equal(params, init_params)


def test_elbo_dynamic_support():
x_prior = dist.TransformedDistribution(
dist.Normal(),
Expand Down

0 comments on commit f98bc4d

Please sign in to comment.