Skip to content

Commit

Permalink
Type annotate poutines (#3306)
Browse files Browse the repository at this point in the history
* wip

* fixes

* revert markov

* pass doctest

* address comments

* Make CondIndepStackFrame.full_size immutable

* Preserve comment

---------

Co-authored-by: Fritz Obermeyer <fritz.obermeyer@gmail.com>
  • Loading branch information
ordabayevy and fritzo committed Dec 18, 2023
1 parent 9c4f932 commit 834ff63
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 59 deletions.
2 changes: 1 addition & 1 deletion pyro/infer/autoguide/effect.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _adjust_plates(self, value: torch.Tensor, event_dim: int) -> torch.Tensor:
Adjusts plates for generating initial values of parameters.
"""
for f in get_plates():
full_size = getattr(f, "full_size", f.size)
full_size = f.full_size or f.size
dim = f.dim - event_dim
if f in self._outer_plates or f.name in self.amortized_plates:
if -value.dim() <= dim:
Expand Down
6 changes: 3 additions & 3 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _create_plates(self, *args, **kwargs):
self.plates = {p.name: p for p in plates}
for name, frame in sorted(self._prototype_frames.items()):
if name not in self.plates:
full_size = getattr(frame, "full_size", frame.size)
full_size = frame.full_size or frame.size
self.plates[name] = pyro.plate(
name, full_size, dim=frame.dim, subsample_size=frame.size
)
Expand Down Expand Up @@ -363,7 +363,7 @@ def _setup_prototype(self, *args, **kwargs):

# If subsampling, repeat init_value to full size.
for frame in site["cond_indep_stack"]:
full_size = getattr(frame, "full_size", frame.size)
full_size = frame.full_size or frame.size
if full_size != frame.size:
dim = frame.dim - event_dim
value = periodic_repeat(value, full_size, dim).contiguous()
Expand Down Expand Up @@ -475,7 +475,7 @@ def _setup_prototype(self, *args, **kwargs):

# If subsampling, repeat init_value to full size.
for frame in site["cond_indep_stack"]:
full_size = getattr(frame, "full_size", frame.size)
full_size = frame.full_size or frame.size
if full_size != frame.size:
dim = frame.dim - event_dim
init_loc = periodic_repeat(init_loc, full_size, dim).contiguous()
Expand Down
2 changes: 2 additions & 0 deletions pyro/poutine/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _tmc_mixture_sample(msg: Message) -> torch.Tensor:
batch_shape = [1] * len(dist.batch_shape)
for f in msg["cond_indep_stack"]:
if f.vectorized:
assert f.dim is not None
batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim]
batch_shape_tuple = tuple(batch_shape)

Expand Down Expand Up @@ -72,6 +73,7 @@ def _tmc_diagonal_sample(msg: Message) -> torch.Tensor:
batch_shape = [1] * len(dist.batch_shape)
for f in msg["cond_indep_stack"]:
if f.vectorized:
assert f.dim is not None
batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim]
batch_shape_tuple = tuple(batch_shape)

Expand Down
11 changes: 6 additions & 5 deletions pyro/poutine/escape_messenger.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from .messenger import Messenger
from .runtime import NonlocalExit
from typing import Callable

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message, NonlocalExit


class EscapeMessenger(Messenger):
"""
Messenger that does a nonlocal exit by raising a util.NonlocalExit exception
"""

def __init__(self, escape_fn):
def __init__(self, escape_fn: Callable[[Message], bool]) -> None:
"""
:param escape_fn: function that takes a msg as input and returns True
if the poutine should perform a nonlocal exit at that site.
Expand All @@ -20,7 +22,7 @@ def __init__(self, escape_fn):
super().__init__()
self.escape_fn = escape_fn

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
"""
:param msg: current message at a trace site
:returns: a sample from the stochastic function at the site.
Expand All @@ -38,4 +40,3 @@ def cont(m):
raise NonlocalExit(m)

msg["continuation"] = cont
return None
60 changes: 37 additions & 23 deletions pyro/poutine/indep_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,46 @@
# SPDX-License-Identifier: Apache-2.0

import numbers
from collections import namedtuple
from typing import Iterator, NamedTuple, Optional, Tuple

import torch
from typing_extensions import Self

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import _DIM_ALLOCATOR, Message
from pyro.util import ignore_jit_warnings

from .messenger import Messenger
from .runtime import _DIM_ALLOCATOR

class CondIndepStackFrame(NamedTuple):
name: str
dim: Optional[int]
size: int
counter: int
full_size: Optional[int] = None

class CondIndepStackFrame(
namedtuple("CondIndepStackFrame", ["name", "dim", "size", "counter"])
):
@property
def vectorized(self):
def vectorized(self) -> bool:
return self.dim is not None

def _key(self):
def _key(self) -> Tuple[str, Optional[int], int, int]:
with ignore_jit_warnings(["Converting a tensor to a Python number"]):
size = (
self.size.item() if isinstance(self.size, torch.Tensor) else self.size
self.size.item() if isinstance(self.size, torch.Tensor) else self.size # type: ignore[attr-defined]
)
return self.name, self.dim, size, self.counter

def __eq__(self, other):
return type(self) == type(other) and self._key() == other._key()
def __eq__(self, other: object) -> bool:
if not isinstance(other, CondIndepStackFrame):
return False
return self._key() == other._key()

def __ne__(self, other):
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

def __hash__(self):
def __hash__(self) -> int:
return hash(self._key())

def __str__(self):
def __str__(self) -> str:
return self.name


Expand All @@ -59,7 +65,13 @@ class IndepMessenger(Messenger):
"""

def __init__(self, name=None, size=None, dim=None, device=None):
def __init__(
self,
name: str,
size: int,
dim: Optional[int] = None,
device: Optional[str] = None,
):
if not torch._C._get_tracing_state() and size == 0:
raise ZeroDivisionError("size cannot be zero")

Expand All @@ -68,20 +80,20 @@ def __init__(self, name=None, size=None, dim=None, device=None):
if dim is not None:
self._vectorized = True

self._indices = None
self._indices: Optional[torch.Tensor] = None
self.name = name
self.dim = dim
self.size = size
self.device = device
self.counter = 0

def next_context(self):
def next_context(self) -> None:
"""
Increments the counter.
"""
self.counter += 1

def __enter__(self):
def __enter__(self) -> Self:
if self._vectorized is not False:
self._vectorized = True

Expand All @@ -90,12 +102,13 @@ def __enter__(self):

return super().__enter__()

def __exit__(self, *args):
def __exit__(self, *args) -> None:
if self._vectorized is True:
assert self.dim is not None
_DIM_ALLOCATOR.free(self.name, self.dim)
return super().__exit__(*args)

def __iter__(self):
def __iter__(self) -> Iterator[int]:
if self._vectorized is True or self.dim is not None:
raise ValueError(
"cannot use plate {} as both vectorized and non-vectorized"
Expand All @@ -110,18 +123,19 @@ def __iter__(self):
with self:
yield i if isinstance(i, numbers.Number) else i.item()

def _reset(self):
def _reset(self) -> None:
if self._vectorized:
assert self.dim is not None
_DIM_ALLOCATOR.free(self.name, self.dim)
self._vectorized = None
self.counter = 0

@property
def indices(self):
def indices(self) -> torch.Tensor:
if self._indices is None:
self._indices = torch.arange(self.size, dtype=torch.long).to(self.device)
return self._indices

def _process_message(self, msg):
def _process_message(self, msg: Message) -> None:
frame = CondIndepStackFrame(self.name, self.dim, self.size, self.counter)
msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
15 changes: 9 additions & 6 deletions pyro/poutine/infer_config_messenger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from .messenger import Messenger
from typing import Callable

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import InferDict, Message


class InferConfigMessenger(Messenger):
Expand All @@ -15,7 +18,7 @@ class InferConfigMessenger(Messenger):
:returns: stochastic function decorated with :class:`~pyro.poutine.infer_config_messenger.InferConfigMessenger`
"""

def __init__(self, config_fn):
def __init__(self, config_fn: Callable[[Message], InferDict]):
"""
:param config_fn: a callable taking a site and returning an infer dict
Expand All @@ -25,7 +28,7 @@ def __init__(self, config_fn):
super().__init__()
self.config_fn = config_fn

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
"""
:param msg: current message at a trace site.
Expand All @@ -35,10 +38,10 @@ def _pyro_sample(self, msg):
Otherwise, implements default sampling behavior
with no additional effects.
"""
assert msg["infer"] is not None
msg["infer"].update(self.config_fn(msg))
return None

def _pyro_param(self, msg):
def _pyro_param(self, msg: Message) -> None:
"""
:param msg: current message at a trace site.
Expand All @@ -48,5 +51,5 @@ def _pyro_param(self, msg):
Otherwise, implements default param behavior
with no additional effects.
"""
assert msg["infer"] is not None
msg["infer"].update(self.config_fn(msg))
return None
27 changes: 17 additions & 10 deletions pyro/poutine/lift_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Callable, Dict, Set, Union

from typing_extensions import Self

from pyro import params
from pyro.distributions.distribution import Distribution
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.poutine.util import is_validation_enabled

from .messenger import Messenger


class LiftMessenger(Messenger):
"""
Expand Down Expand Up @@ -40,7 +43,10 @@ class LiftMessenger(Messenger):
:returns: ``fn`` decorated with a :class:`~pyro.poutine.lift_messenger.LiftMessenger`
"""

def __init__(self, prior):
def __init__(
self,
prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]],
) -> None:
"""
:param prior: prior used to lift parameters. Prior can be of type
dict, pyro.distributions, or a python stochastic fn
Expand All @@ -49,16 +55,16 @@ def __init__(self, prior):
"""
super().__init__()
self.prior = prior
self._samples_cache = {}
self._samples_cache: Dict[str, Message] = {}

