Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multilevel Bayesian quadrature #750

Merged
merged 46 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
7e6167c
Add Kaiyu's original MLBQ implementation in probnum
Nov 29, 2022
349b749
Some edits to MLBQ
Dec 1, 2022
15e4e44
Write a proper version of MLBQ
Dec 2, 2022
bdd2753
Add some input handling a write docstring
Dec 5, 2022
4a1cea9
Add multilevel example for docstring
Dec 5, 2022
b5f8c1a
Put the example to the correct place
Dec 5, 2022
c3ea2a2
Add two basic multilevel BQ tests
Dec 5, 2022
06d6730
Add input handling test for multilevel BQ
Dec 6, 2022
7a83c1d
Documentation for multilevel BQ
Dec 8, 2022
945aaa5
Some linting
Dec 8, 2022
17d0047
Fix formattin of a ref
Dec 9, 2022
bb14d45
Merge branch 'main' into multilevel
tskarvone Dec 13, 2022
1a8e1cb
Merge branch 'main' into multilevel
tskarvone Dec 14, 2022
942cb00
Merge branch 'main' into multilevel
tskarvone Feb 10, 2023
214c97c
Merge branch 'main' into multilevel
tskarvone Feb 11, 2023
7e17072
Update src/probnum/quad/_bayesquad.py
tskarvone Feb 11, 2023
c495965
Update src/probnum/quad/_bayesquad.py
tskarvone Feb 11, 2023
288cde9
Update src/probnum/quad/_bayesquad.py
tskarvone Feb 11, 2023
02153ca
Update src/probnum/quad/_bayesquad.py
tskarvone Feb 11, 2023
763e4bc
Fix things pointed out in review
Feb 12, 2023
9c0d921
Try to fix some failed checks
Feb 12, 2023
e703267
Update src/probnum/quad/_bayesquad.py
tskarvone Feb 15, 2023
c65ff73
Tuple length documentation
Feb 15, 2023
c060fce
Update src/probnum/quad/_bayesquad.py
tskarvone Feb 15, 2023
f788737
Update src/probnum/quad/_bayesquad.py
tskarvone Feb 15, 2023
47b4a31
merge
Feb 15, 2023
b2bc28c
Merge
Feb 15, 2023
31ce405
Update src/probnum/quad/_bayesquad.py
tskarvone Feb 15, 2023
9597e80
Update src/probnum/quad/_bayesquad.py
tskarvone Feb 15, 2023
4ea2ccf
Update tests/test_quad/test_bayesquad/test_bq.py
tskarvone Feb 15, 2023
c9d405a
Update tests/test_quad/test_bayesquad/test_bq.py
tskarvone Feb 15, 2023
ed1a1bc
merge
Feb 15, 2023
5cb3a7e
Update tests/test_quad/test_bayesquad/test_bq.py
tskarvone Feb 15, 2023
366972c
Update tests/test_quad/test_bayesquad/test_bq.py
tskarvone Feb 15, 2023
24856c8
Update tests/test_quad/test_bayesquad/test_bq.py
tskarvone Feb 15, 2023
45f9c52
Update tests/test_quad/test_bayesquad/test_bq.py
tskarvone Feb 15, 2023
c56604d
Update tests/test_quad/test_bayesquad/test_bq.py
tskarvone Feb 15, 2023
85020ac
merge
Feb 15, 2023
b7b4e51
Kernel copying in MLBQ test
Feb 15, 2023
39645f9
Update tests/test_quad/test_bayesquad/test_bq.py
tskarvone Feb 15, 2023
905f0c8
Remove redudant test
Feb 15, 2023
a67506b
Update src/probnum/quad/_bayesquad.py
tskarvone Feb 15, 2023
9b6a237
Multilevel BQ test refactoring
Feb 15, 2023
cc49832
Merge branch 'main' into multilevel
tskarvone Feb 15, 2023
7258d32
Try to make pylint happy
Feb 15, 2023
be1fa10
some more minor changes
mmahsereci Feb 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/probnum/quad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
"""

from . import integration_measures, kernel_embeddings, solvers
from ._bayesquad import bayesquad, bayesquad_from_data
from ._bayesquad import bayesquad, bayesquad_from_data, multilevel_bayesquad_from_data

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"bayesquad",
"bayesquad_from_data",
"multilevel_bayesquad_from_data",
]

# Set correct module paths. Corrects links and module paths in documentation.
bayesquad.__module__ = "probnum.quad"
bayesquad_from_data.__module__ = "probnum.quad"
multilevel_bayesquad_from_data.__module__ = "probnum.quad"
143 changes: 140 additions & 3 deletions src/probnum/quad/_bayesquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def bayesquad(
References
----------
.. [1] Briol, F.-X., et al., Probabilistic integration: A role in statistical
computation?, *Statistical Science 34.1*, 2019, 1-22, 2019
computation?, *Statistical Science 34.1*, 2019, 1-22.
.. [2] Rasmussen, C. E., and Z. Ghahramani, Bayesian Monte Carlo, *Advances in
Neural Information Processing Systems*, 2003, 505-512.
.. [3] Mckay et al., A Comparison of Three Methods for Selecting Values of Input
Expand All @@ -168,7 +168,6 @@ def bayesquad(
Examples
--------
>>> import numpy as np

>>> input_dim = 1
>>> domain = (0, 1)
>>> def fun(x):
Expand Down Expand Up @@ -299,12 +298,150 @@ def bayesquad_from_data(
return integral_belief, info


def multilevel_bayesquad_from_data(
nodes: Tuple[np.ndarray, ...],
fun_diff_evals: Tuple[np.ndarray, ...],
kernels: Optional[Tuple[Kernel, ...]] = None,
measure: Optional[IntegrationMeasure] = None,
domain: Optional[DomainLike] = None,
options: Optional[dict] = None,
) -> Tuple[Normal, Tuple[BQIterInfo, ...]]:
r"""Infer the value of an integral from given sets of nodes and function
evaluations using a multilevel method.

