Skip to content

Commit

Permalink
Type annotate messengers (#3308)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Dec 24, 2023
1 parent 579163a commit 5d920aa
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 74 deletions.
2 changes: 1 addition & 1 deletion pyro/poutine/escape_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _pyro_sample(self, msg: Message) -> None:
msg["done"] = True
msg["stop"] = True

def cont(m):
def cont(m: Message) -> None:
raise NonlocalExit(m)

msg["continuation"] = cont
2 changes: 1 addition & 1 deletion pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class Message(TypedDict, total=False):
args: Tuple
kwargs: Dict
value: Optional[torch.Tensor]
scale: float
scale: Union[torch.Tensor, float]
mask: Union[bool, torch.Tensor, None]
cond_indep_stack: Tuple[CondIndepStackFrame, ...]
done: bool
Expand Down
11 changes: 6 additions & 5 deletions pyro/poutine/scale_messenger.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import torch

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.poutine.util import is_validation_enabled

from .messenger import Messenger


class ScaleMessenger(Messenger):
"""
Expand All @@ -33,7 +35,7 @@ class ScaleMessenger(Messenger):
:returns: stochastic function decorated with a :class:`~pyro.poutine.scale_messenger.ScaleMessenger`
"""

def __init__(self, scale):
def __init__(self, scale: Union[float, torch.Tensor]) -> None:
if isinstance(scale, torch.Tensor):
if is_validation_enabled() and not (scale > 0).all():
raise ValueError(
Expand All @@ -45,6 +47,5 @@ def __init__(self, scale):
super().__init__()
self.scale = scale

def _process_message(self, msg):
def _process_message(self, msg: Message) -> None:
msg["scale"] = self.scale * msg["scale"]
return None
17 changes: 12 additions & 5 deletions pyro/poutine/seed_messenger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from pyro.util import get_rng_state, set_rng_seed, set_rng_state
from types import TracebackType
from typing import Optional, Type

from .messenger import Messenger
from pyro.poutine.messenger import Messenger
from pyro.util import get_rng_state, set_rng_seed, set_rng_state


class SeedMessenger(Messenger):
Expand All @@ -18,14 +20,19 @@ class SeedMessenger(Messenger):
:param int rng_seed: rng seed.
"""

def __init__(self, rng_seed):
def __init__(self, rng_seed: int) -> None:
assert isinstance(rng_seed, int)
self.rng_seed = rng_seed
super().__init__()

def __enter__(self):
def __enter__(self) -> None: # type: ignore[override]
self.old_state = get_rng_state()
set_rng_seed(self.rng_seed)

def __exit__(self, type, value, traceback):
def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
traceback: Optional[TracebackType],
) -> None:
set_rng_state(self.old_state)
115 changes: 63 additions & 52 deletions pyro/poutine/subsample_messenger.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import torch

from pyro.distributions.distribution import Distribution
from pyro.poutine.indep_messenger import CondIndepStackFrame, IndepMessenger
from pyro.poutine.runtime import Message, apply_stack
from pyro.poutine.util import is_validation_enabled
from pyro.util import ignore_jit_warnings

from .indep_messenger import CondIndepStackFrame, IndepMessenger
from .runtime import apply_stack


class _Subsample(Distribution):
"""
Expand All @@ -18,7 +19,13 @@ class _Subsample(Distribution):
Internal use only. This should only be used by `plate`.
"""

def __init__(self, size, subsample_size, use_cuda=None, device=None):
def __init__(
self,
size: int,
subsample_size: Optional[int],
use_cuda: Optional[bool] = None,
device: Optional[str] = None,
) -> None:
"""
:param int size: the size of the range to subsample from
:param int subsample_size: the size of the returned subsample
Expand All @@ -38,10 +45,10 @@ def __init__(self, size, subsample_size, use_cuda=None, device=None):
)
)
with ignore_jit_warnings(["torch.Tensor results are registered as constants"]):
self.device = torch.Tensor().device if not device else device
self.device = device or torch.Tensor().device

