Skip to content

Commit

Permalink
Add median to batched auto-guides. (#1749)
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Feb 27, 2024
1 parent d6f9897 commit 00f43d0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
13 changes: 12 additions & 1 deletion numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,9 +1862,12 @@ def _get_batched_posterior(self):
def _get_posterior(self):
return dist.TransformedDistribution(
self._get_batched_posterior(),
ReshapeTransform((self.latent_dim,), self._batch_shape + self._event_shape),
self._get_reshape_transform(),
)

def _get_reshape_transform(self) -> ReshapeTransform:
return ReshapeTransform((self.latent_dim,), self._batch_shape + self._event_shape)


class AutoBatchedMultivariateNormal(AutoBatchedMixin, AutoContinuous):
"""
Expand Down Expand Up @@ -1911,6 +1914,10 @@ def _get_batched_posterior(self):
)
return dist.MultivariateNormal(loc, scale_tril=scale_tril)

def median(self, params):
loc = self._get_reshape_transform()(params["{}_loc".format(self.prefix)])
return self._unpack_and_constrain(loc, params)


class AutoLowRankMultivariateNormal(AutoContinuous):
"""
Expand Down Expand Up @@ -2039,6 +2046,10 @@ def _get_batched_posterior(self):
cov_factor = cov_factor * scale[..., None]
return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)

def median(self, params):
loc = self._get_reshape_transform()(params["{}_loc".format(self.prefix)])
return self._unpack_and_constrain(loc, params)


class AutoLaplaceApproximation(AutoContinuous):
r"""
Expand Down
24 changes: 24 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,3 +1311,27 @@ def model(n, m):

with pytest.raises(ValueError, match="Expected 2 batch dimensions"):
auto_class(model, batch_ndim=2)(3, 3)


@pytest.mark.parametrize(
"auto_class", [AutoBatchedMultivariateNormal, AutoBatchedLowRankMultivariateNormal]
)
def test_auto_batched_median(auto_class) -> None:
def model():
distribution = dist.Normal().expand([7]).to_event(1)
with numpyro.plate("n", 3):
x = numpyro.sample("x", distribution)
with numpyro.plate("m", 3):
y = numpyro.sample("y", distribution)
return x, y

guide = auto_class(model)
with numpyro.handlers.seed(rng_seed=0):
trace = numpyro.handlers.trace(guide).get_trace()
params = {
key: value["value"] for key, value in trace.items() if value["type"] == "param"
}
median = guide.median(params)
assert jnp.allclose(
jnp.stack(list(median.values())).ravel(), params["auto_loc"].ravel()
)

0 comments on commit 00f43d0

Please sign in to comment.