Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up handling of global settings #3152

Merged
merged 6 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Pyro Documentation
optimization
poutine
ops
settings
testing

.. toctree::
Expand Down
6 changes: 6 additions & 0 deletions docs/source/settings.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Settings
--------

.. automodule:: pyro.settings
:members:
:member-order: bysource
3 changes: 3 additions & 0 deletions pyro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
)
from pyro.util import set_rng_seed

from . import settings

# After changing this, run scripts/update_version.py
version_prefix = "1.8.2"

Expand Down Expand Up @@ -58,6 +60,7 @@
"render_model",
"sample",
"set_rng_seed",
"settings",
"subsample",
"validation_enabled",
]
17 changes: 17 additions & 0 deletions pyro/distributions/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyro.distributions.util import broadcast_shape, sum_rightmost
from pyro.ops.special import log_binomial

from .. import settings
from . import constraints


Expand Down Expand Up @@ -98,6 +99,22 @@ def log_prob(self, value):
)


@settings.register(
"binomial_approx_sample_thresh", __name__, "Binomial.approx_sample_thresh"
)
def _validate_thresh(thresh):
assert isinstance(thresh, float)
assert 0 < thresh


@settings.register(
"binomial_approx_log_prob_tol", __name__, "Binomial.approx_log_prob_tol"
)
def _validate_tol(tol):
assert isinstance(tol, float)
assert 0 <= tol


# This overloads .log_prob() and .enumerate_support() to speed up evaluating
# log_prob on the support of this variable: we can completely avoid tensor ops
# and merely reshape the self.logits tensor. This is especially important for
Expand Down
9 changes: 9 additions & 0 deletions pyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@

from pyro.util import ignore_jit_warnings

from .. import settings

_VALIDATION_ENABLED = __debug__
torch_dist.Distribution.set_default_validate_args(__debug__)

settings.register("validate_distributions_pyro", __name__, "_VALIDATION_ENABLED")
settings.register(
"validate_distributions_torch",
"torch.distributions.distribution",
"Distribution._validate_args",
)

log_sum_exp = logsumexp # DEPRECATED


Expand Down
4 changes: 4 additions & 0 deletions pyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from pyro.ops.rings import MarginalRing
from pyro.poutine.util import site_is_subsample

from .. import settings

_VALIDATION_ENABLED = __debug__
settings.register("validate_infer", __name__, "_VALIDATION_ENABLED")

LAST_CACHE_SIZE = [Counter()] # for profiling


Expand Down
8 changes: 8 additions & 0 deletions pyro/ops/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
import torch
from torch.fft import irfft, rfft

from .. import settings

_ROOT_TWO_INVERSE = 1.0 / math.sqrt(2.0)
CHOLESKY_RELATIVE_JITTER = 4.0 # in units of finfo.eps


@settings.register("cholesky_relative_jitter", __name__, "CHOLESKY_RELATIVE_JITTER")
def _validate_jitter(value):
assert isinstance(value, (float, int))
assert 0 <= value


def as_complex(x):
"""
Similar to :func:`torch.view_as_complex` but copies data in case strides
Expand Down
3 changes: 3 additions & 0 deletions pyro/poutine/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from .. import settings

_VALIDATION_ENABLED = __debug__
settings.register("validate_poutine", __name__, "_VALIDATION_ENABLED")


def enable_validation(is_validate):
Expand Down
163 changes: 163 additions & 0 deletions pyro/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
Example usage::

# Simple getting and setting.
print(pyro.settings.get()) # print all settings
print(pyro.settings.get("cholesky_relative_jitter")) # print one
pyro.settings.set(cholesky_relative_jitter=0.5) # set one
pyro.settings.set(**my_settings) # set many

# Use as a contextmanager.
with pyro.settings.context(cholesky_relative_jitter=0.5):
my_function()

# Use as a decorator.
fn = pyro.settings.context(cholesky_relative_jitter=0.5)(my_function)
fn()

# Register a new setting.
pyro.settings.register(
"binomial_approx_sample_thresh", # alias
"pyro.distributions.torch", # module
"Binomial.approx_sample_thresh", # deep name
)

# Register a new setting on a user-provided validator.
@pyro.settings.register(
"binomial_approx_sample_thresh", # alias
"pyro.distributions.torch", # module
"Binomial.approx_sample_thresh", # deep name
)
def validate_thresh(thresh): # called each time setting is set
assert isinstance(thresh, float)
assert thresh > 0

Default Settings
----------------

{defaults}

Settings Interface
------------------
"""