In multilevel Bayesian quadrature, the integral :math:`\int_\Omega f(x) d \mu(x)`
is (approximately) decomposed as a telescoping sum over :math:`L+1` levels:

.. math:: \int_\Omega f(x) d \mu(x) \approx \int_\Omega f_0(x) d
\mu(x) + \sum_{l=1}^L \int_\Omega [f_l(x) - f_{l-1}(x)] d \mu(x),

where :math:`f_l` is an increasingly accurate but also increasingly expensive
tskarvone marked this conversation as resolved.
Show resolved Hide resolved
approximation to :math:`f`. It is not necessary that the highest level approximation
:math:`f_L` be equal to :math:`f`.

Bayesian quadrature is subsequently applied to independently infer each of the
:math:`L+1` integrals and the outputs are summed to infer
:math:`\int_\Omega f(x) d \mu(x)`. [1]_

Parameters
----------
nodes
tskarvone marked this conversation as resolved.
Show resolved Hide resolved
Tuple of length :math:`L+1` containing the locations for each level at which
the functionn evaluations are available as ``fun_diff_evals``. Each element
must be a shape=(n_eval, input_dim) ``np.ndarray``. If a tuple containing only
one element is provided, it is inferred that the same nodes ``nodes[0]`` are
used on every level.
fun_diff_evals
Tuple of length :math:`L+1` containing the evaluations of :math:`f_l - f_{l-1}`
for each level at the nodes provided in ``nodes``. Each element must be a
shape=(n_eval,) ``np.ndarray``. The zeroth element contains the evaluations of
:math:`f_0`.
kernels
Tuple of length :math:`L+1` containing the kernels used for the GP model at each
level. See **Notes** for further details. Defaults to the ``ExpQuad`` kernel for
each level.
measure
The integration measure. Defaults to the Lebesgue measure.
domain
The integration domain. Contains lower and upper bound as scalar or
``np.ndarray``. Obsolete if ``measure`` is given.
options
A dictionary with the following optional solver settings

scale_estimation : Optional[str]
Estimation method to use to compute the scale parameter. Used
independently on each level. Defaults to 'mle'. Options are

============================== =======
Maximum likelihood estimation ``mle``
============================== =======

jitter : Optional[FloatLike]
Non-negative jitter to numerically stabilise kernel matrix
inversion. Same jitter is used on each level. Defaults to 1e-8.

Returns
-------
integral :
The integral belief subject to the provided measure or domain.
infos :
Information on the performance of the method for each level.

Raises
------
ValueError
If ``nodes``, ``fun_diff_evals`` or ``kernels`` have different lengths.

Warns
-----
UserWarning
When ``domain`` is given but not used.

