Skip to content

Commit

Permalink
Type annotate Trace, TraceMessenger, & pyro.poutine.guide (#3299)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Dec 3, 2023
1 parent 1b80fc2 commit 7233cf9
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 77 deletions.
18 changes: 13 additions & 5 deletions pyro/distributions/score_parts.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from collections import namedtuple
from typing import NamedTuple, Optional, Union

import torch

from pyro.distributions.util import scale_and_mask


class ScoreParts(
namedtuple("ScoreParts", ["log_prob", "score_function", "entropy_term"])
):
class ScoreParts(NamedTuple):
"""
This data structure stores terms used in stochastic gradient estimators that
combine the pathwise estimator and the score function estimator.
"""

def scale_and_mask(self, scale=1.0, mask=None):
log_prob: torch.Tensor
score_function: torch.Tensor
entropy_term: torch.Tensor

def scale_and_mask(
self,
scale: Union[float, torch.Tensor] = 1.0,
mask: Optional[torch.BoolTensor] = None,
) -> "ScoreParts":
"""
Scale and mask appropriate terms of a gradient estimator by a data multiplicity factor.
Note that the `score_function` term should not be scaled or masked.
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def infer_shapes(cls, **arg_shapes):
event_shape = torch.Size()
return batch_shape, event_shape

def expand(self, batch_shape, _instance=None):
def expand(self, batch_shape, _instance=None) -> "ExpandedDistribution":
"""
Returns a new :class:`ExpandedDistribution` instance with batch
dimensions expanded to `batch_shape`.
Expand Down
52 changes: 29 additions & 23 deletions pyro/poutine/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import Callable, Dict, Tuple, Union
from typing import Callable, Dict, Optional, Tuple, Union

import torch

import pyro.distributions as dist
from pyro.distributions.distribution import Distribution

from .trace_messenger import TraceMessenger
from .trace_struct import Trace
from .util import prune_subsample_sites, site_is_subsample
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.runtime import Message
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites, site_is_subsample


class GuideMessenger(TraceMessenger, ABC):
Expand All @@ -21,19 +21,19 @@ class GuideMessenger(TraceMessenger, ABC):
Derived classes must implement the :meth:`get_posterior` method.
"""

def __init__(self, model: Callable):
def __init__(self, model: Callable) -> None:
super().__init__()
# Do not register model as submodule
self._model = (model,)

@property
def model(self):
def model(self) -> Callable:
return self._model[0]

def __getstate__(self):
def __getstate__(self) -> Dict[str, object]:
# Avoid pickling the trace.
state = super().__getstate__()
state.pop("trace")
state = self.__dict__.copy()
del state["trace"]
return state

def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[override]
Expand All @@ -53,16 +53,19 @@ def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[
del self.args_kwargs

model_trace, guide_trace = self.get_traces()
samples = {
name: site["value"]
for name, site in model_trace.nodes.items()
if site["type"] == "sample"
}
samples = {}
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
assert isinstance(site["value"], torch.Tensor)
samples[name] = site["value"]
return samples

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
if msg["is_observed"] or site_is_subsample(msg):
return
assert isinstance(msg["name"], str)
assert isinstance(msg["fn"], TorchDistributionMixin)
assert msg["infer"] is not None
prior = msg["fn"]
msg["infer"]["prior"] = prior
posterior = self.get_posterior(msg["name"], prior)
Expand All @@ -72,17 +75,20 @@ def _pyro_sample(self, msg):
posterior = posterior.expand(prior.batch_shape)
msg["fn"] = posterior

def _pyro_post_sample(self, msg):
def _pyro_post_sample(self, msg: Message) -> None:
# Manually apply outer plates.
assert msg["infer"] is not None
prior = msg["infer"].get("prior")
if prior is not None and prior.batch_shape != msg["fn"].batch_shape:
msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape)
if prior is not None:
assert isinstance(msg["fn"], TorchDistributionMixin)
if prior.batch_shape != msg["fn"].batch_shape:
msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape)
return super()._pyro_post_sample(msg)

@abstractmethod
def get_posterior(
self, name: str, prior: Distribution
) -> Union[Distribution, torch.Tensor]:
self, name: str, prior: TorchDistributionMixin
) -> Union[TorchDistributionMixin, torch.Tensor]:
"""
Abstract method to compute a posterior distribution or sample a
posterior value given a prior distribution conditioned on upstream
Expand Down Expand Up @@ -112,7 +118,7 @@ def get_posterior(
"""
raise NotImplementedError

def upstream_value(self, name: str):
def upstream_value(self, name: str) -> Optional[torch.Tensor]:
"""
For use in :meth:`get_posterior` .
Expand Down
9 changes: 9 additions & 0 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
T = TypeVar("T")

if TYPE_CHECKING:
from pyro.distributions.score_parts import ScoreParts
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.poutine.messenger import Messenger
Expand All @@ -49,9 +50,12 @@ class InferDict(TypedDict, total=False):
is_auxiliary: bool
is_observed: bool
num_samples: int
prior: TorchDistributionMixin
tmc: Literal["diagonal", "mixture"]
_deterministic: bool
_dim_to_symbol: Dict[int, str]
_do_not_trace: bool
_enumerate_symbol: str
_markov_scope: Optional[Dict[str, int]]
_enumerate_dim: int
_dim_to_id: Dict[int, int]
Expand All @@ -74,6 +78,11 @@ class Message(TypedDict, total=False):
continuation: Optional[Callable[[Message], None]]
infer: Optional[InferDict]
obs: Optional[torch.Tensor]
log_prob: torch.Tensor
log_prob_sum: torch.Tensor
unscaled_log_prob: torch.Tensor
score_parts: ScoreParts
packed: "Message"
_intervener_id: Optional[str]


Expand Down
45 changes: 29 additions & 16 deletions pyro/poutine/trace_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
# SPDX-License-Identifier: Apache-2.0

import sys
from typing import Any, Callable, Literal, Optional

from .messenger import Messenger
from .trace_struct import Trace
from .util import site_is_subsample
from typing_extensions import Self

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import site_is_subsample

def identify_dense_edges(trace):

def identify_dense_edges(trace: Trace) -> None:
"""
Modifies a trace in-place by adding all edges based on the
`cond_indep_stack` information stored at each site.
Expand Down Expand Up @@ -63,7 +67,11 @@ class TraceMessenger(Messenger):
:returns: stochastic function decorated with a :class:`~pyro.poutine.trace_messenger.TraceMessenger`
"""

def __init__(self, graph_type=None, param_only=None):
def __init__(
self,
graph_type: Optional[Literal["flat", "dense"]] = None,
param_only: Optional[bool] = None,
) -> None:
"""
:param string graph_type: string that specifies the type of graph
to construct (currently only "flat" or "dense" supported)
Expand All @@ -79,30 +87,31 @@ def __init__(self, graph_type=None, param_only=None):
self.param_only = param_only
self.trace = Trace(graph_type=self.graph_type)

def __enter__(self):
def __enter__(self) -> Self:
self.trace = Trace(graph_type=self.graph_type)
return super().__enter__()

def __exit__(self, *args, **kwargs):
def __exit__(self, *args, **kwargs) -> None:
"""
Adds appropriate edges based on cond_indep_stack information
upon exiting the context.
"""
if self.param_only:
for node in list(self.trace.nodes.values()):
if node["type"] != "param":
assert node["name"] is not None
self.trace.remove_node(node["name"])
if self.graph_type == "dense":
identify_dense_edges(self.trace)
return super().__exit__(*args, **kwargs)

def __call__(self, fn):
def __call__(self, fn: Callable) -> "TraceHandler": # type: ignore[override]
"""
TODO docs
"""
return TraceHandler(self, fn)

def get_trace(self):
def get_trace(self) -> Trace:
"""
:returns: data structure
:rtype: pyro.poutine.Trace
Expand All @@ -112,7 +121,7 @@ def get_trace(self):
"""
return self.trace.copy()

def _reset(self):
def _reset(self) -> None:
tr = Trace(graph_type=self.graph_type)
if "_INPUT" in self.trace.nodes:
tr.add_node(
Expand All @@ -125,16 +134,19 @@ def _reset(self):
self.trace = tr
super()._reset()

def _pyro_post_sample(self, msg):
def _pyro_post_sample(self, msg: Message) -> None:
if self.param_only:
return
assert msg["name"] is not None
assert msg["infer"] is not None
if msg["infer"].get("_do_not_trace"):
assert msg["infer"].get("is_auxiliary")
assert not msg["is_observed"]
return
self.trace.add_node(msg["name"], **msg.copy())

def _pyro_post_param(self, msg):
def _pyro_post_param(self, msg: Message) -> None:
assert msg["name"] is not None
self.trace.add_node(msg["name"], **msg.copy())


Expand All @@ -150,11 +162,11 @@ class TraceHandler:
We can also use this for visualization.
"""

def __init__(self, msngr, fn):
def __init__(self, msngr: TraceMessenger, fn: Callable):
self.fn = fn
self.msngr = msngr

def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> Any:
"""
Runs the stochastic function stored in this poutine,
with additional side effects.
Expand All @@ -175,6 +187,7 @@ def __call__(self, *args, **kwargs):
except (ValueError, RuntimeError) as e:
exc_type, exc_value, traceback = sys.exc_info()
shapes = self.msngr.trace.format_shapes()
assert exc_type is not None
exc = exc_type("{}\n{}".format(exc_value, shapes))
exc = exc.with_traceback(traceback)
raise exc from e
Expand All @@ -184,10 +197,10 @@ def __call__(self, *args, **kwargs):
return ret

@property
def trace(self):
def trace(self) -> Trace:
return self.msngr.trace

def get_trace(self, *args, **kwargs):
def get_trace(self, *args, **kwargs) -> Trace:
"""
:returns: data structure
:rtype: pyro.poutine.Trace
Expand Down

0 comments on commit 7233cf9

Please sign in to comment.