Skip to content

Commit

Permalink
fix _loc_scale method in AutoMultivariateNormal (#3233)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak committed Jun 23, 2023
1 parent ea44053 commit 7e3d62e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
5 changes: 4 additions & 1 deletion pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,10 @@ def get_posterior(self, *args, **kwargs):
return dist.MultivariateNormal(self.loc, scale_tril=scale_tril)

def _loc_scale(self, *args, **kwargs):
return self.loc, self.scale * self.scale_tril.diag()
scale_tril = self.scale[..., None] * self.scale_tril
scale = scale_tril.pow(2).sum(-1).sqrt()
assert scale.shape == self.loc.shape
return self.loc, scale


class AutoDiagonalNormal(AutoContinuous):
Expand Down
9 changes: 8 additions & 1 deletion tests/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,9 @@ def AutoGuideList_x(model):
@pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO])
def test_quantiles(auto_class, Elbo):
def model():
pyro.sample("x", dist.Normal(0.0, 1.0))
pyro.sample("y", dist.LogNormal(0.0, 1.0))
pyro.sample("z", dist.Beta(2.0, 2.0).expand([2]).to_event(1))
pyro.sample("x", dist.Normal(0.0, 1.0))

guide = auto_class(model)
optim = Adam({"lr": 0.05, "betas": (0.8, 0.99)})
Expand All @@ -543,6 +543,13 @@ def model():
if auto_class is AutoLaplaceApproximation:
guide = guide.laplace_approximation()

if hasattr(auto_class, "get_posterior"):
posterior = guide.get_posterior()
posterior_scale = posterior.variance[-1].sqrt()
q = guide.quantiles([0.158655, 0.8413447])
quantile_scale = 0.5 * (q["x"][1] - q["x"][0]) # only x is unconstrained
assert_close(quantile_scale, posterior_scale, atol=1.0e-6)

quantiles = guide.quantiles([0.1, 0.5, 0.9])
median = guide.median()
for name in ["x", "y", "z"]:
Expand Down

0 comments on commit 7e3d62e

Please sign in to comment.