Skip to content

Commit

Permalink
Fix provenance for jax 0.4.4 (#1543)
Browse files Browse the repository at this point in the history
* fix provenance after jax 0.4.4

* fix typo

* run black

* clean up args_struct

* Fix track_nonreparam logic

* only substitute reparam sites

* revise compute_downstream_costs tests

* fix typo in test_compute_downstream

* add debug code

* add repr for provenance array

* revise eval_provenance logic

* Add comments for provenance logics

* add debug info to match eval_shape functionality

* fix typo for pjit rule

* address comments

* revert get_latents logic

* use frozenset().union
  • Loading branch information
fehiepsi committed Mar 6, 2023
1 parent 7789f39 commit ac05f92
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 241 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:
pip freeze
- name: Test with pytest
run: |
CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/
CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/
test-inference:
Expand Down Expand Up @@ -103,6 +103,7 @@ jobs:
run: |
pytest -vs --durations=20 test/infer/test_mcmc.py
pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py
pytest -vs --durations=20 test/contrib
- name: Test x64
run: |
JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64
Expand Down
127 changes: 47 additions & 80 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
from numpyro.distributions import ExpandedDistribution, MaskedDistribution
from numpyro.distributions.kl import kl_divergence
from numpyro.distributions.util import scale_and_mask
from numpyro.handlers import Messenger, replay, seed, substitute, trace
from numpyro.handlers import replay, seed, substitute, trace
from numpyro.infer.util import (
_without_rsample_stop_gradient,
get_importance_trace,
is_identically_one,
log_density,
)
from numpyro.ops.provenance import eval_provenance, get_provenance
from numpyro.ops.provenance import eval_provenance
from numpyro.util import _validate_model, check_model_guide_match, find_stack_level


Expand Down Expand Up @@ -535,63 +535,6 @@ def _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes):
return downstream_costs, downstream_guide_cost_nodes


class track_nonreparam(Messenger):
"""
Track non-reparameterizable sample sites. Intended to be used with ``eval_provenance``.
**References:**
1. *Nonstandard Interpretations of Probabilistic Programs for Efficient Inference*,
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind
**Example:**
.. doctest::
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer.elbo import track_nonreparam
>>> from numpyro.ops.provenance import eval_provenance, get_provenance
>>> from numpyro.handlers import seed, trace
>>> def model():
... probs_a = jnp.array([0.3, 0.7])
... probs_b = jnp.array([[0.1, 0.9], [0.8, 0.2]])
... probs_c = jnp.array([[0.5, 0.5], [0.6, 0.4]])
... a = numpyro.sample("a", dist.Categorical(probs_a))
... b = numpyro.sample("b", dist.Categorical(probs_b[a]))
... numpyro.sample("c", dist.Categorical(probs_c[b]), obs=jnp.array(0))
>>> def get_log_probs():
... seeded_model = seed(model, rng_seed=0)
... model_tr = trace(seeded_model).get_trace()
... return {
... name: site["fn"].log_prob(site["value"])
... for name, site in model_tr.items()
... if site["type"] == "sample"
... }
>>> model_deps = get_provenance(eval_provenance(track_nonreparam(get_log_probs)))
>>> print(model_deps) # doctest: +SKIP
{'a': frozenset({'a'}), 'b': frozenset({'a', 'b'}), 'c': frozenset({'a', 'b'})}
"""

def postprocess_message(self, msg):
if (
msg["type"] == "sample"
and (not msg["is_observed"])
and (not msg["fn"].has_rsample)
):
new_provenance = frozenset({msg["name"]})
old_provenance = msg["value"].aval.named_shape.get(
"_provenance", frozenset()
)
msg["value"].aval.named_shape["_provenance"] = (
old_provenance | new_provenance
)


def get_importance_log_probs(model, guide, args, kwargs, params):
"""
Returns log probabilities at each site for the guide and the model that is run against it.
Expand All @@ -610,6 +553,43 @@ def get_importance_log_probs(model, guide, args, kwargs, params):
return model_log_probs, guide_log_probs


def _substitute_nonreparam(data, msg):
if msg["name"] in data and not msg["fn"].has_rsample:
value = msg["fn"](*msg["args"], **msg["kwargs"])
value = 0 * value + data[msg["name"]]
return value


def _get_latents(model, guide, args, kwargs, params):
model = seed(substitute(model, data=params), rng_seed=0)
guide = seed(substitute(guide, data=params), rng_seed=0)
guide_tr = trace(guide).get_trace(*args, **kwargs)
model_tr = trace(replay(model, guide_tr)).get_trace(*args, **kwargs)
model_tr.update(guide_tr)
return {
name: site["value"]
for name, site in model_tr.items()
if site["type"] == "sample" and not site.get("is_observed", False)
}


def get_nonreparam_deps(model, guide, args, kwargs, param_map, latents=None):
"""Find dependencies on non-reparameterizable sample sites for each cost term in the model and the guide."""
if latents is None:
latents = eval_shape(
partial(_get_latents, model, guide, args, kwargs, param_map)
)

def fn(**latents):
subs_fn = partial(_substitute_nonreparam, latents)
subs_model = substitute(seed(model, rng_seed=0), substitute_fn=subs_fn)
subs_guide = substitute(seed(guide, rng_seed=0), substitute_fn=subs_fn)
return get_importance_log_probs(subs_model, subs_guide, args, kwargs, param_map)

model_deps, guide_deps = eval_provenance(fn, **latents)
return model_deps, guide_deps


class TraceGraph_ELBO(ELBO):
"""
A TraceGraph implementation of ELBO-based SVI. The gradient estimator
Expand Down Expand Up @@ -661,8 +641,12 @@ def single_particle_elbo(rng_key):
check_model_guide_match(model_trace, guide_trace)
_validate_model(model_trace, plate_warning="strict")

