Skip to content

Commit

Permalink
Add param_store.py type hints (#3271)
Browse files Browse the repository at this point in the history
* add types

* StateDict

* fixes

* revert optim changes

* fix tensor type

* fix dtype

* change default to cpu

* address comment
  • Loading branch information
ordabayevy committed Oct 4, 2023
1 parent fa73d9c commit c00bcc3
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 40 deletions.
106 changes: 70 additions & 36 deletions pyro/params/param_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,26 @@
import warnings
import weakref
from contextlib import contextmanager
from typing import (
Callable,
Dict,
ItemsView,
Iterator,
KeysView,
Optional,
Tuple,
Union,
)

import torch
from torch.distributions import constraints, transform_to
from torch.serialization import MAP_LOCATION
from typing_extensions import TypedDict


class StateDict(TypedDict):
params: Dict[str, torch.Tensor]
constraints: Dict[str, constraints.Constraint]


class ParamStoreDict:
Expand Down Expand Up @@ -39,80 +56,86 @@ class ParamStoreDict:
# -------------------------------------------------------------------------------
# New dict-like interface

def __init__(self):
def __init__(self) -> None:
"""
initialize ParamStore data structures
"""
self._params = {} # dictionary from param name to param
self._param_to_name = {} # dictionary from unconstrained param to param name
self._constraints = {} # dictionary from param name to constraint object
self._params: Dict[
str, torch.Tensor
] = {} # dictionary from param name to param
self._param_to_name: Dict[
torch.Tensor, str
] = {} # dictionary from unconstrained param to param name
self._constraints: Dict[
str, constraints.Constraint
] = {} # dictionary from param name to constraint object

def clear(self):
def clear(self) -> None:
"""
Clear the ParamStore
"""
self._params = {}
self._param_to_name = {}
self._constraints = {}

def items(self):
def items(self) -> Iterator[Tuple[str, torch.Tensor]]:
"""
Iterate over ``(name, constrained_param)`` pairs. Note that `constrained_param` is
in the constrained (i.e. user-facing) space.
"""
for name in self._params:
yield name, self[name]

def keys(self):
def keys(self) -> KeysView[str]:
"""
Iterate over param names.
"""
return self._params.keys()

def values(self):
def values(self) -> Iterator[torch.Tensor]:
"""
Iterate over constrained parameter values.
"""
for name, constrained_param in self.items():
yield constrained_param

def __bool__(self):
def __bool__(self) -> bool:
return bool(self._params)

def __len__(self):
def __len__(self) -> int:
return len(self._params)

def __contains__(self, name):
def __contains__(self, name: str) -> bool:
return name in self._params

def __iter__(self):
def __iter__(self) -> Iterator[str]:
"""
Iterate over param names.
"""
return iter(self.keys())

def __delitem__(self, name):
def __delitem__(self, name) -> None:
"""
Remove a parameter from the param store.
"""
unconstrained_value = self._params.pop(name)
self._param_to_name.pop(unconstrained_value)
self._constraints.pop(name)

def __getitem__(self, name):
def __getitem__(self, name: str) -> torch.Tensor:
"""
Get the *constrained* value of a named parameter.
"""
unconstrained_value = self._params[name]

# compute the constrained value
constraint = self._constraints[name]
constrained_value = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
constrained_value: torch.Tensor = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value) # type: ignore[attr-defined]

return constrained_value

def __setitem__(self, name, new_constrained_value):
def __setitem__(self, name: str, new_constrained_value: torch.Tensor) -> None:
"""
Set the constrained value of an existing parameter, or the value of a
new *unconstrained* parameter. To declare a new parameter with
Expand All @@ -132,7 +155,12 @@ def __setitem__(self, name, new_constrained_value):
self._params[name] = unconstrained_value
self._param_to_name[unconstrained_value] = name

def setdefault(self, name, init_constrained_value, constraint=constraints.real):
def setdefault(
self,
name: str,
init_constrained_value: Union[torch.Tensor, Callable[[], torch.Tensor]],
constraint: constraints.Constraint = constraints.real,
) -> torch.Tensor:
"""
Retrieve a *constrained* parameter value from the if it exists, otherwise
set the initial value. Note that this is a little fancier than
Expand Down Expand Up @@ -170,32 +198,38 @@ def setdefault(self, name, init_constrained_value, constraint=constraints.real):
# -------------------------------------------------------------------------------
# Old non-dict interface

def named_parameters(self):
def named_parameters(self) -> ItemsView[str, torch.Tensor]:
"""
Returns an iterator over ``(name, unconstrained_value)`` tuples for
each parameter in the ParamStore. Note that, in the event the parameter is constrained,
`unconstrained_value` is in the unconstrained space implicitly used by the constraint.
"""
return self._params.items()

