Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/primitives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ plate_stack
-----------
.. autofunction:: numpyro.primitives.plate_stack

label_event_dim
---------------
.. autofunction:: numpyro.primitives.label_event_dim

subsample
---------
.. autofunction:: numpyro.primitives.subsample
Expand Down
4 changes: 4 additions & 0 deletions docs/source/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ get_model_relations
^^^^^^^^^^^^^^^^^^^
.. autofunction:: numpyro.infer.inspect.get_model_relations

get_site_dims
^^^^^^^^^^^^^
.. autofunction:: numpyro.infer.inspect.get_site_dims

Visualization Utilities
=======================

Expand Down
2 changes: 2 additions & 0 deletions numpyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
deterministic,
factor,
get_mask,
label_event_dim,
module,
param,
plate,
Expand Down Expand Up @@ -50,6 +51,7 @@ def _filter_absl_cpu_warning(record):
"get_mask",
"handlers",
"infer",
"label_event_dim",
"module",
"ops",
"optim",
Expand Down
75 changes: 74 additions & 1 deletion numpyro/infer/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import partial
import itertools
from pathlib import Path
from typing import Optional
from typing import Dict, List, Optional

import jax

Expand Down Expand Up @@ -645,6 +645,79 @@ def render_model(
return graph


def get_site_dims(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming ideas appreciated, ie maybe it should be get_named_dims instead?

model, model_args=None, model_kwargs=None
) -> Dict[str, Dict[str, List[str]]]:
"""Infers named dimensions based on plates and label_event_dim's for sample and deterministic
sites. This returns a nested dictionary, where each site has a dictionary containing a list of
batch_dims and event_dims

For example for the model::

def model(X, z, y):
n_groups = len(np.unique(z))
m = numpyro.sample('m', dist.Normal(0, 1))
sd = numpyro.sample('sd', dist.LogNormal(m, 1))
with numpyro.label_event_dim("groups", n_groups):
gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(n_groups,)))

with numpyro.plate('N', len(X)):
mu = numpyro.deterministic("mu", m + gamma[z])
numpyro.sample('obs', dist.Normal(m, sd), obs=y)

the site dims are::

{'gamma': {'batch_dims': [], 'event_dims': ["groups"]},
'mu': {'batch_dims': ["N"], 'event_dims': []},
'obs': {'batch_dims': ["N"], 'event_dims': []}}

:param callable model: A model to inspect.
:param model_args: Optional tuple of model args.
:param model_kwargs: Optional dict of model kwargs.
:rtype: dict
"""
model_args = model_args or ()
model_kwargs = model_kwargs or {}

def _get_dist_name(fn):
if isinstance(
fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)
):
return _get_dist_name(fn.base_dist)
return type(fn).__name__

def get_trace():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the same as the get_trace function in get_model_relations. I could externalize it and call that within get_model_relations and here to reduce code (at the cost of making these 2 inspection tools share a dependency)

# We use `init_to_sample` to get around ImproperUniform distribution,
# which does not have `sample` method.
subs_model = handlers.substitute(
handlers.seed(model, 0),
substitute_fn=init_to_sample,
)
trace = handlers.trace(subs_model).get_trace(*model_args, **model_kwargs)
# Work around an issue where jax.eval_shape does not work
# for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
# Here we will remove `fn` and store its name in the trace.
for name, site in trace.items():
if site["type"] == "sample":
site["fn_name"] = _get_dist_name(site.pop("fn"))
elif site["type"] == "deterministic":
site["fn_name"] = "Deterministic"
return PytreeTrace(trace)

# We use eval_shape to avoid any array computation.
trace = jax.eval_shape(get_trace).trace

named_dims = {}

for name, site in trace.items():
batch_dims = [frame.name for frame in site["cond_indep_stack"]]
event_dims = [frame.name for frame in site.get("dep_stack", [])]
if site["type"] in ["sample", "deterministic"] and (batch_dims or event_dims):
named_dims[name] = {"batch_dims": batch_dims, "event_dims": event_dims}

return named_dims