latents = {}
for name, site in guide_trace.items():
if site["type"] == "sample" and not site.get("is_observed", False):
latents[name] = site["value"]
model_deps, guide_deps = get_nonreparam_deps(
seeded_model, seeded_guide, args, kwargs, param_map
model, guide, args, kwargs, param_map, latents=latents
)

elbo = 0.0
Expand Down Expand Up @@ -837,23 +821,6 @@ def _partition(model_sum_deps, sum_vars):
return components


def get_nonreparam_deps(model, guide, args, kwargs, param_map):
"""Find dependencies on non-reparameterizable sample sites for each cost term in the model and the guide."""
model_deps, guide_deps = get_provenance(
eval_provenance(
partial(
track_nonreparam(get_importance_log_probs),
model,
guide,
args,
kwargs,
param_map,
)
)
)
return model_deps, guide_deps


def guess_max_plate_nesting(model, guide, args, kwargs, param_map):
"""Guess maximum plate nesting by performing jax shape inference."""
model_shapes, guide_shapes = eval_shape(
Expand Down Expand Up @@ -917,15 +884,15 @@ def single_particle_elbo(rng_key):
if self.max_plate_nesting == float("inf"):
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
# XXX: We can extract abstract latents here such that they
# can be reused in get_nonreparam_deps below.
self.max_plate_nesting = guess_max_plate_nesting(
seeded_model, seeded_guide, args, kwargs, param_map
)

seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
# get dependencies on nonreparametrizable variables
model_deps, guide_deps = get_nonreparam_deps(
seeded_model, seeded_guide, args, kwargs, param_map
model, guide, args, kwargs, param_map
)
# get descendants of variables in the guide
guide_desc = defaultdict(frozenset)
Expand Down
27 changes: 9 additions & 18 deletions numpyro/infer/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer.initialization import init_to_sample
from numpyro.ops.provenance import ProvenanceArray, eval_provenance, get_provenance
from numpyro.ops.provenance import eval_provenance
from numpyro.ops.pytree import PytreeTrace


Expand Down Expand Up @@ -55,7 +55,7 @@ def get_trace():
return jax.eval_shape(get_trace).trace


def _get_log_probs(model, model_args, model_kwargs, sample):
def _get_log_probs(model, model_args, model_kwargs, **sample):
# Note: We use seed 0 for parameter initialization.
with handlers.trace() as tr, handlers.seed(rng_seed=0), handlers.substitute(
data=sample
Expand Down Expand Up @@ -199,14 +199,12 @@ def model_3():

# Find direct prior dependencies among latent and observed sites.
samples = {
name: ProvenanceArray(site["value"], frozenset({name}))
name: site["value"]
for name, site in trace.items()
if site["type"] == "sample" and not site["is_observed"]
}
sample_deps = get_provenance(
eval_provenance(
partial(_get_log_probs, model, model_args, model_kwargs), samples
)
sample_deps = eval_provenance(
partial(_get_log_probs, model, model_args, model_kwargs), **samples
)
prior_dependencies = {n: {n: set()} for n in plates} # no deps yet
for i, downstream in enumerate(sample_sites):
Expand Down Expand Up @@ -360,7 +358,7 @@ def _resolve_plate_samples(plate_samples):
k: [name for name in trace if name in v] for k, v in plate_samples.items()
}

def get_log_probs(sample):
def get_log_probs(**sample):
class substitute_deterministic(handlers.substitute):
def process_message(self, msg):
if msg["type"] == "deterministic":
Expand All @@ -384,24 +382,17 @@ def process_message(self, msg):
return provenance_arrays

samples = {
name: ProvenanceArray(site["value"], frozenset({name}))
name: site["value"]
for name, site in trace.items()
if (site["type"] == "sample" and not site["is_observed"])
or site["type"] == "deterministic"
}

params = {
name: jax.tree_util.tree_map(
lambda x: ProvenanceArray(x, frozenset({name})), site["value"]
)
for name, site in trace.items()
if site["type"] == "param"
name: site["value"] for name, site in trace.items() if site["type"] == "param"
}

sample_and_params = {**samples, **params}
sample_params_deps = get_provenance(
eval_provenance(get_log_probs, sample_and_params)
)
sample_params_deps = eval_provenance(get_log_probs, **samples, **params)

sample_sample = {}
sample_param = {}
Expand Down

0 comments on commit ac05f92

Please sign in to comment.