-
Notifications
You must be signed in to change notification settings - Fork 269
[draft] Event dim labelling primitive for arviz dim labelling #2012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1f508c3
c596cf5
8a61990
67ed1fa
87bdb9b
184c2c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
from functools import partial | ||
import itertools | ||
from pathlib import Path | ||
from typing import Optional | ||
from typing import Dict, List, Optional | ||
|
||
import jax | ||
|
||
|
@@ -645,6 +645,79 @@ def render_model( | |
return graph | ||
|
||
|
||
def get_site_dims( | ||
model, model_args=None, model_kwargs=None | ||
) -> Dict[str, Dict[str, List[str]]]: | ||
"""Infers named dimensions based on plates and label_event_dim's for sample and deterministic | ||
sites. This returns a nested dictionary, where each site has a dictionary containing a list of | ||
batch_dims and event_dims | ||
|
||
For example for the model:: | ||
|
||
def model(X, z, y): | ||
n_groups = len(np.unique(z)) | ||
m = numpyro.sample('m', dist.Normal(0, 1)) | ||
sd = numpyro.sample('sd', dist.LogNormal(m, 1)) | ||
with numpyro.label_event_dim("groups", n_groups): | ||
gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(n_groups,))) | ||
|
||
with numpyro.plate('N', len(X)): | ||
mu = numpyro.deterministic("mu", m + gamma[z]) | ||
numpyro.sample('obs', dist.Normal(m, sd), obs=y) | ||
|
||
the site dims are:: | ||
|
||
{'gamma': {'batch_dims': [], 'event_dims': ["groups"]}, | ||
'mu': {'batch_dims': ["N"], 'event_dims': []}, | ||
'obs': {'batch_dims': ["N"], 'event_dims': []}} | ||
|
||
:param callable model: A model to inspect. | ||
:param model_args: Optional tuple of model args. | ||
:param model_kwargs: Optional dict of model kwargs. | ||
:rtype: dict | ||
""" | ||
model_args = model_args or () | ||
model_kwargs = model_kwargs or {} | ||
|
||
def _get_dist_name(fn): | ||
if isinstance( | ||
fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution) | ||
): | ||
return _get_dist_name(fn.base_dist) | ||
return type(fn).__name__ | ||
|
||
def get_trace(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the same as the get_trace function in |
||
# We use `init_to_sample` to get around ImproperUniform distribution, | ||
# which does not have `sample` method. | ||
subs_model = handlers.substitute( | ||
handlers.seed(model, 0), | ||
substitute_fn=init_to_sample, | ||
) | ||
trace = handlers.trace(subs_model).get_trace(*model_args, **model_kwargs) | ||
# Work around an issue where jax.eval_shape does not work | ||
# for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`) | ||
# Here we will remove `fn` and store its name in the trace. | ||
for name, site in trace.items(): | ||
if site["type"] == "sample": | ||
site["fn_name"] = _get_dist_name(site.pop("fn")) | ||
elif site["type"] == "deterministic": | ||
site["fn_name"] = "Deterministic" | ||
return PytreeTrace(trace) | ||
|
||
# We use eval_shape to avoid any array computation. | ||
trace = jax.eval_shape(get_trace).trace | ||
|
||
named_dims = {} | ||
|
||
for name, site in trace.items(): | ||
batch_dims = [frame.name for frame in site["cond_indep_stack"]] | ||
event_dims = [frame.name for frame in site.get("dep_stack", [])] | ||
if site["type"] in ["sample", "deterministic"] and (batch_dims or event_dims): | ||
named_dims[name] = {"batch_dims": batch_dims, "event_dims": event_dims} | ||
|
||
return named_dims | ||
|
||
|
||
__all__ = [ | ||
"get_dependencies", | ||
"get_model_relations", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
from contextlib import ExitStack, contextmanager | ||
import functools | ||
from types import TracebackType | ||
from typing import Any, Callable, Generator, Optional, Union, cast | ||
from typing import Any, Callable, Generator, List, Optional, Tuple, Union, cast | ||
import warnings | ||
|
||
import jax | ||
|
@@ -24,6 +24,7 @@ | |
|
||
|
||
CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "dim", "size"]) | ||
DepStackFrame = namedtuple("DepStackFrame", ["name", "dim", "size"]) | ||
|
||
|
||
def default_process_message(msg: Message) -> None: | ||
|
@@ -649,6 +650,105 @@ def plate_stack( | |
yield | ||
|
||
|
||
class label_event_dim(plate): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might be good to break the dependence on |
||
"""This labels event dimensions, modeled after numpyro.plate. Unlike numpyro.plate, it will not | ||
change the shape of any sites within its context. | ||
|
||
Labelled event dims can be found in `dep_stack` in the model trace. | ||
|
||
**example** | ||
|
||
with label_event_dim("groups", n_groups): | ||
alpha = numpyro.sample("alpha", dist.ZeroSumNormal(1, event_shape=(n_groups,))) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
name: str, | ||
size: int, | ||
subsample_size: Optional[int] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. still considering if |
||
dim: Optional[int] = None, | ||
) -> None: | ||
self.name = name | ||
self.size = size | ||
if dim is not None and dim >= 0: | ||
raise ValueError("dim arg must be negative.") | ||
self.dim, self._indices = self._subsample( | ||
self.name, self.size, subsample_size, dim | ||
) | ||
self.subsample_size = self._indices.shape[0] | ||
|
||
# We'll try only adding our pseudoplate to the CondIndepStack without doing anything else | ||
def process_message(self, msg: Message) -> None: | ||
if msg["type"] not in ("param", "sample", "plate", "deterministic"): | ||
if msg["type"] == "control_flow": | ||
raise NotImplementedError( | ||
"Cannot use control flow primitive under a `plate` primitive." | ||
" Please move those `plate` statements into the control flow" | ||
" body function. See `scan` documentation for more information." | ||
) | ||
return | ||
|
||
if ( | ||
"block_plates" in msg.get("infer", {}) | ||
and self.name in msg["infer"]["block_plates"] | ||
): | ||
return | ||
|
||
frame = DepStackFrame(self.name, self.dim, self.subsample_size) | ||
msg["dep_stack"] = msg.get("dep_stack", []) + [frame] | ||
|
||
if msg["type"] == "deterministic": | ||
return | ||
|
||
def _get_event_shape(self, dep_stack: List[DepStackFrame]) -> Tuple[int, ...]: | ||
n_dims = max(-f.dim for f in dep_stack) | ||
event_shape = [1] * n_dims | ||
for f in dep_stack: | ||
event_shape[f.dim] = f.size | ||
return tuple(event_shape) | ||
|
||
# We need to make sure dims get arranged properly when there are multiple plates | ||
@staticmethod | ||
def _subsample(name, size, subsample_size, dim): | ||
msg = { | ||
"type": "plate", | ||
"fn": _subsample_fn, | ||
"name": name, | ||
"args": (size, subsample_size), | ||
"kwargs": {"rng_key": None}, | ||
"value": ( | ||
None | ||
if (subsample_size is not None and size != subsample_size) | ||
else jnp.arange(size) | ||
), | ||
"scale": 1.0, | ||
"cond_indep_stack": [], | ||
"dep_stack": [], | ||
} | ||
apply_stack(msg) | ||
subsample = msg["value"] | ||
subsample_size = msg["args"][1] | ||
if subsample_size is not None and subsample_size != subsample.shape[0]: | ||
warnings.warn( | ||
"subsample_size does not match len(subsample), {} vs {}.".format( | ||
subsample_size, len(subsample) | ||
) | ||
+ " Did you accidentally use different subsample_size in the model and guide?", | ||
stacklevel=find_stack_level(), | ||
) | ||
dep_stack = msg["dep_stack"] | ||
occupied_dims = {f.dim for f in dep_stack} | ||
if dim is None: | ||
new_dim = -1 | ||
while new_dim in occupied_dims: | ||
new_dim -= 1 | ||
dim = new_dim | ||
else: | ||
assert dim not in occupied_dims | ||
return dim, subsample | ||
|
||
|
||
def factor(name: str, log_factor: ArrayLike) -> None: | ||
""" | ||
Factor statement to add arbitrary log probability factor to a | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming ideas appreciated, ie maybe it should be
get_named_dims
instead?