tskarvone marked this conversation as resolved.
Show resolved Hide resolved
Notes
-----
The tuple of kernels provided by the ``kernels`` parameter must contain distinct
kernel instances, i.e., ``kernels[i] is kernel[j]`` must return ``False`` for any
:math:`i\neq j`.

References
----------
.. [1] Li, K., et al., Multilevel Bayesian quadrature, AISTATS, 2023.

Examples
--------
>>> import numpy as np
>>> input_dim = 1
>>> domain = (0, 1)
>>> n_level = 6
>>> def fun(x, l):
... return x.reshape(-1, ) / (l + 1.0)
>>> nodes = ()
>>> fun_diff_evals = ()
>>> for l in range(n_level):
... n_l = 2*l + 1
... nodes += (np.reshape(np.linspace(0, 1, n_l), (n_l, input_dim)),)
... fun_diff_evals += (np.reshape(fun(nodes[l], l), (n_l,)),)
>>> F, infos = multilevel_bayesquad_from_data(nodes=nodes,
... fun_diff_evals=fun_diff_evals,
... domain=domain)
>>> print(np.round(F.mean, 4))
0.7252
"""

n_level = len(fun_diff_evals)
if kernels is None:
kernels = n_level * (None,)
if len(nodes) == 1:
nodes = n_level * (nodes[0],)
if not len(nodes) == len(fun_diff_evals) == len(kernels):
raise ValueError(
f"You must provide an equal number of kernels ({(len(kernels))}), "
f"vectors of function evaluations ({len(fun_diff_evals)}) "
f"and sets of nodes ({len(nodes)})."
)

integer_belief = Normal(mean=0.0, cov=0.0)
infos = ()
for level in range(n_level):
integer_belief_l, info_l = bayesquad_from_data(
nodes=nodes[level],
fun_evals=fun_diff_evals[level],
kernel=kernels[level],
measure=measure,
domain=domain,
options=options,
)
integer_belief += integer_belief_l
infos += (info_l,)

return integer_belief, infos


def _check_domain_measure_compatibility(
input_dim: IntLike,
domain: Optional[DomainLike],
measure: Optional[IntegrationMeasure],
) -> Tuple[int, Optional[DomainType], IntegrationMeasure]:

input_dim = int(input_dim)

# Neither domain nor measure given
Expand Down
148 changes: 146 additions & 2 deletions tests/test_quad/test_bayesquad/test_bq.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Test cases for Bayesian quadrature."""
import copy

import numpy as np
import pytest
from scipy.integrate import quad as scipyquad

from probnum.quad import bayesquad, bayesquad_from_data
from probnum.quad.integration_measures import LebesgueMeasure
from probnum.quad import bayesquad, bayesquad_from_data, multilevel_bayesquad_from_data
from probnum.quad.integration_measures import GaussianMeasure, LebesgueMeasure
from probnum.quad.kernel_embeddings import KernelEmbedding
from probnum.randvars import Normal

Expand Down Expand Up @@ -219,3 +220,146 @@ def test_zero_function_gives_zero_variance_with_mle(rng):
)
assert bq_integral1.var == 0.0
assert bq_integral2.var == 0.0


def test_multilevel_bayesquad_from_data_output_types_and_shapes(kernel, measure, rng):
"""Test correct output for different inputs to multilevel BQ."""

# full set of nodes
ns_1 = (3, 7, 2)
n_level_1 = len(ns_1)
fun_diff_evals_1 = tuple(np.zeros(ns_1[l]) for l in range(n_level_1))
nodes_full = tuple(measure.sample((ns_1[l]), rng=rng) for l in range(n_level_1))

# i) default kernel
F, infos = multilevel_bayesquad_from_data(
nodes=nodes_full,
fun_diff_evals=fun_diff_evals_1,
measure=measure,
)
assert isinstance(F, Normal)
assert len(infos) == n_level_1

# ii) full kernel list
kernels_full_1 = tuple(copy.deepcopy(kernel) for _ in range(n_level_1))
F, infos = multilevel_bayesquad_from_data(
nodes=nodes_full,
fun_diff_evals=fun_diff_evals_1,
kernels=kernels_full_1,
measure=measure,
)
assert isinstance(F, Normal)
assert len(infos) == n_level_1

# one set of nodes
n_level_2 = 3
ns_2 = n_level_2 * (7,)
fun_diff_evals_2 = tuple(np.zeros(ns_2[l]) for l in range(n_level_2))
nodes_1 = (measure.sample(n_sample=ns_2[0], rng=rng),)