__all__ = [
"get_dependencies",
"get_model_relations",
Expand Down
102 changes: 101 additions & 1 deletion numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import ExitStack, contextmanager
import functools
from types import TracebackType
from typing import Any, Callable, Generator, Optional, Union, cast
from typing import Any, Callable, Generator, List, Optional, Tuple, Union, cast
import warnings

import jax
Expand All @@ -24,6 +24,7 @@


CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "dim", "size"])
DepStackFrame = namedtuple("DepStackFrame", ["name", "dim", "size"])


def default_process_message(msg: Message) -> None:
Expand Down Expand Up @@ -649,6 +650,105 @@ def plate_stack(
yield


class label_event_dim(plate):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be good to break the dependence on plate and inherit Messenger instead - this was a quick hack to get things working

"""This labels event dimensions, modeled after numpyro.plate. Unlike numpyro.plate, it will not
change the shape of any sites within its context.

Labelled event dims can be found in `dep_stack` in the model trace.

**example**

with label_event_dim("groups", n_groups):
alpha = numpyro.sample("alpha", dist.ZeroSumNormal(1, event_shape=(n_groups,)))
"""

def __init__(
self,
name: str,
size: int,
subsample_size: Optional[int] = None,
Copy link
Contributor Author

@kylejcaron kylejcaron Mar 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still considering if size and subsample_size are needed. right now they dont do anything, but I could potentially use size for some form of shape validation

dim: Optional[int] = None,
) -> None:
self.name = name
self.size = size
if dim is not None and dim >= 0:
raise ValueError("dim arg must be negative.")
self.dim, self._indices = self._subsample(
self.name, self.size, subsample_size, dim
)
self.subsample_size = self._indices.shape[0]

# We'll try only adding our pseudoplate to the CondIndepStack without doing anything else
def process_message(self, msg: Message) -> None:
if msg["type"] not in ("param", "sample", "plate", "deterministic"):
if msg["type"] == "control_flow":
raise NotImplementedError(
"Cannot use control flow primitive under a `plate` primitive."
" Please move those `plate` statements into the control flow"
" body function. See `scan` documentation for more information."
)
return

if (
"block_plates" in msg.get("infer", {})
and self.name in msg["infer"]["block_plates"]
):
return

frame = DepStackFrame(self.name, self.dim, self.subsample_size)
msg["dep_stack"] = msg.get("dep_stack", []) + [frame]

if msg["type"] == "deterministic":
return

def _get_event_shape(self, dep_stack: List[DepStackFrame]) -> Tuple[int, ...]:
n_dims = max(-f.dim for f in dep_stack)
event_shape = [1] * n_dims
for f in dep_stack:
event_shape[f.dim] = f.size
return tuple(event_shape)

# We need to make sure dims get arranged properly when there are multiple plates
@staticmethod
def _subsample(name, size, subsample_size, dim):
msg = {
"type": "plate",
"fn": _subsample_fn,
"name": name,
"args": (size, subsample_size),
"kwargs": {"rng_key": None},
"value": (
None
if (subsample_size is not None and size != subsample_size)
else jnp.arange(size)
),
"scale": 1.0,
"cond_indep_stack": [],
"dep_stack": [],
}
apply_stack(msg)
subsample = msg["value"]
subsample_size = msg["args"][1]
if subsample_size is not None and subsample_size != subsample.shape[0]:
warnings.warn(
"subsample_size does not match len(subsample), {} vs {}.".format(
subsample_size, len(subsample)
)
+ " Did you accidentally use different subsample_size in the model and guide?",
stacklevel=find_stack_level(),
)
dep_stack = msg["dep_stack"]
occupied_dims = {f.dim for f in dep_stack}
if dim is None:
new_dim = -1
while new_dim in occupied_dims:
new_dim -= 1
dim = new_dim
else:
assert dim not in occupied_dims
return dim, subsample


def factor(name: str, log_factor: ArrayLike) -> None:
"""
Factor statement to add arbitrary log probability factor to a
Expand Down
108 changes: 107 additions & 1 deletion test/infer/test_inspect.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from functools import partial

import numpy as np

import jax
import jax.numpy as jnp

import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer.inspect import get_dependencies
from numpyro.infer.initialization import init_to_sample
from numpyro.infer.inspect import get_dependencies, get_site_dims
from numpyro.ops.pytree import PytreeTrace


class NonreparameterizedNormal(dist.Normal):
Expand Down Expand Up @@ -434,3 +439,104 @@ def model():
},
}
assert actual == expected


def test_label_event_dim():
def get_trace(model):
# We use `init_to_sample` to get around ImproperUniform distribution,
# which does not have `sample` method.

def _get_dist_name(fn):
if isinstance(
fn,
(dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution),
):
return _get_dist_name(fn.base_dist)
return type(fn).__name__

subs_model = handlers.substitute(
handlers.seed(model, 0),
substitute_fn=init_to_sample,
)
trace = handlers.trace(subs_model).get_trace()
# Work around an issue where jax.eval_shape does not work
# for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
# Here we will remove `fn` and store its name in the trace.
for _, site in trace.items():
if site["type"] == "sample":
site["fn_name"] = _get_dist_name(site.pop("fn"))
elif site["type"] == "deterministic":
site["fn_name"] = "Deterministic"
return PytreeTrace(trace)

def model1():
with numpyro.label_event_dim("dim1", 10):
param = numpyro.sample("param", dist.ZeroSumNormal(1, event_shape=(10,)))
_ = numpyro.deterministic("transformed_param", param + 1)

def model2():
with numpyro.label_event_dim("dim2", 5), numpyro.label_event_dim("dim1", 10):
numpyro.sample("param", dist.ZeroSumNormal(1, event_shape=(10, 5)))

def model3():
with numpyro.plate("dim1", 5), numpyro.label_event_dim("dim2", 10):
numpyro.sample("param", dist.ZeroSumNormal(1, event_shape=(10,)))

# expected dims, dim names
expected_results = [
([-1], ["dim1"]),
([-2, -1], ["dim1", "dim2"]),
([-1], ["dim2"]),
]
for model, (exp_dims, exp_names) in zip([model1, model2, model3], expected_results):
trace = jax.eval_shape(partial(get_trace, model)).trace
sites = {
name: site
for name, site in trace.items()
if site["type"] in ["sample", "deterministic"]
}
for _, metadata in sites.items():
# make sure the new event dim stack is present
assert (
"dep_stack" in metadata.keys() and "cond_indep_stack" in metadata.keys()
)

# make sure the dims are as expected
assert exp_dims == [plate.dim for plate in metadata["dep_stack"]]

# make sure the dim names are as expected
assert exp_names == [plate.name for plate in metadata["dep_stack"]]


def test_get_site_dims():
def model1():
with numpyro.label_event_dim("dim1", 10):
param = numpyro.sample("param", dist.ZeroSumNormal(1, event_shape=(10,)))
_ = numpyro.deterministic("transformed_param", param + 1)
with numpyro.plate("obs_idx", 3):
_ = numpyro.sample(
"obs", dist.Normal(0, 1), obs=jnp.array([-1.0, 0.0, 1.0])
)

def model2():
_ = numpyro.sample("unplated_param", dist.Normal(0, 1))
with numpyro.label_event_dim("dim2", 5), numpyro.label_event_dim("dim1", 10):
numpyro.sample("param", dist.ZeroSumNormal(1, event_shape=(10, 5)))

def model3():
with numpyro.plate("dim1", 5), numpyro.label_event_dim("dim2", 10):
numpyro.sample("param", dist.ZeroSumNormal(1, event_shape=(10,)))

expected_results = [
{
"param": {"batch_dims": [], "event_dims": ["dim1"]},
"transformed_param": {"batch_dims": [], "event_dims": ["dim1"]},
"obs": {"batch_dims": ["obs_idx"], "event_dims": []},
},
{"param": {"batch_dims": [], "event_dims": ["dim1", "dim2"]}},
{"param": {"batch_dims": ["dim1"], "event_dims": ["dim2"]}},
]

for model, expected in zip([model1, model2, model3], expected_results):
site_dims = get_site_dims(model)
assert site_dims == expected
Loading