def __enter__(self):
def __enter__(self) -> Self:
self._samples_cache = {}
if is_validation_enabled() and isinstance(self.prior, dict):
self._param_hits = set()
self._param_misses = set()
self._param_hits: Set[str] = set()
self._param_misses: Set[str] = set()
return super().__enter__()

def __exit__(self, *args, **kwargs):
def __exit__(self, *args, **kwargs) -> None:
self._samples_cache = {}
if is_validation_enabled() and isinstance(self.prior, dict):
extra = set(self.prior) - self._param_hits
Expand All @@ -71,17 +77,18 @@ def __exit__(self, *args, **kwargs):
)
return super().__exit__(*args, **kwargs)

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
return None

def _pyro_param(self, msg):
def _pyro_param(self, msg: Message) -> None:
"""
Overrides the `pyro.param` call with samples sampled from the
distribution specified in the prior. The prior can be a
pyro.distributions object or a dict of distributions keyed
on the param names. If the param name does not match the
name the keys in the prior, that param name is unchanged.
"""
assert msg["name"] is not None
name = msg["name"]
param_name = params.user_param_name(name)
if isinstance(self.prior, dict):
Expand Down
10 changes: 6 additions & 4 deletions pyro/poutine/mask_messenger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import torch

from .messenger import Messenger
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message


class MaskMessenger(Messenger):
Expand All @@ -17,7 +20,7 @@ class MaskMessenger(Messenger):
:returns: stochastic function decorated with a :class:`~pyro.poutine.scale_messenger.MaskMessenger`
"""

def __init__(self, mask):
def __init__(self, mask: Union[bool, torch.BoolTensor]) -> None:
if isinstance(mask, torch.Tensor):
if mask.dtype != torch.bool:
raise ValueError(
Expand All @@ -31,6 +34,5 @@ def __init__(self, mask):
super().__init__()
self.mask = mask

def _process_message(self, msg):
def _process_message(self, msg: Message) -> None:
msg["mask"] = self.mask if msg["mask"] is None else msg["mask"] & self.mask
return None
2 changes: 2 additions & 0 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ class InferDict(TypedDict, total=False):
is_auxiliary: bool
is_observed: bool
num_samples: int
obs: Optional[torch.Tensor]
prior: TorchDistributionMixin
tmc: Literal["diagonal", "mixture"]
was_observed: bool
_deterministic: bool
_dim_to_symbol: Dict[int, str]
_do_not_trace: bool
Expand Down

0 comments on commit 834ff63

Please sign in to comment.