Skip to content

Commit

Permalink
Consider time history=0 as plate (#1443)
Browse files Browse the repository at this point in the history
* Consider time history=0 as plate

* lint
  • Loading branch information
fehiepsi committed Jul 1, 2022
1 parent 9fd29ab commit e10bf59
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
2 changes: 2 additions & 0 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@


def _subs_wrapper(subs_map, i, length, site):
if site["type"] != "sample":
return
value = None
if isinstance(subs_map, dict) and site["name"] in subs_map:
value = subs_map[site["name"]]
Expand Down
8 changes: 6 additions & 2 deletions numpyro/contrib/funsor/infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
time_to_init_vars = defaultdict(frozenset) # PP... variables
time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites
sum_vars, prod_vars = frozenset(), frozenset()
history = 1
history = 0
log_measures = {}
for site in model_trace.values():
if site["type"] == "sample":
Expand All @@ -186,10 +186,14 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
for dim, name in dim_to_name.items():
if name.startswith("_time"):
time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]])
time_to_factors[time_dim].append(log_prob_factor)
history = max(
history, max(_get_shift(s) for s in dim_to_name.values())
)
if history == 0:
log_factors.append(log_prob_factor)
prod_vars |= frozenset({name})
else:
time_to_factors[time_dim].append(log_prob_factor)
time_to_init_vars[time_dim] |= frozenset(
s for s in dim_to_name.values() if s.startswith("_PREV_")
)
Expand Down
24 changes: 24 additions & 0 deletions test/contrib/test_funsor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from numpy.testing import assert_allclose
import pytest

import jax
from jax import random
import jax.numpy as jnp

Expand Down Expand Up @@ -516,6 +517,29 @@ def transition_fn(carry, y):
assert_allclose(actual_x_curr, expected_x_curr)


def test_scan_enum_history_0():
def model(ys):
z = numpyro.sample("z", dist.Bernoulli(0.2), infer={"enumerate": "parallel"})

def transition_fn(c, y):
numpyro.sample("y", dist.Normal(z, 1), obs=y)
return None, None

scan(transition_fn, None, ys)

actual, trace = log_density(
model=enum(model, first_available_dim=-1),
model_args=(jnp.arange(3),),
model_kwargs={},
params={},
)
z_factor = trace["z"]["fn"].log_prob(trace["z"]["value"])
prev_y_factor = trace["_PREV_y"]["fn"].log_prob(trace["_PREV_y"]["value"])
y_factor = trace["y"]["fn"].log_prob(trace["y"]["value"]).sum(0)
expected = jax.nn.logsumexp(z_factor + prev_y_factor + y_factor)
assert_allclose(actual, expected)


def test_missing_plate(monkeypatch):
K, N = 3, 1000

Expand Down

0 comments on commit e10bf59

Please sign in to comment.