Skip to content

Commit

Permalink
Document InferDict and Message typed dicts (#3323)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Feb 9, 2024
1 parent 6337ced commit ca2f93c
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,48 @@
class InferDict(TypedDict, total=False):
"""
A dictionary that contains information about inference.
This can be used to configure per-site inference strategies, e.g.::
pyro.sample(
"x",
dist.Bernoulli(0.5),
infer={"enumerate": "parallel"},
)
Keys:
enumerate (str):
If one of the strings "sequential" or "parallel", enables
enumeration. Parallel enumeration is generally faster but requires
broadcasting-safe operations and static structure.
expand (bool):
Whether to expand the distribution during enumeration. Defaults to
False if missing.
is_auxiliary (bool):
Whether the sample site is auxiliary, e.g. for use in guides that
deterministically transform auxiliary variables. Defaults to False
if missing.
is_observed (bool):
Whether the sample site is observed (i.e. not latent). Defaults to
False if missing.
num_samples (int):
The number of samples to draw. Defaults to 1 if missing.
obs (optional torch.Tensor):
The observed value, or None for latent variables. Defaults to None
if missing.
prior (optional torch.distributions.Distribution):
(internal) For use in GuideMessenger to store the model's prior
distribution (conditioned on upstream sites).
tmc (str):
Whether to use the diagonal or mixture approximation for Tensor
Monte Carlo in TraceTMC_ELBO.
was_observed (bool):
(internal) Whether the sample site was originally observed, in the
context of inference via Reweighted Wake Sleep or Compiled
Sequential Importance Sampling.
"""

enumerate: Literal["sequential", "parallel"]
expand: bool
is_auxiliary: bool
is_observed: bool
Expand All @@ -66,6 +106,58 @@ class InferDict(TypedDict, total=False):


class Message(TypedDict, Generic[_P, _T], total=False):
"""
Pyro's internal message type for effect handling.
Messages are stored in trace objects, e.g.::
trace.nodes["my_site_name"] # This is a Message.
Keys:
type (str):
The message type, typically one of the strings "sample", "param",
"plate", or "markov", but possibly custom.
name (str):
The site name, typically naming a sample or parameter.
fn (callable):
The distribution or function used to generate the sample.
is_observed (bool):
A flag to indicate whether the value is observed.
args (tuple):
Positional arguments to the distribution or function.
kwargs (dict):
Keyword arguments to the distribution or function.
value (torch.Tensor):
The value of the sample (either observed or sampled).
scale (torch.Tensor):
A scaling factor for the log probability.
mask (bool torch.Tensor):
A bool or tensor to mask the log probability.
cond_indep_stack (tuple):
The site's local stack of conditional independence metadata.
Immutable.
done (bool):
A flag to indicate whether the message has been handled.
stop (bool):
A flag to stop further processing of the message.
continuation (callable):
A function to call after processing the message.
infer (optional InferDict):
A dictionary of inference parameters.
obs (torch.Tensor):
The observed value.
log_prob (torch.Tensor):
The log probability of the sample.
log_prob_sum (torch.Tensor):
The sum of the log probability.
unscaled_log_prob (torch.Tensor):
The unscaled log probability.
score_parts (pyro.distributions.ScoreParts):
A collection of score parts.
packed (Message):
A packed message, used during enumeration.
"""

type: str
name: Optional[str]
fn: Callable[_P, _T]
Expand Down

0 comments on commit ca2f93c

Please sign in to comment.