# This library must have no dependencies on other pyro modules.
import functools
from contextlib import contextmanager
from importlib import import_module
from typing import Any, Callable, Dict, Iterator, Optional, Tuple

# Docs are updated by register().
_doc_template = __doc__

# Global registry mapping alias:str to (modulename, deepname, validator)
# triples where deepname may have dots to indicate e.g. class variables.
_REGISTRY: Dict[str, Tuple[str, str, Optional[Callable]]] = {}


def get(alias: Optional[str] = None) -> Any:
"""
Gets one or all global settings.

:param str alias: The name of a registered setting.
:returns: The currently set value.
"""
if alias is None:
# Return dict of all settings.
return {alias: get(alias) for alias in sorted(_REGISTRY)}
# Get a single setting.
module, deepname, validator = _REGISTRY[alias]
value = import_module(module)
for name in deepname.split("."):
value = getattr(value, name)
return value


def set(**kwargs) -> None:
r"""
Sets one or more settings.

:param \*\*kwargs: alias=value pairs.
"""
for alias, value in kwargs.items():
module, deepname, validator = _REGISTRY[alias]
if validator is not None:
validator(value)
destin = import_module(module)
names = deepname.split(".")
for name in names[:-1]:
destin = getattr(destin, name)
setattr(destin, names[-1], value)


@contextmanager
def context(**kwargs) -> Iterator[None]:
r"""
Context manager to temporarily override one or more settings. This also
works as a decorator.

:param \*\*kwargs: alias=value pairs.
"""
old = {alias: get(alias) for alias in kwargs}
try:
set(**kwargs)
yield
finally:
set(**old)


def register(
alias: str,
modulename: str,
deepname: str,
validator: Optional[Callable] = None,
) -> Callable:
"""
Register a global settings.

This should be declared in the module where the setting is defined.

This can be used either as a declaration::

settings.register("my_setting", __name__, "MY_SETTING")

or as a decorator on a user-defined validator function::

@settings.register("my_setting", __name__, "MY_SETTING")
def _validate_my_setting(value):
assert isinstance(value, float)
assert 0 < value

:param str alias: A valid python identifier serving as a settings alias.
Lower snake case preferred, e.g. ``my_setting``.
:param str modulename: The module name where the setting is declared,
typically ``__name__``.
:param str deepname: A ``.``-separated string of names. E.g. for a module
constant, use ``MY_CONSTANT``. For a class attributue, use
``MyClass.my_attribute``.
:param callable validator: Optional validator that inputs a value,
possibly raises validation errors, and returns None.
"""
global __doc__
assert isinstance(alias, str)
assert alias.isidentifier()
assert isinstance(modulename, str)
assert isinstance(deepname, str)
_REGISTRY[alias] = modulename, deepname, validator

# Add default value to module docstring.
__doc__ = _doc_template.format(
defaults="\n".join(f"- {a} = {get(a)}" for a in sorted(_REGISTRY))
)

# Support use as a decorator on an optional user-provided validator.
if validator is None:
# Return a decorator, but its fine if user discards this.
return functools.partial(register, alias, modulename, deepname)
else:
# Test current value passes validation.
validator(get(alias))
return validator
50 changes: 50 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pytest

from pyro import settings

_TEST_SETTING: float = 0.1

pytestmark = pytest.mark.stage("unit")


def test_settings():
v0 = settings.get()
assert isinstance(v0, dict)
assert all(isinstance(alias, str) for alias in v0)
assert settings.get("validate_distributions_pyro") is True
assert settings.get("validate_distributions_torch") is True
assert settings.get("validate_poutine") is True
assert settings.get("validate_infer") is True


def test_register():
with pytest.raises(KeyError):
settings.get("test_setting")

@settings.register("test_setting", "tests.test_settings", "_TEST_SETTING")
def _validate(value):
assert isinstance(value, float)
assert 0 < value

# Test simple get and set.
assert settings.get("test_setting") == 0.1
settings.set(test_setting=0.2)
assert settings.get("test_setting") == 0.2
with pytest.raises(AssertionError):
settings.set(test_setting=-0.1)

# Test context manager.
with settings.context(test_setting=0.3):
assert settings.get("test_setting") == 0.3
assert settings.get("test_setting") == 0.2

# Test decorator.
@settings.context(test_setting=0.4)
def fn():
assert settings.get("test_setting") == 0.4

fn()
assert settings.get("test_setting") == 0.2