-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Some more acquisition functions for
quad
(#758)
- Loading branch information
1 parent
f657a87
commit 128ce48
Showing
12 changed files
with
402 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,20 @@ | ||
"""Acquisition functions for Bayesian quadrature.""" | ||
|
||
from ._acquisition_function import AcquisitionFunction | ||
from ._integral_variance_reduction import IntegralVarianceReduction | ||
from ._mutual_information import MutualInformation | ||
from ._predictive_variance import WeightedPredictiveVariance | ||
|
||
# Public classes and functions. Order is reflected in documentation. | ||
__all__ = [ | ||
"AcquisitionFunction", | ||
"IntegralVarianceReduction", | ||
"MutualInformation", | ||
"WeightedPredictiveVariance", | ||
] | ||
|
||
# Set correct module paths. Corrects links and module paths in documentation. | ||
AcquisitionFunction.__module__ = "probnum.quad.solvers.acquisition_functions" | ||
IntegralVarianceReduction.__module__ = "probnum.quad.solvers.acquisition_functions" | ||
MutualInformation.__module__ = "probnum.quad.solvers.acquisition_functions" | ||
WeightedPredictiveVariance.__module__ = "probnum.quad.solvers.acquisition_functions" |
75 changes: 75 additions & 0 deletions
75
src/probnum/quad/solvers/acquisition_functions/_integral_variance_reduction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Integral variance reduction acquisition function for Bayesian quadrature.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
from probnum.quad.solvers._bq_state import BQState | ||
from probnum.quad.solvers.belief_updates import BQStandardBeliefUpdate | ||
|
||
from ._acquisition_function import AcquisitionFunction | ||
|
||
# pylint: disable=too-few-public-methods | ||
|
||
|
||
class IntegralVarianceReduction(AcquisitionFunction): | ||
r"""The normalized reduction of the integral variance. | ||
The acquisition function is | ||
.. math:: | ||
a(x) &= \mathfrak{v}^{-1}(\mathfrak{v} - \mathfrak{v}(x))\\ | ||
&= \frac{(\int \bar{k}(x', x)p(x')\mathrm{d}x')^2}{\mathfrak{v} v(x)}\\ | ||
&= \rho^2(x) | ||
where :math:`\mathfrak{v}` is the current integral variance, :math:`\mathfrak{v}(x)` | ||
is the integral variance including a hypothetical observation at | ||
:math:`x`, :math:`v(x)` is the predictive variance for :math:`f(x)` and | ||
:math:`\bar{k}(x', x)` is the posterior kernel function. | ||
The value :math:`a(x)` is equal to the squared correlation :math:`\rho^2(x)` between | ||
the hypothetical observation at :math:`x` and the integral value. [1]_ | ||
The normalization constant :math:`\mathfrak{v}^{-1}` ensures that | ||
:math:`a(x)\in[0, 1]`. | ||
References | ||
---------- | ||
.. [1] Gessner et al. Active Multi-Information Source Bayesian Quadrature, | ||
*UAI*, 2019 | ||
""" | ||
|
||
@property | ||
def has_gradients(self) -> bool: | ||
# Todo (#581): this needs to return True, once gradients are available | ||
return False | ||
|
||
def __call__( | ||
self, | ||
x: np.ndarray, | ||
bq_state: BQState, | ||
) -> Tuple[np.ndarray, Optional[np.ndarray]]: | ||
|
||
_, y_predictive_var = BQStandardBeliefUpdate.predict_integrand(x, bq_state) | ||
|
||
# if observation noise is added to BQ, it needs to be retrieved here. | ||
observation_noise_var = 0.0 # dummy placeholder | ||
y_predictive_var += observation_noise_var | ||
|
||
predictive_embedding = bq_state.kernel_embedding.kernel_mean(x) | ||
|
||
# posterior if observations are available | ||
if bq_state.fun_evals.shape[0] > 0: | ||
|
||
weights = BQStandardBeliefUpdate.gram_cho_solve( | ||
bq_state.gram_cho_factor, bq_state.kernel.matrix(bq_state.nodes, x) | ||
) | ||
predictive_embedding -= np.dot(bq_state.kernel_means, weights) | ||
|
||
values = (bq_state.scale_sq * predictive_embedding) ** 2 / ( | ||
bq_state.integral_belief.cov * y_predictive_var | ||
) | ||
return values, None |
52 changes: 52 additions & 0 deletions
52
src/probnum/quad/solvers/acquisition_functions/_mutual_information.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
"""Mutual information acquisition function for Bayesian quadrature.""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
from probnum.quad.solvers._bq_state import BQState | ||
|
||
from ._acquisition_function import AcquisitionFunction | ||
from ._integral_variance_reduction import IntegralVarianceReduction | ||
|
||
# pylint: disable=too-few-public-methods | ||
|
||
|
||
class MutualInformation(AcquisitionFunction): | ||
r"""The mutual information between a hypothetical integrand observation and the | ||
integral value. | ||
The acquisition function is | ||
.. math:: | ||
a(x) = -0.5 \log(1-\rho^2(x)) | ||
where :math:`\rho^2(x)` is the squared correlation between a hypothetical integrand | ||
observations at :math:`x` and the integral value. [1]_ | ||
The mutual information is non-negative and unbounded for a 'perfect' observation | ||
and :math:`\rho^2(x) = 1.` | ||
References | ||
---------- | ||
.. [1] Gessner et al. Active Multi-Information Source Bayesian Quadrature, | ||
*UAI*, 2019 | ||
""" | ||
|
||
@property | ||
def has_gradients(self) -> bool: | ||
# Todo (#581): this needs to return True, once gradients are available | ||
return False | ||
|
||
def __call__( | ||
self, | ||
x: np.ndarray, | ||
bq_state: BQState, | ||
) -> Tuple[np.ndarray, Optional[np.ndarray]]: | ||
ivr = IntegralVarianceReduction() | ||
rho2, _ = ivr(x, bq_state) | ||
values = -0.5 * np.log(1 - rho2) | ||
return values, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.