# i) default kernel
F, infos = multilevel_bayesquad_from_data(
nodes=nodes_1,
fun_diff_evals=fun_diff_evals_2,
measure=measure,
)
assert isinstance(F, Normal)
assert len(infos) == n_level_2

# ii) full kernel list
kernels_full_2 = tuple(copy.deepcopy(kernel) for _ in range(n_level_2))
F, infos = multilevel_bayesquad_from_data(
nodes=nodes_1,
fun_diff_evals=fun_diff_evals_2,
kernels=kernels_full_2,
measure=measure,
)
assert isinstance(F, Normal)
assert len(infos) == n_level_2


def test_multilevel_bayesquad_from_data_wrong_inputs(kernel, measure, rng):
"""Tests that wrong number inputs to multilevel BQ throw errors."""
ns = (3, 7, 11)
n_level = len(ns)
fun_diff_evals = tuple(np.zeros(ns[l]) for l in range(n_level))

# number of nodes does not match the number of fun evals
wrong_n_nodes = 2
nodes_2 = tuple(measure.sample((ns[l]), rng=rng) for l in range(wrong_n_nodes))
with pytest.raises(ValueError):
multilevel_bayesquad_from_data(
nodes=nodes_2,
fun_diff_evals=fun_diff_evals,
measure=measure,
)

# number of kernels does not match number of fun evals
wrong_n_kernels = 2
kernels = tuple(copy.deepcopy(kernel) for _ in range(wrong_n_kernels))
nodes_1 = (measure.sample(n_sample=ns[0], rng=rng),)
with pytest.raises(ValueError):
multilevel_bayesquad_from_data(
nodes=nodes_1,
fun_diff_evals=fun_diff_evals,
kernels=kernels,
measure=measure,
)


def test_multilevel_bayesquad_from_data_equals_bq_with_trivial_data_1d():
"""Test that multilevel BQ equals BQ when all but one level are given non-zero
function evaluations for 1D data."""
n_level = 5
domain = (0, 3.3)
nodes = tuple(np.linspace(0, 1, 2 * l + 1)[:, None] for l in range(n_level))
for i in range(n_level):
jitter = 1e-5 * (i + 1.0)
fun_evals = nodes[i][:, 0] ** (2 + 0.3 * i) + 1.2
fun_diff_evals = [np.zeros(shape=(len(xs),)) for xs in nodes]
fun_diff_evals[i] = fun_evals
mlbq_integral, _ = multilevel_bayesquad_from_data(
nodes=nodes,
fun_diff_evals=tuple(fun_diff_evals),
domain=domain,
options=dict(jitter=jitter),
)
bq_integral, _ = bayesquad_from_data(
nodes=nodes[i],
fun_evals=fun_evals,
domain=domain,
options=dict(jitter=jitter),
)
assert mlbq_integral.mean == bq_integral.mean
assert mlbq_integral.cov == bq_integral.cov


def test_multilevel_bayesquad_from_data_equals_bq_with_trivial_data_2d():
"""Test that multilevel BQ equals BQ when all but one level are given non-zero
function evaluations for 2D data."""
input_dim = 2
n_level = 5
measure = GaussianMeasure(np.full((input_dim,), 0.2), cov=0.6 * np.eye(input_dim))
_gh = gauss_hermite_tensor
nodes = tuple(
_gh(l + 1, input_dim, measure.mean, measure.cov)[0] for l in range(n_level)
)
for i in range(n_level):
jitter = 1e-5 * (i + 1.0)
fun_evals = np.sin(nodes[i][:, 0] * i) + (i + 1.0) * np.cos(nodes[i][:, 1])
fun_diff_evals = [np.zeros(shape=(len(xs),)) for xs in nodes]
fun_diff_evals[i] = fun_evals
mlbq_integral, _ = multilevel_bayesquad_from_data(
nodes=nodes,
fun_diff_evals=tuple(fun_diff_evals),
measure=measure,
options=dict(jitter=jitter),
)
bq_integral, _ = bayesquad_from_data(
nodes=nodes[i],
fun_evals=fun_evals,
measure=measure,
options=dict(jitter=jitter),
)
assert mlbq_integral.mean == bq_integral.mean
assert mlbq_integral.cov == bq_integral.cov