Skip to content

Commit

Permalink
Some more acquisition functions for quad (#758)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmahsereci committed Feb 10, 2023
1 parent f657a87 commit 128ce48
Show file tree
Hide file tree
Showing 12 changed files with 402 additions and 127 deletions.
29 changes: 17 additions & 12 deletions src/probnum/quad/_bayesquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,16 @@ def bayesquad(
policy
Type of acquisition strategy to use. Defaults to 'bmc'. Options are
============================================ ===========
Bayesian Monte Carlo [2]_ ``bmc``
Van Der Corput points ``vdc``
Uncertainty Sampling with random candidates ``us_rand``
Uncertainty Sampling with optimizer ``us``
============================================ ===========
==================================================== ===========
Bayesian Monte Carlo [2]_ ``bmc``
Van Der Corput points ``vdc``
Uncertainty Sampling with random candidates ``us_rand``
Uncertainty Sampling with optimizer ``us``
Mutual information with random candidates ``mi_rand``
Mutual information with optimizer ``mi``
Integral variance reduction with random candidates ``ivr_rand``
Integral variance reduction with optimizer ``ivr``
==================================================== ===========
initial_design
The type of initial design to use. If ``None`` is given, no initial design is
Expand Down Expand Up @@ -115,13 +119,14 @@ def bayesquad(
n_initial_design_nodes : Optional[IntLike]
The number of nodes created by the initial design. Defaults to
``input_dim * 5`` if an initial design is given.
us_rand_n_candidates : Optional[IntLike]
The number of candidate nodes used by the policy 'us_rand'. Defaults
to 1e2.
us_n_restarts : Optional[IntLike]
n_candidates : Optional[IntLike]
The number of candidate nodes used by the policies that maximize an
acquisition function by drawing random candidates. Defaults to 1e2.
Applicable to policies 'us_rand', 'mi_rand' and 'ivr_rand'.
n_restarts : Optional[IntLike]
The number of restarts that the acquisition optimizer performs in
order to find the maximizer when policy 'us' is used. Defaults
to 10.
order to find the maximizer. Defaults to 10. Applicable to policies
'us', 'mi' and 'ivr'.
Returns
-------
Expand Down
72 changes: 39 additions & 33 deletions src/probnum/quad/solvers/_bayesian_quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from probnum.quad.integration_measures import IntegrationMeasure, LebesgueMeasure
from probnum.quad.kernel_embeddings import KernelEmbedding
from probnum.quad.solvers._bq_state import BQIterInfo, BQState
from probnum.quad.solvers.acquisition_functions import WeightedPredictiveVariance
from probnum.quad.solvers.acquisition_functions import (
IntegralVarianceReduction,
MutualInformation,
WeightedPredictiveVariance,
)
from probnum.quad.solvers.belief_updates import BQBeliefUpdate, BQStandardBeliefUpdate
from probnum.quad.solvers.initial_designs import InitialDesign, LatinDesign, MCDesign
from probnum.quad.solvers.policies import (
Expand Down Expand Up @@ -96,7 +100,7 @@ def __init__(
self.stopping_criterion = stopping_criterion
self.initial_design = initial_design

# pylint: disable=too-many-statements
# pylint: disable=too-many-statements, too-many-locals
@classmethod
def from_problem(
cls,
Expand Down Expand Up @@ -147,13 +151,14 @@ def from_problem(
n_initial_design_nodes : Optional[IntLike]
The number of nodes created by the initial design. Defaults to
``input_dim * 5`` if an initial design is given.
us_rand_n_candidates : Optional[IntLike]
The number of candidate nodes used by the policy 'us_rand'. Defaults
to 1e2.
us_n_restarts : Optional[IntLike]
n_candidates : Optional[IntLike]
The number of candidate nodes used by the policies that maximize an
acquisition function by drawing random candidates. Defaults to 1e2.
Applicable to policies 'us_rand', 'mi_rand' and 'ivr_rand'.
n_restarts : Optional[IntLike]
The number of restarts that the acquisition optimizer performs in
order to find the maximizer when policy 'us' is used. Defaults
to 10.
order to find the maximizer. Defaults to 10. Applicable to policies
'us', 'mi' and 'ivr'.
Returns
-------
Expand Down Expand Up @@ -190,8 +195,8 @@ def from_problem(
n_initial_design_nodes = options.get(
"n_initial_design_nodes", int(5 * input_dim)
)
us_rand_n_candidates = options.get("us_rand_n_candidates", int(1e2))
us_n_restarts = options.get("us_n_restarts", int(10))
n_candidates = options.get("n_candidates", int(1e2))
n_restarts = options.get("n_restarts", int(10))

# Set up integration measure
if domain is None and measure is None:
Expand All @@ -206,6 +211,11 @@ def from_problem(
kernel = ExpQuad(input_shape=(input_dim,))

# Select policy
acquisition_dict = dict(
mi=MutualInformation,
ivr=IntegralVarianceReduction,
us=WeightedPredictiveVariance,
)
if policy is None:
# If policy is None, this implies that the integration problem is defined
# through a fixed set of nodes and function evaluations which will not
Expand All @@ -215,17 +225,20 @@ def from_problem(
policy = RandomPolicy(batch_size, measure.sample)
elif policy == "vdc":
policy = VanDerCorputPolicy(batch_size, measure)
elif policy == "us_rand":
# all random max acquisition policies (all must contain suffix '_rand')
elif policy in ["us_rand", "mi_rand", "ivr_rand"]:
assert policy[-5:] == "_rand"
policy = RandomMaxAcquisitionPolicy(
batch_size=1,
acquisition_func=WeightedPredictiveVariance(),
n_candidates=us_rand_n_candidates,
acquisition_func=acquisition_dict[policy[:-5]],
n_candidates=n_candidates,
)
elif policy == "us":
# all max acquisition policies with optimizer
elif policy in ["us", "mi", "ivr"]:
policy = MaxAcquisitionPolicy(
batch_size=1,
acquisition_func=WeightedPredictiveVariance(),
n_restarts=us_n_restarts,
acquisition_func=acquisition_dict[policy],
n_restarts=n_restarts,
)
else:
raise NotImplementedError(f"The given policy ({policy}) is unknown.")
Expand All @@ -242,30 +255,23 @@ def _stopcrit_or(sc1, sc2):
return sc2
return sc1 | sc2

_stopping_criterion = None

_stop_crit = None
if max_evals is not None:
_stopping_criterion = _stopcrit_or(
_stopping_criterion, MaxNevals(max_evals)
)
_stop_crit = _stopcrit_or(_stop_crit, MaxNevals(max_evals))
if var_tol is not None:
_stopping_criterion = _stopcrit_or(
_stopping_criterion, IntegralVarianceTolerance(var_tol)
)
_stop_crit = _stopcrit_or(_stop_crit, IntegralVarianceTolerance(var_tol))
if rel_tol is not None:
_stopping_criterion = _stopcrit_or(
_stopping_criterion, RelativeMeanChange(rel_tol)
)
_stop_crit = _stopcrit_or(_stop_crit, RelativeMeanChange(rel_tol))

# If no stopping criteria are given, use some default values.
if _stopping_criterion is None:
_stopping_criterion = IntegralVarianceTolerance(var_tol=1e-6) | MaxNevals(
max_nevals=input_dim * 25 # 25 is an arbitrary value
)
if _stop_crit is None:
_stop_crit = IntegralVarianceTolerance(var_tol=1e-6) | MaxNevals(
max_nevals=input_dim * 25
) # 25 is an arbitrary value

# If no policy is given, then the iteration must terminate immediately.
if policy is None:
_stopping_criterion = ImmediateStop()
_stop_crit = ImmediateStop()

# Select initial design
if initial_design is None:
Expand All @@ -284,7 +290,7 @@ def _stopcrit_or(sc1, sc2):
measure=measure,
policy=policy,
belief_update=belief_update,
stopping_criterion=_stopping_criterion,
stopping_criterion=_stop_crit,
initial_design=initial_design,
)

Expand Down
7 changes: 7 additions & 0 deletions src/probnum/quad/solvers/acquisition_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
"""Acquisition functions for Bayesian quadrature."""

from ._acquisition_function import AcquisitionFunction
from ._integral_variance_reduction import IntegralVarianceReduction
from ._mutual_information import MutualInformation
from ._predictive_variance import WeightedPredictiveVariance

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"AcquisitionFunction",
"IntegralVarianceReduction",
"MutualInformation",
"WeightedPredictiveVariance",
]

# Set correct module paths. Corrects links and module paths in documentation.
AcquisitionFunction.__module__ = "probnum.quad.solvers.acquisition_functions"
IntegralVarianceReduction.__module__ = "probnum.quad.solvers.acquisition_functions"
MutualInformation.__module__ = "probnum.quad.solvers.acquisition_functions"
WeightedPredictiveVariance.__module__ = "probnum.quad.solvers.acquisition_functions"
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Integral variance reduction acquisition function for Bayesian quadrature."""

from __future__ import annotations

from typing import Optional, Tuple

import numpy as np

from probnum.quad.solvers._bq_state import BQState
from probnum.quad.solvers.belief_updates import BQStandardBeliefUpdate

from ._acquisition_function import AcquisitionFunction

# pylint: disable=too-few-public-methods


class IntegralVarianceReduction(AcquisitionFunction):
r"""The normalized reduction of the integral variance.
The acquisition function is
.. math::
a(x) &= \mathfrak{v}^{-1}(\mathfrak{v} - \mathfrak{v}(x))\\
&= \frac{(\int \bar{k}(x', x)p(x')\mathrm{d}x')^2}{\mathfrak{v} v(x)}\\
&= \rho^2(x)
where :math:`\mathfrak{v}` is the current integral variance, :math:`\mathfrak{v}(x)`
is the integral variance including a hypothetical observation at
:math:`x`, :math:`v(x)` is the predictive variance for :math:`f(x)` and
:math:`\bar{k}(x', x)` is the posterior kernel function.
The value :math:`a(x)` is equal to the squared correlation :math:`\rho^2(x)` between
the hypothetical observation at :math:`x` and the integral value. [1]_
The normalization constant :math:`\mathfrak{v}^{-1}` ensures that
:math:`a(x)\in[0, 1]`.
References
----------
.. [1] Gessner et al. Active Multi-Information Source Bayesian Quadrature,
*UAI*, 2019
"""

@property
def has_gradients(self) -> bool:
# Todo (#581): this needs to return True, once gradients are available
return False

def __call__(
self,
x: np.ndarray,
bq_state: BQState,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:

_, y_predictive_var = BQStandardBeliefUpdate.predict_integrand(x, bq_state)

# if observation noise is added to BQ, it needs to be retrieved here.
observation_noise_var = 0.0 # dummy placeholder
y_predictive_var += observation_noise_var

predictive_embedding = bq_state.kernel_embedding.kernel_mean(x)

# posterior if observations are available
if bq_state.fun_evals.shape[0] > 0:

weights = BQStandardBeliefUpdate.gram_cho_solve(
bq_state.gram_cho_factor, bq_state.kernel.matrix(bq_state.nodes, x)
)
predictive_embedding -= np.dot(bq_state.kernel_means, weights)

values = (bq_state.scale_sq * predictive_embedding) ** 2 / (
bq_state.integral_belief.cov * y_predictive_var
)
return values, None
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Mutual information acquisition function for Bayesian quadrature."""

from __future__ import annotations

from typing import Optional, Tuple

import numpy as np

from probnum.quad.solvers._bq_state import BQState

from ._acquisition_function import AcquisitionFunction
from ._integral_variance_reduction import IntegralVarianceReduction

# pylint: disable=too-few-public-methods


class MutualInformation(AcquisitionFunction):
r"""The mutual information between a hypothetical integrand observation and the
integral value.
The acquisition function is
.. math::
a(x) = -0.5 \log(1-\rho^2(x))
where :math:`\rho^2(x)` is the squared correlation between a hypothetical integrand
observations at :math:`x` and the integral value. [1]_
The mutual information is non-negative and unbounded for a 'perfect' observation
and :math:`\rho^2(x) = 1.`
References
----------
.. [1] Gessner et al. Active Multi-Information Source Bayesian Quadrature,
*UAI*, 2019
"""

@property
def has_gradients(self) -> bool:
# Todo (#581): this needs to return True, once gradients are available
return False

def __call__(
self,
x: np.ndarray,
bq_state: BQState,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
ivr = IntegralVarianceReduction()
rho2, _ = ivr(x, bq_state)
values = -0.5 * np.log(1 - rho2)
return values, None
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Uncertainty sampling for Bayesian Monte Carlo."""
"""Uncertainty sampling for Bayesian quadrature."""

from __future__ import annotations

Expand All @@ -25,6 +25,10 @@ class WeightedPredictiveVariance(AcquisitionFunction):
where :math:`\operatorname{Var}(f(x))` is the predictive variance of the model and
:math:`p(x)` is the density of the integration measure :math:`\mu`.
Notes
-----
The implementation scales :math:`a(x)` with the inverse of the squared kernel
scale for numerical stability.
"""

@property
Expand All @@ -37,12 +41,9 @@ def __call__(
x: np.ndarray,
bq_state: BQState,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
predictive_variance = bq_state.kernel(x, x)
if bq_state.fun_evals.shape != (0,):
kXx = bq_state.kernel.matrix(bq_state.nodes, x)
regression_weights = BQStandardBeliefUpdate.gram_cho_solve(
bq_state.gram_cho_factor, kXx
)
predictive_variance -= np.sum(regression_weights * kXx, axis=0)
values = bq_state.scale_sq * predictive_variance * bq_state.measure(x) ** 2

_, predictive_variance = BQStandardBeliefUpdate.predict_integrand(x, bq_state)
predictive_variance *= 1 / bq_state.scale_sq # for numerical stability

values = predictive_variance * bq_state.measure(x) ** 2
return values, None
24 changes: 24 additions & 0 deletions src/probnum/quad/solvers/belief_updates/_belief_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,27 @@ def gram_cho_solve(
"""
return cho_solve(gram_cho_factor, z)

@staticmethod
@abc.abstractmethod
def predict_integrand(
x: np.ndarray, bq_state: BQState
) -> Tuple[np.ndarray, np.ndarray]:
"""Predictive mean and variances of the integrand at given nodes.
Parameters
----------
x
*shape=(n_nodes, input_dim)* -- The nodes where to predict.
bq_state
The BQ state.
Returns
-------
mean_prediction :
*shape=(n_nodes,)* -- The means of the predictions.
var_predictions :
*shape=(n_nodes,)* -- The variances of the predictions.
"""
raise NotImplementedError

0 comments on commit 128ce48

Please sign in to comment.