Skip to content

Commit

Permalink
promote shapes of scanned values (#1444)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Jun 28, 2022
1 parent 6ca6299 commit 9fd29ab
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
10 changes: 9 additions & 1 deletion numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,15 @@ def body_fn(wrapped_carry, x):
return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

wrapped_carry = device_put((0, rng_key, init))
return lax.scan(body_fn, wrapped_carry, xs, length=length, reverse=reverse)
last_carry, (pytree_trace, ys) = lax.scan(
body_fn, wrapped_carry, xs, length=length, reverse=reverse
)
for name, site in pytree_trace.trace.items():
if site["type"] != "sample":
continue
# we haven't promote shapes of values yet during `lax.scan`, so we do it here
site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"])
return last_carry, (pytree_trace, ys)


def scan(f, init, xs, length=None, reverse=False, history=1):
Expand Down
14 changes: 14 additions & 0 deletions test/contrib/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,17 @@ def false_fun(_):
atol=0.1,
)
assert_allclose([x.mean(), x.std()], [2.0, jnp.sqrt(5.0)], atol=0.5)


def test_scan_promote():
def model():
def transition_fn(c, val):
with numpyro.plate("N", 3, dim=-1):
numpyro.sample("x", dist.Normal(0, 1), obs=1.0)
return None, None

scan(transition_fn, None, None, length=10)

tr = numpyro.handlers.trace(model).get_trace()
assert tr["x"]["value"].shape == (10, 1)
assert tr["x"]["fn"].log_prob(tr["x"]["value"]).shape == (10, 3)

0 comments on commit 9fd29ab

Please sign in to comment.