def get_all_param_names(self):
def get_all_param_names(self) -> KeysView[str]:
warnings.warn(
"ParamStore.get_all_param_names() is deprecated; use .keys() instead.",
DeprecationWarning,
)
return self.keys()

def replace_param(self, param_name, new_param, old_param):
def replace_param(
self, param_name: str, new_param: torch.Tensor, old_param: torch.Tensor
) -> None:
warnings.warn(
"ParamStore.replace_param() is deprecated; use .__setitem__() instead.",
DeprecationWarning,
)
assert self._params[param_name] is old_param.unconstrained()
assert self._params[param_name] is old_param.unconstrained() # type: ignore[attr-defined]
self[param_name] = new_param

def get_param(
self, name, init_tensor=None, constraint=constraints.real, event_dim=None
):
self,
name: str,
init_tensor: Optional[torch.Tensor] = None,
constraint: constraints.Constraint = constraints.real,
event_dim: Optional[int] = None,
) -> torch.Tensor:
"""
Get parameter from its name. If it does not yet exist in the
ParamStore, it will be created and stored.
Expand All @@ -216,7 +250,7 @@ def get_param(
else:
return self.setdefault(name, init_tensor, constraint)

def match(self, name):
def match(self, name: str) -> Dict[str, torch.Tensor]:
"""
Get all parameters that match regex. The parameter must exist.
Expand All @@ -227,7 +261,7 @@ def match(self, name):
pattern = re.compile(name)
return {name: self[name] for name in self if pattern.match(name)}

def param_name(self, p):
def param_name(self, p: torch.Tensor) -> Optional[str]:
"""
Get parameter name from parameter
Expand All @@ -239,18 +273,18 @@ def param_name(self, p):
# -------------------------------------------------------------------------------
# Persistence interface

def get_state(self) -> dict:
def get_state(self) -> StateDict:
"""
Get the ParamStore state.
"""
params = self._params.copy()
# Remove weakrefs in preparation for pickling.
for param in params.values():
param.__dict__.pop("unconstrained", None)
state = {"params": params, "constraints": self._constraints.copy()}
state: StateDict = {"params": params, "constraints": self._constraints.copy()}
return state

def set_state(self, state: dict):
def set_state(self, state: StateDict) -> None:
"""
Set the ParamStore state using state from a previous :meth:`get_state` call
"""
Expand All @@ -269,7 +303,7 @@ def set_state(self, state: dict):
constraint = constraints.real
self._constraints[param_name] = constraint

def save(self, filename):
def save(self, filename: str) -> None:
"""
Save parameters to file
Expand All @@ -279,7 +313,7 @@ def save(self, filename):
with open(filename, "wb") as output_file:
torch.save(self.get_state(), output_file)

def load(self, filename, map_location=None):
def load(self, filename: str, map_location: MAP_LOCATION = None) -> None:
"""
Loads parameters from file
Expand All @@ -301,7 +335,7 @@ def load(self, filename, map_location=None):
self.set_state(state)

@contextmanager
def scope(self, state=None) -> dict:
def scope(self, state: Optional[StateDict] = None) -> Iterator[StateDict]:
"""
Context manager for using multiple parameter stores within the same process.
Expand Down Expand Up @@ -343,19 +377,19 @@ def scope(self, state=None) -> dict:
_MODULE_NAMESPACE_DIVIDER = "$$$"


def param_with_module_name(pyro_name, param_name):
def param_with_module_name(pyro_name: str, param_name: str) -> str:
return _MODULE_NAMESPACE_DIVIDER.join([pyro_name, param_name])


def module_from_param_with_module_name(param_name):
def module_from_param_with_module_name(param_name: str) -> str:
return param_name.split(_MODULE_NAMESPACE_DIVIDER)[0]


def user_param_name(param_name):
def user_param_name(param_name: str) -> str:
if _MODULE_NAMESPACE_DIVIDER in param_name:
return param_name.split(_MODULE_NAMESPACE_DIVIDER)[1]
return param_name


def normalize_param_name(name):
def normalize_param_name(name: str) -> str:
return name.replace(_MODULE_NAMESPACE_DIVIDER, ".")
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ warn_unused_ignores = True
[mypy-pyro.optm.*]
warn_unused_ignores = True

[mypy-pyro.params.*]
ignore_errors = True
warn_unused_ignores = True

[mypy-pyro.poutine.*]
ignore_errors = True

Expand Down

0 comments on commit c00bcc3

Please sign in to comment.