Skip to content

Commit

Permalink
Type annotate broadcast, condition, do, enum messengers (#3295)
Browse files Browse the repository at this point in the history
* Type annotate broadcast, condition, do, enum messengers

* InferDict

* TorchDistributionMixin

* simplify

* fix do messenger

* batch_shape & event_shape

* torch.Size
  • Loading branch information
ordabayevy committed Nov 27, 2023
1 parent e3281ca commit c1ebdf3
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 95 deletions.
4 changes: 3 additions & 1 deletion pyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import inspect
from abc import ABCMeta, abstractmethod

import torch

from pyro.distributions.score_parts import ScoreParts

COERCIONS = []
Expand Down Expand Up @@ -122,7 +124,7 @@ def score_parts(self, x, *args, **kwargs):
log_prob=log_prob, score_function=log_prob, entropy_term=0
)

def enumerate_support(self, expand=True):
def enumerate_support(self, expand: bool = True) -> torch.Tensor:
"""
Returns a representation of the parametrized distribution's support,
along the first dimension. This is implemented only by discrete
Expand Down
16 changes: 16 additions & 0 deletions pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
else self.sample(sample_shape)
)

@property
def batch_shape(self) -> torch.Size:
"""
:return: The shape over which parameters are batched.
:rtype: torch.Size
"""
raise NotImplementedError

@property
def event_shape(self) -> torch.Size:
"""
:return: The shape of a single sample from the distribution (without batching).
:rtype: torch.Size
"""
raise NotImplementedError

@property
def event_dim(self) -> int:
"""
Expand Down
85 changes: 46 additions & 39 deletions pyro/poutine/broadcast_messenger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from pyro.util import ignore_jit_warnings
from typing import List, Optional

from .messenger import Messenger
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.util import ignore_jit_warnings


class BroadcastMessenger(Messenger):
Expand Down Expand Up @@ -38,47 +41,51 @@ class BroadcastMessenger(Messenger):

@staticmethod
@ignore_jit_warnings(["Converting a tensor to a Python boolean"])
def _pyro_sample(msg):
def _pyro_sample(msg: Message) -> None:
"""
:param msg: current message at a trace site.
"""
if msg["done"] or msg["type"] != "sample":
if (
msg["done"]
or msg["type"] != "sample"
or not isinstance(msg["fn"], TorchDistributionMixin)
):
return

dist = msg["fn"]
actual_batch_shape = getattr(dist, "batch_shape", None)
if actual_batch_shape is not None:
target_batch_shape = [
None if size == 1 else size for size in actual_batch_shape
]
for f in msg["cond_indep_stack"]:
if f.dim is None or f.size == -1:
continue
assert f.dim < 0
target_batch_shape = [None] * (
-f.dim - len(target_batch_shape)
) + target_batch_shape
if (
target_batch_shape[f.dim] is not None
and target_batch_shape[f.dim] != f.size
):
raise ValueError(
"Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
f.name,
msg["name"],
f.dim,
f.size,
target_batch_shape[f.dim],
)
)
target_batch_shape[f.dim] = f.size
# Starting from the right, if expected size is None at an index,
# set it to the actual size if it exists, else 1.
for i in range(-len(target_batch_shape) + 1, 1):
if target_batch_shape[i] is None:
target_batch_shape[i] = (
actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1
actual_batch_shape = dist.batch_shape
target_batch_shape = [
None if size == 1 else size for size in actual_batch_shape
]
for f in msg["cond_indep_stack"]:
if f.dim is None or f.size == -1:
continue
assert f.dim < 0
prefix_batch_shape: List[Optional[int]] = [None] * (
-f.dim - len(target_batch_shape)
)
target_batch_shape = prefix_batch_shape + target_batch_shape
if (
target_batch_shape[f.dim] is not None
and target_batch_shape[f.dim] != f.size
):
raise ValueError(
"Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
f.name,
msg["name"],
f.dim,
f.size,
target_batch_shape[f.dim],
)
msg["fn"] = dist.expand(target_batch_shape)
if msg["fn"].has_rsample != dist.has_rsample:
msg["fn"].has_rsample = dist.has_rsample # copy custom attribute
)
target_batch_shape[f.dim] = f.size
# Starting from the right, if expected size is None at an index,
# set it to the actual size if it exists, else 1.
for i in range(-len(target_batch_shape) + 1, 1):
if target_batch_shape[i] is None:
target_batch_shape[i] = (
actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1
)
msg["fn"] = dist.expand(target_batch_shape)
if msg["fn"].has_rsample != dist.has_rsample:
msg["fn"].has_rsample = dist.has_rsample # copy custom attribute
15 changes: 10 additions & 5 deletions pyro/poutine/condition_messenger.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from .messenger import Messenger
from .trace_struct import Trace
from typing import Dict, Union

import torch

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.poutine.trace_struct import Trace


class ConditionMessenger(Messenger):
Expand Down Expand Up @@ -31,7 +36,7 @@ class ConditionMessenger(Messenger):
:returns: stochastic function decorated with a :class:`~pyro.poutine.condition_messenger.ConditionMessenger`
"""

def __init__(self, data):
def __init__(self, data: Union[Dict[str, torch.Tensor], Trace]) -> None:
"""
:param data: a dict or a Trace
Expand All @@ -41,7 +46,7 @@ def __init__(self, data):
super().__init__()
self.data = data

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
"""
:param msg: current message at a trace site.
:returns: a sample from the stochastic function at the site.
Expand All @@ -53,6 +58,7 @@ def _pyro_sample(self, msg):
Otherwise, implements default sampling behavior
with no additional effects.
"""
assert isinstance(msg["name"], str)
name = msg["name"]

if name in self.data:
Expand All @@ -61,4 +67,3 @@ def _pyro_sample(self, msg):
else:
msg["value"] = self.data[name]
msg["is_observed"] = msg["value"] is not None
return None
22 changes: 13 additions & 9 deletions pyro/poutine/do_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import numbers
import warnings
from typing import Dict, Union

import torch

from .messenger import Messenger
from .runtime import apply_stack
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message, apply_stack


class DoMessenger(Messenger):
Expand Down Expand Up @@ -48,17 +49,18 @@ class DoMessenger(Messenger):
:returns: stochastic function decorated with a :class:`~pyro.poutine.do_messenger.DoMessenger`
"""

def __init__(self, data):
def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]):
super().__init__()
self.data = data
self._intervener_id = str(id(self))

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
assert isinstance(msg["name"], str)
if (
msg.get("_intervener_id", None) != self._intervener_id
msg.get("_intervener_id") != self._intervener_id
and self.data.get(msg["name"]) is not None
):
if msg.get("_intervener_id", None) is not None:
if msg.get("_intervener_id") is not None:
warnings.warn(
"Attempting to intervene on variable {} multiple times,"
"this is almost certainly incorrect behavior".format(msg["name"]),
Expand All @@ -76,7 +78,11 @@ def _pyro_sample(self, msg):
intervention = self.data[msg["name"]]
msg["name"] = msg["name"] + "__CF" # mangle old name

if isinstance(intervention, (numbers.Number, torch.Tensor)):
if isinstance(intervention, numbers.Number):
msg["value"] = torch.tensor(intervention)
msg["is_observed"] = True
msg["stop"] = True
elif isinstance(intervention, torch.Tensor):
msg["value"] = intervention
msg["is_observed"] = True
msg["stop"] = True
Expand All @@ -86,5 +92,3 @@ def _pyro_sample(self, msg):
type(intervention)
)
)

return None

0 comments on commit c1ebdf3

Please sign in to comment.