Skip to content

Commit

Permalink
Fix quantile of mvn autoguides (#1622)
Browse files Browse the repository at this point in the history
* fix quantile of mvn autoguides

* add test for auto mvn quantiles
  • Loading branch information
fehiepsi committed Jul 30, 2023
1 parent 902623c commit 6b23534
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
4 changes: 3 additions & 1 deletion numpyro/distributions/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ def __init__(
# distributions match, but for now we just check the type, since __eq__
# isn't consistently implemented for all support types.
support_type = type(component_distributions[0].support)
if any(type(d.support) != support_type for d in component_distributions[1:]):
if any(
type(d.support) is not support_type for d in component_distributions[1:]
):
raise ValueError("All component distributions must have the same support.")

self._mixing_distribution = mixing_distribution
Expand Down
12 changes: 6 additions & 6 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,9 +1647,9 @@ def median(self, params):
def quantiles(self, params, quantiles):
transform = self.get_transform(params)
quantiles = jnp.array(quantiles)[..., None]
latent = dist.Normal(transform.loc, jnp.diagonal(transform.scale_tril)).icdf(
quantiles
)
latent = dist.Normal(
transform.loc, jnp.linalg.norm(transform.scale_tril, axis=-1)
).icdf(quantiles)
return self._unpack_and_constrain(latent, params)


Expand Down Expand Up @@ -1828,9 +1828,9 @@ def median(self, params):
def quantiles(self, params, quantiles):
transform = self.get_transform(params)
quantiles = jnp.array(quantiles)[..., None]
latent = dist.Normal(transform.loc, jnp.diagonal(transform.scale_tril)).icdf(
quantiles
)
latent = dist.Normal(
transform.loc, jnp.linalg.norm(transform.scale_tril, axis=-1)
).icdf(quantiles)
return self._unpack_and_constrain(latent, params)


Expand Down
21 changes: 21 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,27 @@ def body_fn(i, val):
assert_allclose(jnp.mean(posterior_samples["coefs"], 0), expected_coefs, rtol=0.1)


def test_mvn_quantile():
def model():
numpyro.sample("x", dist.Normal(0, 1).expand([2]).to_event(1))

guide = AutoMultivariateNormal(model)
with handlers.seed(rng_seed=random.PRNGKey(0)):
guide()
params = {
"auto_loc": jnp.zeros(2),
"auto_scale_tril": jnp.array([[1.0, 0.0], [0.5, 0.5]]),
}
actual = guide.quantiles(params, quantiles=0.3)["x"]
mvn = dist.MultivariateNormal(
params["auto_loc"], scale_tril=params["auto_scale_tril"]
)
expected = dist.Normal(params["auto_loc"], jnp.sqrt(mvn.variance)).icdf(
jnp.array(0.3)
)
assert_allclose(actual, expected, rtol=1e-5)


def test_iaf():
# test for substitute logic for exposed methods `sample_posterior` and `get_transforms`
N, dim = 3000, 3
Expand Down

0 comments on commit 6b23534

Please sign in to comment.