@ignore_jit_warnings(["Converting a tensor to a Python boolean"])
def sample(self, sample_shape=torch.Size()):
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
"""
:returns: a random subsample of `range(size)`
:rtype: torch.LongTensor
Expand All @@ -57,7 +64,7 @@ def sample(self, sample_shape=torch.Size()):
].clone()
return result.cuda() if self.use_cuda else result

def log_prob(self, x):
def log_prob(self, x: torch.Tensor) -> torch.Tensor:
# This is zero so that plate can provide an unbiased estimate of
# the non-subsampled log_prob.
result = torch.tensor(0.0, device=self.device)
Expand All @@ -71,33 +78,34 @@ class SubsampleMessenger(IndepMessenger):

def __init__(
self,
name,
size=None,
subsample_size=None,
subsample=None,
dim=None,
use_cuda=None,
device=None,
):
super().__init__(name, size, dim, device)
self.subsample_size = subsample_size
self._indices = subsample
self.use_cuda = use_cuda
self.device = device

self.size, self.subsample_size, self._indices = self._subsample(
self.name,
self.size,
self.subsample_size,
self._indices,
self.use_cuda,
self.device,
name: str,
size: Optional[int] = None,
subsample_size: Optional[int] = None,
subsample: Optional[torch.Tensor] = None,
dim: Optional[int] = None,
use_cuda: Optional[bool] = None,
device: Optional[str] = None,
) -> None:
full_size, self.subsample_size, subsample = self._subsample(
name,
size,
subsample_size,
subsample,
use_cuda,
device,
)
super().__init__(name, full_size, dim, device)
self._indices = subsample

@staticmethod
def _subsample(
name, size=None, subsample_size=None, subsample=None, use_cuda=None, device=None
):
name: str,
size: Optional[int] = None,
subsample_size: Optional[int] = None,
subsample: Optional[torch.Tensor] = None,
use_cuda: Optional[bool] = None,
device: Optional[str] = None,
) -> Tuple[int, int, Optional[torch.Tensor]]:
"""
Helper function for plate. See its docstrings for details.
"""
Expand All @@ -107,27 +115,28 @@ def _subsample(
size = -1 # This is PyTorch convention for "arbitrary size"
subsample_size = -1
else:
msg = {
"type": "sample",
"name": name,
"fn": _Subsample(size, subsample_size, use_cuda, device),
"is_observed": False,
"args": (),
"kwargs": {},
"value": subsample,
"infer": {},
"scale": 1.0,
"mask": None,
"cond_indep_stack": (),
"done": False,
"stop": False,
"continuation": None,
}
msg = Message(
type="sample",
name=name,
fn=_Subsample(size, subsample_size, use_cuda, device),
is_observed=False,
args=(),
kwargs={},
value=subsample,
infer={},
scale=1.0,
mask=None,
cond_indep_stack=(),
done=False,
stop=False,
continuation=None,
)
apply_stack(msg)
subsample = msg["value"]

with ignore_jit_warnings():
if subsample_size is None:
assert subsample is not None
subsample_size = (
subsample.size(0)
if isinstance(subsample, torch.Tensor)
Expand All @@ -143,11 +152,11 @@ def _subsample(

return size, subsample_size, subsample

def _reset(self):
def _reset(self) -> None:
self._indices = None
super()._reset()

def _process_message(self, msg):
def _process_message(self, msg: Message) -> None:
frame = CondIndepStackFrame(
name=self.name,
dim=self.dim,
Expand All @@ -164,12 +173,13 @@ def _process_message(self, msg):
msg["scale"] = torch.tensor(msg["scale"])
msg["scale"] = msg["scale"] * self.size / self.subsample_size

def _postprocess_message(self, msg):
def _postprocess_message(self, msg: Message) -> None:
if msg["type"] in ("param", "subsample") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
assert event_dim >= 0
dim = self.dim - event_dim
assert msg["value"] is not None
shape = msg["value"].shape
if len(shape) >= -dim and shape[dim] != 1:
if is_validation_enabled() and shape[dim] != self.size:
Expand All @@ -189,18 +199,19 @@ def _postprocess_message(self, msg):
# Subsample parameters with known batch semantics.
if self.subsample_size < self.size:
value = msg["value"]
assert self._indices is not None
new_value = value.index_select(
dim, self._indices.to(value.device)
)
if msg["type"] == "param":
if hasattr(value, "_pyro_unconstrained_param"):
param = value._pyro_unconstrained_param
param = value._pyro_unconstrained_param # type: ignore[attr-defined]
else:
param = value.unconstrained()
param = value.unconstrained() # type: ignore[attr-defined]

if not hasattr(param, "_pyro_subsample"):
param._pyro_subsample = {}

param._pyro_subsample[dim] = self._indices
new_value._pyro_unconstrained_param = param
new_value._pyro_unconstrained_param = param # type: ignore[attr-defined]
msg["value"] = new_value
26 changes: 16 additions & 10 deletions pyro/poutine/substitute_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Dict, Set

import torch
from typing_extensions import Self

from pyro import params
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.poutine.util import is_validation_enabled


Expand All @@ -20,30 +25,30 @@ class SubstituteMessenger(Messenger):
... a = pyro.param("a", torch.tensor(0.5))
... x = pyro.sample("x", dist.Bernoulli(probs=a))
... return x
>>> substituted_model = pyro.poutine.substitute(model, data={"a": 0.3})
>>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})
In this example, site `a` will now have value `0.3`.
In this example, site `a` will now have value `torch.tensor(0.3)`.
:param data: dictionary of values keyed by site names.
:returns: ``fn`` decorated with a :class:`~pyro.poutine.substitute_messenger.SubstituteMessenger`
"""

def __init__(self, data):
def __init__(self, data: Dict[str, torch.Tensor]) -> None:
"""
:param data: values for the parameters.
Constructor
"""
super().__init__()
self.data = data
self._data_cache = {}
self._data_cache: Dict[str, Message] = {}

def __enter__(self):
def __enter__(self) -> Self:
self._data_cache = {}
if is_validation_enabled() and isinstance(self.data, dict):
self._param_hits = set()
self._param_misses = set()
self._param_hits: Set[str] = set()
self._param_misses: Set[str] = set()
return super().__enter__()

def __exit__(self, *args, **kwargs):
def __exit__(self, *args, **kwargs) -> None:
self._data_cache = {}
if is_validation_enabled() and isinstance(self.data, dict):
extra = set(self.data) - self._param_hits
Expand All @@ -56,15 +61,16 @@ def __exit__(self, *args, **kwargs):
)
return super().__exit__(*args, **kwargs)

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
return None

def _pyro_param(self, msg):
def _pyro_param(self, msg: Message) -> None:
"""
Overrides the `pyro.param` with substituted values.
If the param name does not match the name the keys in `data`,
that param value is unchanged.
"""
assert msg["name"] is not None
name = msg["name"]
param_name = params.user_param_name(name)

Expand Down

0 comments on commit 5d920aa

Please sign in to comment.