Skip to content

Commit

Permalink
dispatch _promote_batch_shape_expanded to Independent (#1630)
Browse files Browse the repository at this point in the history
* dispatch _promote_batch_shape_expanded to Independent

* formatting

* union type workaround

* linting

* separate register

* add test of scan plate mask

* lint with flake8 and black

* isort

* _promote_batch_shape_independent and log_density in test

* import sorting

* shape assertions
  • Loading branch information
deoxyribose committed Aug 28, 2023
1 parent 8958bb2 commit ca96eca
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
10 changes: 10 additions & 0 deletions numpyro/distributions/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from numpyro.distributions.distribution import (
Distribution,
ExpandedDistribution,
Independent,
MaskedDistribution,
Unit,
)
Expand Down Expand Up @@ -572,6 +573,15 @@ def _promote_batch_shape_masked(d: MaskedDistribution):
return new_self


@promote_batch_shape.register
def _promote_batch_shape_independent(d: Independent):
new_self = copy.copy(d)
new_base_dist = promote_batch_shape(d.base_dist)
new_self._batch_shape = new_base_dist.batch_shape[: d.event_dim]
new_self.base_dist = new_base_dist
return new_self


@promote_batch_shape.register
def _promote_batch_shape_unit(d: Unit):
return d
35 changes: 33 additions & 2 deletions test/contrib/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import numpyro
from numpyro.contrib.control_flow import cond, scan
import numpyro.distributions as dist
from numpyro.handlers import seed, substitute, trace
from numpyro.handlers import mask, seed, substitute, trace
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
from numpyro.infer.util import potential_energy
from numpyro.infer.util import log_density, potential_energy


def test_scan():
Expand Down Expand Up @@ -210,3 +210,34 @@ def transition_fn(c, val):
tr = numpyro.handlers.trace(model).get_trace()
assert tr["x"]["value"].shape == (10, 1)
assert tr["x"]["fn"].log_prob(tr["x"]["value"]).shape == (10, 3)


def test_scan_plate_mask():
def model(y=None, T=10):
def transition(carry, y_curr):
x_prev, t = carry
with numpyro.plate("N", 10, dim=-1):
with mask(mask=(t < T)):
x_curr = numpyro.sample(
"x",
dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1),
)
y_curr = numpyro.sample(
"y",
dist.Normal(x_curr, jnp.ones((10, 3))).to_event(1),
obs=y_curr,
)
return (x_curr, t + 1), None

x0 = numpyro.sample(
"x_0", dist.Normal(jnp.zeros((10, 3)), jnp.ones((10, 3))).to_event(1)
)

x, t = scan(transition, (x0, 0), y, length=T)
return (x, y)

with numpyro.handlers.seed(rng_seed=0):
model_density, model_trace = log_density(model, (None, 10), {}, {})
assert model_density
assert model_trace["x"]["fn"].batch_shape == (10,)
assert model_trace["x"]["fn"].event_shape == (3,)

0 comments on commit ca96eca

Please sign in to comment.