Skip to content

Commit

Permalink
Add inference utilities to transform between unconstrained and constr…
Browse files Browse the repository at this point in the history
…ained space (#1564)

Improve and simplify constrain_fn and unconstrain_fn implementation

Add missing doctstrings

Constrain/unconstrain functions now always consider param sites

Fix syntax for lint tests

Fix syntax for lint tests

Fix syntax for lint tests
  • Loading branch information
aymgal committed Mar 30, 2023
1 parent 9f4f3eb commit abe456c
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
44 changes: 44 additions & 0 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def substitute_fn(site):
if site["type"] == "sample":
with helpful_support_errors(site):
return biject_to(site["fn"].support)(params[site["name"]])
elif site["type"] == "param":
constraint = site["kwargs"].pop("constraint", constraints.real)
with helpful_support_errors(site):
return biject_to(constraint)(params[site["name"]])
else:
return params[site["name"]]

Expand All @@ -193,6 +197,42 @@ def substitute_fn(site):
}


def get_transforms(model, model_args, model_kwargs, params):
"""
(EXPERIMENTAL INTERFACE) Retrieve (inverse) transforms via biject_to()
given a NumPyro model. This function supports 'param' sites.
NB: Parameter values are only used to retrieve the model trace.
:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of values keyed by site names.
:return: `dict` of transformation keyed by site names.
"""
substituted_model = substitute(model, data=params)
transforms, _, _, _ = _get_model_transforms(
substituted_model, model_args, model_kwargs
)
return transforms


def unconstrain_fn(model, model_args, model_kwargs, params):
"""
(EXPERIMENTAL INTERFACE) Given a NumPyro model and a dict of parameters,
this function applies the right transformation to convert parameter values
from constrained space to unconstrained space.
:param model: a callable containing NumPyro primitives.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param dict params: dictionary of constrained values keyed by site
names.
:return: `dict` of transformation keyed by site names.
"""
transforms = get_transforms(model, model_args, model_kwargs, params)
return transform_fn(transforms, params, invert=True)


def _unconstrain_reparam(params, site):
name = site["name"]
if name in params:
Expand Down Expand Up @@ -449,6 +489,10 @@ def _get_model_transforms(model, model_args=(), model_kwargs=None):
for arg in args:
if not isinstance(getattr(support, arg), (int, float)):
replay_model = True
elif v["type"] == "param":
constraint = v["kwargs"].pop("constraint", constraints.real)
with helpful_support_errors(v, raise_warnings=True):
inv_transforms[k] = biject_to(constraint)
elif v["type"] == "deterministic":
replay_model = True
return inv_transforms, replay_model, has_enumerate_support, model_trace
Expand Down
47 changes: 47 additions & 0 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
log_likelihood,
potential_energy,
transform_fn,
unconstrain_fn,
)
import numpyro.optim as optim

Expand Down Expand Up @@ -220,6 +221,52 @@ def model():
assert_allclose(actual_potential_energy, expected_potential_energy)


def test_constrain_unconstrain():
x_prior = dist.HalfNormal(2)
y_prior = dist.LogNormal(scale=3.0) # transformed distribution
z_constraint = constraints.positive

def model():
numpyro.sample("x", x_prior)
numpyro.sample("y", y_prior)
numpyro.param("z", init_value=2.0, constraint=z_constraint)

params = {"x": jnp.array(-5.0), "y": jnp.array(7.0), "z": jnp.array(3.0)}
model = handlers.seed(model, random.PRNGKey(0))
inv_transforms = {
"x": biject_to(x_prior.support),
"y": biject_to(y_prior.support),
"z": biject_to(z_constraint),
}
expected_constrained_samples = partial(transform_fn, inv_transforms)(params)
transforms = {
"x": biject_to(x_prior.support).inv,
"y": biject_to(y_prior.support).inv,
"z": biject_to(z_constraint).inv,
}
expected_unconstrained_samples = partial(transform_fn, transforms)(
expected_constrained_samples
)

actual_constrained_samples = constrain_fn(model, (), {}, params)
actual_unconstrained_samples = unconstrain_fn(
model, (), {}, actual_constrained_samples
)

assert_allclose(expected_constrained_samples["x"], actual_constrained_samples["x"])
assert_allclose(expected_constrained_samples["y"], actual_constrained_samples["y"])
assert_allclose(expected_constrained_samples["z"], actual_constrained_samples["z"])
assert_allclose(
expected_unconstrained_samples["x"], actual_unconstrained_samples["x"]
)
assert_allclose(
expected_unconstrained_samples["y"], actual_unconstrained_samples["y"]
)
assert_allclose(
expected_unconstrained_samples["z"], actual_unconstrained_samples["z"]
)


def test_model_with_mask_false():
def model():
x = numpyro.sample("x", dist.Normal())
Expand Down

0 comments on commit abe456c

Please sign in to comment.