Skip to content

Commit

Permalink
Uses init_loc_fn to initialize mixture particles (#1612)
Browse files Browse the repository at this point in the history
* changed `steinvi` to use `init_loc_fn` for all particles.

* removed unused import

* sketched `_find_init_params`

* | requires >=python3.10

* removed unused imports

* reduced smoke test (took long to run). added custom guide test.

* removed tests covered by new test cases.

* fixed imports.
  • Loading branch information
OlaRonning committed Jul 10, 2023
1 parent a66391d commit 9618266
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 97 deletions.
113 changes: 62 additions & 51 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from copy import deepcopy
import functools
from functools import partial
from itertools import chain
Expand All @@ -20,8 +21,7 @@
get_parameter_transform,
)
from numpyro.contrib.funsor import config_enumerate, enum
from numpyro.distributions import Distribution, Normal
from numpyro.distributions.constraints import real
from numpyro.distributions import Distribution
from numpyro.distributions.transforms import IdentityTransform
from numpyro.infer.autoguide import AutoGuide
from numpyro.infer.util import _guess_max_plate_nesting, transform_fn
Expand Down Expand Up @@ -102,6 +102,38 @@ def __init__(
enum=True,
**static_kwargs,
):
if isinstance(guide, AutoGuide):
not_comptaible_guides = [
"AutoIAFNormal",
"AutoBNAFNormal",
"AutoDAIS",
"AutoSemiDAIS",
"AutoSurrogateLikelihoodDAIS",
]
guide_name = guide.__class__.__name__
assert guide_name not in not_comptaible_guides, (
f"SteinVI currently not compatible with {guide_name}. "
f"If you have a use case, feel free to open an issue."
)

init_loc_error_message = (
"SteinVI is not compatible with init_to_feasible, init_to_value, "
"and init_to_uniform with radius=0. If you have a use case, "
"feel free to open an issue."
)
if isinstance(guide.init_loc_fn, partial):
init_fn_name = guide.init_loc_fn.func.__name__
if init_fn_name == "init_to_uniform":
assert (
guide.init_loc_fn.keywords.get("radius", None) != 0
), init_loc_error_message
else:
init_fn_name = guide.init_loc_fn.__name__
assert init_fn_name not in [
"init_to_feasible",
"init_to_value",
], init_loc_error_message

self._inference_model = model
self.model = model
self.guide = guide
Expand All @@ -112,7 +144,7 @@ def __init__(
)
self.kernel_fn = kernel_fn
self.static_kwargs = static_kwargs
self.num_particles = num_stein_particles
self.num_stein_particles = num_stein_particles
self.loss_temperature = loss_temperature
self.repulsion_temperature = repulsion_temperature
self.enum = enum
Expand Down Expand Up @@ -167,48 +199,21 @@ def _calc_particle_info(self, uparams, num_particles, start_index=0):
start_index = end_index
return res, end_index

def _find_init_params(self, particle_seed, inner_guide, inner_guide_trace):
def extract_info(site):
nonlocal particle_seed
name = site["name"]
value = site["value"]
constraint = site["kwargs"].get("constraint", real)
transform = get_parameter_transform(site)
if (
isinstance(inner_guide, AutoGuide)
and "_".join((inner_guide.prefix, "loc")) in name
):
site_key, particle_seed = random.split(particle_seed)
unconstrained_shape = transform.inverse_shape(value.shape)
init_value = jnp.expand_dims(
transform.inv(value), 0
) + Normal( # Add gaussian noise
scale=0.1
).sample(
particle_seed, (self.num_particles, *unconstrained_shape)
)
init_value = transform(init_value)

else:
site_fn = site["fn"]
site_args = site["args"]
site_key, particle_seed = random.split(particle_seed)
def _find_init_params(self, particle_seed, inner_guide, model_args, model_kwargs):
def local_trace(key):
guide = deepcopy(inner_guide)

def _reinit(seed):
with handlers.seed(rng_seed=seed):
return site_fn(*site_args)
with handlers.seed(rng_seed=key), handlers.trace() as mixture_trace:
guide(*model_args, **model_kwargs)

init_value = vmap(_reinit)(
random.split(particle_seed, self.num_particles)
)
return init_value, constraint
init_params = {
name: site["value"]
for name, site in mixture_trace.items()
if site.get("type") == "param"
}
return init_params

init_params = {
name: extract_info(site)
for name, site in inner_guide_trace.items()
if site.get("type") == "param"
}
return init_params
return vmap(local_trace)(random.split(particle_seed, self.num_stein_particles))

def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
# 0. Separate model and guide parameters, since only guide parameters are updated using Stein
Expand Down Expand Up @@ -352,7 +357,7 @@ def _update_force(attr_force, rep_force, jac):
vmap(single_particle_grad)(
stein_particles, attractive_force, repulsive_force
)
/ self.num_particles
/ self.num_stein_particles
)

# 5. Decompose the monolithic particle forces back to concrete parameter values
Expand All @@ -372,19 +377,25 @@ def init(self, rng_key: KeyArray, *args, **kwargs):
:param kwargs: Keyword arguments to the model / guide.
:return: initial :data:`SteinVIState`
"""
rng_key, kernel_seed, model_seed, guide_seed = random.split(rng_key, 4)
model_init = handlers.seed(self.model, model_seed)
guide_init = handlers.seed(self.guide, guide_seed)
guide_trace = handlers.trace(guide_init).get_trace(
*args, **kwargs, **self.static_kwargs

rng_key, kernel_seed, model_seed, guide_seed, particle_seed = random.split(
rng_key, 5
)

model_init = handlers.seed(self.model, model_seed)
model_trace = handlers.trace(model_init).get_trace(
*args, **kwargs, **self.static_kwargs
)
rng_key, particle_seed = random.split(rng_key)

guide_init_params = self._find_init_params(
particle_seed, self.guide, guide_trace
particle_seed, self.guide, args, kwargs
)

guide_init = handlers.seed(self.guide, guide_seed)
guide_trace = handlers.trace(guide_init).get_trace(
*args, **kwargs, **self.static_kwargs
)

params = {}
transforms = {}
inv_transforms = {}
Expand Down Expand Up @@ -415,7 +426,7 @@ def init(self, rng_key: KeyArray, *args, **kwargs):
"particle_transform", IdentityTransform()
)
if site["name"] in guide_init_params:
pval, _ = guide_init_params[site["name"]]
pval = guide_init_params[site["name"]]
if self.non_mixture_params_fn(site["name"]):
pval = tree_map(lambda x: x[0], pval)
else:
Expand Down

0 comments on commit 9618266

Please sign in to comment.