Skip to content

Commit

Permalink
Allow substitute deterministic sites (#1664)
Browse files Browse the repository at this point in the history
* allow substitute deterministic sites

* fix lint

* pop deterministic sites from posterior

* handle deterministic sites in cond
  • Loading branch information
fehiepsi committed Oct 21, 2023
1 parent 065aa43 commit 2416eb9
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 12 deletions.
5 changes: 2 additions & 3 deletions examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
from sklearn.model_selection import train_test_split

import jax
from jax import random
import jax.numpy as jnp

Expand Down Expand Up @@ -195,9 +196,7 @@ def main(args):


if __name__ == "__main__":
from jax.config import config

config.update("jax_debug_nans", True)
jax.config.update("jax_debug_nans", True)

parser = argparse.ArgumentParser()
parser.add_argument("--subsample-size", type=int, default=100)
Expand Down
2 changes: 2 additions & 0 deletions numpyro/contrib/control_flow/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def _subs_wrapper(subs_map, site):
if isinstance(subs_map, dict) and site["name"] in subs_map:
return subs_map[site["name"]]
elif callable(subs_map):
if site["type"] == "deterministic":
return subs_map(site)
rng_key = site["kwargs"].get("rng_key")
subs_map = (
handlers.seed(subs_map, rng_seed=rng_key)
Expand Down
6 changes: 3 additions & 3 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,9 +794,9 @@ def __init__(self, fn=None, data=None, substitute_fn=None):
super(substitute, self).__init__(fn)

def process_message(self, msg):
if (msg["type"] not in ("sample", "param", "mutable", "plate")) or msg.get(
"_control_flow_done", False
):
if (
msg["type"] not in ("sample", "param", "mutable", "plate", "deterministic")
) or msg.get("_control_flow_done", False):
if msg["type"] == "control_flow":
if self.data is not None:
msg["kwargs"]["substitute_stack"].append(("substitute", self.data))
Expand Down
12 changes: 7 additions & 5 deletions numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from jax.api_util import flatten_fun, shaped_abstractify
import jax.core as core
from jax.experimental.pjit import pjit_p
import jax.util as util

try:
import jax.extend.linear_util as lu
except ImportError:
import jax.linear_util as lu

from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
from jax.interpreters.pxla import xla_pmap_p
import jax.numpy as jnp
Expand Down Expand Up @@ -40,7 +42,7 @@ def eval_provenance(fn, **kwargs):
args, in_tree = jax.tree_util.tree_flatten(((), kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fn), in_tree)
# Abstract eval to get output pytree
avals = core.safe_map(shaped_abstractify, args)
avals = util.safe_map(shaped_abstractify, args)
# XXX: we split out the process of abstract evaluation and provenance tracking
# for simplicity. In principle, they can be merged so that we only need to walk
# through the equations once.
Expand Down Expand Up @@ -81,14 +83,14 @@ def write(v, p):
return
env[v] = read(v) | p

core.safe_map(write, jaxpr.invars, provenance_inputs)
util.safe_map(write, jaxpr.invars, provenance_inputs)
for eqn in jaxpr.eqns:
provenance_inputs = core.safe_map(read, eqn.invars)
provenance_inputs = util.safe_map(read, eqn.invars)
rule = track_deps_rules.get(eqn.primitive, _default_track_deps_rules)
provenance_outputs = rule(eqn, provenance_inputs)
core.safe_map(write, eqn.outvars, provenance_outputs)
util.safe_map(write, eqn.outvars, provenance_outputs)

return core.safe_map(read, jaxpr.outvars)
return util.safe_map(read, jaxpr.outvars)


track_deps_rules = {}
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import os

from jax.config import config
from jax import config

from numpyro.util import set_rng_seed

Expand Down
1 change: 1 addition & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def model(y=None):
random.PRNGKey(0), params, sample_shape=(1000,)
)

posterior_samples.pop("z")
predictive = Predictive(model, posterior_samples, params=params)
predictive_samples = predictive(random.PRNGKey(0), y_test)

Expand Down

0 comments on commit 2416eb9

Please sign in to comment.