Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Mar 9, 2024
1 parent 03a48db commit f83c6c0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 37 deletions.
58 changes: 24 additions & 34 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import pyro
import pyro.params.param_store
from pyro.ops.provenance import detach_provenance
from pyro.params.param_store import StateDict
from pyro.poutine.runtime import _PYRO_PARAM_STORE

_MODULE_LOCAL_PARAMS: bool = False
Expand All @@ -51,14 +50,15 @@

if TYPE_CHECKING:
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.params.param_store import StateDict


@pyro.settings.register("module_local_params", __name__, "_MODULE_LOCAL_PARAMS")
def _validate_module_local_params(value: bool) -> None:
assert isinstance(value, bool)


def _is_module_local_param_enabled() -> Optional[bool]:
def _is_module_local_param_enabled() -> bool:
return pyro.settings.get("module_local_params") # type: ignore[no-any-return]


Expand Down Expand Up @@ -108,12 +108,7 @@ def forward(self):
dims and no subsampling will be performed.
"""

init_value: Optional[
Union[
torch.Tensor,
Callable[[], torch.Tensor],
]
] = None
init_value: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None
constraint: constraints.Constraint = constraints.real
event_dim: Optional[int] = None

Expand All @@ -125,20 +120,19 @@ def __get__(
if obj is None:
return self

assert not isinstance(self.init_value, torch.Tensor)
assert self.init_value is not None
name = self.init_value.__name__
name = self.init_value.__name__ # type: ignore[union-attr]
if name not in obj.__dict__["_pyro_params"]:
init_value, constraint, event_dim = self
assert not isinstance(init_value, torch.Tensor)
assert init_value is not None
init_value = functools.partial(init_value, obj) # bind method's self arg
# bind method's self arg
init_value = functools.partial(init_value, obj) # type: ignore[arg-type]
setattr(obj, name, PyroParam(init_value, constraint, event_dim))
value: PyroParam = obj.__getattr__(name)
return value

# Support decoration with optional kwargs, e.g. @PyroParam(event_dim=0).
def __call__(self, init_value: torch.Tensor) -> "PyroParam":
def __call__(
self, init_value: Union[torch.Tensor, Callable[[], torch.Tensor]]
) -> "PyroParam":
assert self.init_value is None
return PyroParam(init_value, self.constraint, self.event_dim)

Expand Down Expand Up @@ -179,26 +173,21 @@ def forward(self):
"TorchDistributionMixin", Callable[["PyroModule"], "TorchDistributionMixin"]
]

def __init__(
self,
prior: Union[
"TorchDistributionMixin", Callable[["PyroModule"], "TorchDistributionMixin"]
],
) -> None:
def __post_init__(self) -> None:
super().__init__()
if not hasattr(prior, "sample"): # if not a distribution
if not hasattr(self.prior, "sample"): # if not a distribution
assert 1 == sum(
1
for p in inspect.signature(prior).parameters.values()
for p in inspect.signature(self.prior).parameters.values()
if p.default is inspect.Parameter.empty
), "prior should take the single argument 'self'"
self.name = getattr(prior, "__name__", None)
self.name: Optional[str] = getattr(self.prior, "__name__", None)
if self.name is not None:
# Ensure decorated function is accessible for pickling.
prior.__name__ = "_pyro_prior_" + prior.__name__
qualname = prior.__qualname__.rsplit(".", 1)
qualname[-1] = prior.__name__
prior.__qualname__ = ".".join(qualname)
self.prior.__name__ = "_pyro_prior_" + self.prior.__name__
qualname = self.prior.__qualname__.rsplit(".", 1)
qualname[-1] = self.prior.__name__
self.prior.__qualname__ = ".".join(qualname)

# Support use as a decorator.
def __get__(
Expand All @@ -217,13 +206,12 @@ def __get__(
setattr(obj_type, self.prior.__name__, self.prior) # for pickling

obj.__dict__["_pyro_samples"].setdefault(self.name, self.prior)
if TYPE_CHECKING:
assert isinstance(self.name, str)
assert self.name is not None
value: PyroSample = obj.__getattr__(self.name)
return value


def _make_name(prefix: Optional[str], name: str) -> str:
def _make_name(prefix: str, name: str) -> str:
return "{}.{}".format(prefix, name) if prefix else name


Expand All @@ -248,7 +236,7 @@ def __init__(self) -> None:
self.cache: Dict[str, torch.Tensor] = {}
self.used = False
if _is_module_local_param_enabled():
self.param_state: StateDict = {"params": {}, "constraints": {}}
self.param_state: "StateDict" = {"params": {}, "constraints": {}}

def __enter__(self) -> None:
if not self.active and _is_module_local_param_enabled():
Expand Down Expand Up @@ -870,7 +858,8 @@ def to_pyro_module_(m: torch.nn.Module, recurse: bool = True) -> None:
if isinstance(m, PyroModule):
if recurse:
for name, module in list(m._modules.items()):
assert module is not None
if TYPE_CHECKING:
assert module is not None
to_pyro_module_(module)
setattr(m, name, module)
return
Expand All @@ -888,7 +877,8 @@ def to_pyro_module_(m: torch.nn.Module, recurse: bool = True) -> None:
setattr(m, name, param)
for name, module in list(m._modules.items()):
if recurse:
assert module is not None
if TYPE_CHECKING:
assert module is not None
to_pyro_module_(module)
setattr(m, name, module)

Expand Down
4 changes: 1 addition & 3 deletions pyro/poutine/reparam_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ def _pyro_sample(self, msg: "Message") -> None:
# ReplayMessenger we would need to ensure those messengers can
# similarly be safely applied twice, with the second application
# avoiding overwriting the original application.
_get_init_messengers_iter = _get_init_messengers()
assert _get_init_messengers_iter is not None
for m in _get_init_messengers_iter:
for m in _get_init_messengers():
m._process_message(msg)

# Pass args_kwargs to the reparam via a side channel.
Expand Down

0 comments on commit f83c6c0

Please sign in to comment.