Skip to content

Commit

Permalink
Add init to mean strategy (#1550)
Browse files Browse the repository at this point in the history
* added init_to_mean

* Add init_to_mean strategy

* Fix merge duplicate init_to_mean

* fallback to median instead

---------

Co-authored-by: Vitalii Kleshchevnikov <vk7@sanger.ac.uk>
  • Loading branch information
fehiepsi and vitkl committed Mar 8, 2023
1 parent 94121ac commit 737d7d9
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/source/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ init_to_feasible
^^^^^^^^^^^^^^^^
.. autofunction:: numpyro.infer.initialization.init_to_feasible

init_to_mean
^^^^^^^^^^^^
.. autofunction:: numpyro.infer.initialization.init_to_mean

init_to_median
^^^^^^^^^^^^^^
.. autofunction:: numpyro.infer.initialization.init_to_median
Expand Down
2 changes: 2 additions & 0 deletions numpyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from numpyro.infer.hmc_gibbs import HMCECS, DiscreteHMCGibbs, HMCGibbs
from numpyro.infer.initialization import (
init_to_feasible,
init_to_mean,
init_to_median,
init_to_sample,
init_to_uniform,
Expand All @@ -30,6 +31,7 @@
__all__ = [
"autoguide",
"init_to_feasible",
"init_to_mean",
"init_to_median",
"init_to_sample",
"init_to_uniform",
Expand Down
30 changes: 30 additions & 0 deletions numpyro/infer/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,36 @@ def init_to_median(site=None, num_samples=15):
return init_to_uniform(site)


def init_to_mean(site=None):
"""
Initialize to the prior mean. For priors with no `.mean` property implemented,
we defer to the :func:`init_to_median` strategy.
"""
if site is None:
return partial(init_to_mean)

if (
site["type"] == "sample"
and not site["is_observed"]
and not site["fn"].support.is_discrete
):
if site["value"] is not None:
warnings.warn(
f"init_to_mean() skipping initialization of site '{site['name']}'"
" which already stores a value.",
stacklevel=find_stack_level(),
)
return site["value"]
try:
# Try .mean property.
value = site["fn"].mean
sample_shape = site["kwargs"].get("sample_shape")
if sample_shape:
value = jnp.broadcast_to(value, sample_shape + jnp.shape(value))
except (NotImplementedError, ValueError):
return init_to_median(site)


def init_to_sample(site=None):
"""
Initialize to a prior sample. For priors with no `.sample` method implemented,
Expand Down
2 changes: 2 additions & 0 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.initialization import (
init_to_feasible,
init_to_mean,
init_to_median,
init_to_sample,
init_to_uniform,
Expand Down Expand Up @@ -240,6 +241,7 @@ def model():
init_to_uniform(radius=3),
init_to_value(values={"tau": 0.7}),
init_to_feasible,
init_to_mean,
init_to_median,
init_to_sample,
init_to_uniform,
Expand Down

0 comments on commit 737d7d9

Please sign in to comment.