Skip to content

Commit

Permalink
ML estimation of the scale parameter for quad (#712)
Browse files Browse the repository at this point in the history
  • Loading branch information
tskarvone committed Nov 17, 2022
1 parent 6fadc00 commit 35f50e4
Show file tree
Hide file tree
Showing 15 changed files with 455 additions and 74 deletions.
3 changes: 2 additions & 1 deletion src/probnum/quad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
choosing points to evaluate the integrand based on said model.
"""

from probnum.quad.solvers.policies import Policy, RandomPolicy
from probnum.quad.solvers.policies import Policy, RandomPolicy, VanDerCorputPolicy
from probnum.quad.solvers.stopping_criteria import (
BQStoppingCriterion,
ImmediateStop,
Expand Down Expand Up @@ -41,6 +41,7 @@
"IntegralVarianceTolerance",
"MaxNevals",
"RandomPolicy",
"VanDerCorputPolicy",
"RelativeMeanChange",
]

Expand Down
32 changes: 30 additions & 2 deletions src/probnum/quad/_bayesquad.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ def bayesquad(
domain: Optional[DomainLike] = None,
measure: Optional[IntegrationMeasure] = None,
policy: Optional[str] = "bmc",
scale_estimation: Optional[str] = "mle",
max_evals: Optional[IntLike] = None,
var_tol: Optional[FloatLike] = None,
rel_tol: Optional[FloatLike] = None,
batch_size: Optional[IntLike] = 1,
batch_size: IntLike = 1,
rng: Optional[np.random.Generator] = np.random.default_rng(),
jitter: FloatLike = 1.0e-8,
) -> Tuple[Normal, BQIterInfo]:
r"""Infer the solution of the uni- or multivariate integral
:math:`\int_\Omega f(x) d \mu(x)`
Expand Down Expand Up @@ -78,17 +80,32 @@ def bayesquad(
Bayesian Monte Carlo [2]_ ``bmc``
========================== =======
========================== =======
van Der Corput points ``vdc``
========================== =======
scale_estimation
Estimation method to use to compute the scale parameter. Defaults to 'mle'.
Options are
============================== =======
Maximum likelihood estimation ``mle``
============================== =======
max_evals
Maximum number of function evaluations.
var_tol
Tolerance on the variance of the integral.
rel_tol
Tolerance on consecutive updates of the integral mean.
batch_size
Number of new observations at each update.
Number of new observations at each update. Defaults to 1.
rng
Random number generator. Used by Bayesian Monte Carlo other random sampling
policies. Optional. Default is `np.random.default_rng()`.
jitter
Non-negative jitter to numerically stabilise kernel matrix inversion.
Defaults to 1e-8.
Returns
-------
Expand Down Expand Up @@ -147,11 +164,13 @@ def bayesquad(
measure=measure,
domain=domain,
policy=policy,
scale_estimation=scale_estimation,
max_evals=max_evals,
var_tol=var_tol,
rel_tol=rel_tol,
batch_size=batch_size,
rng=rng,
jitter=jitter,
)

# Integrate
Expand All @@ -166,6 +185,8 @@ def bayesquad_from_data(
kernel: Optional[Kernel] = None,
domain: Optional[DomainLike] = None,
measure: Optional[IntegrationMeasure] = None,
scale_estimation: Optional[str] = "mle",
jitter: FloatLike = 1.0e-8,
) -> Tuple[Normal, BQIterInfo]:
r"""Infer the value of an integral from a given set of nodes and function
evaluations.
Expand All @@ -184,6 +205,11 @@ def bayesquad_from_data(
``np.ndarray``. Obsolete if ``measure`` is given.
measure
The integration measure. Defaults to the Lebesgue measure.
scale_estimation
Estimation method to use to compute the scale parameter. Defaults to 'mle'.
jitter
Non-negative jitter to numerically stabilise kernel matrix inversion.
Defaults to 1e-8.
Returns
-------
Expand Down Expand Up @@ -231,6 +257,8 @@ def bayesquad_from_data(
measure=measure,
domain=domain,
policy=None,
scale_estimation=scale_estimation,
jitter=jitter,
)

# Integrate
Expand Down
4 changes: 2 additions & 2 deletions src/probnum/quad/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def as_domain(
Parameters
----------
domain
The integration domain as supllied.
The integration domain as supplied.
input_dim
The input dimensionality as supplied.
Expand All @@ -44,7 +44,7 @@ def as_domain(
ValueError
If ``input_dim`` is not positive.
If domain has too many or too little elements.
If the bounds of the domain have differening sizes.
If the bounds of the domain have differing sizes.
If ``input_dim`` is incompatible with domain bounds.
If bounds have wrong shape.
If integration domain is empty.
Expand Down
24 changes: 17 additions & 7 deletions src/probnum/quad/solvers/bayesian_quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from probnum.quad.solvers.policies import Policy, RandomPolicy
from probnum.quad.solvers.policies import Policy, RandomPolicy, VanDerCorputPolicy
from probnum.quad.solvers.stopping_criteria import (
BQStoppingCriterion,
ImmediateStop,
Expand All @@ -23,6 +23,8 @@
from .belief_updates import BQBeliefUpdate, BQStandardBeliefUpdate
from .bq_state import BQIterInfo, BQState

# pylint: disable=too-many-branches, too-complex


class BayesianQuadrature:
r"""The Bayesian quadrature method.
Expand Down Expand Up @@ -75,11 +77,13 @@ def from_problem(
measure: Optional[IntegrationMeasure] = None,
domain: Optional[DomainLike] = None,
policy: Optional[str] = "bmc",
scale_estimation: Optional[str] = "mle",
max_evals: Optional[IntLike] = None,
var_tol: Optional[FloatLike] = None,
rel_tol: Optional[FloatLike] = None,
batch_size: IntLike = 1,
rng: np.random.Generator = None,
jitter: FloatLike = 1.0e-8,
) -> "BayesianQuadrature":

r"""Creates an instance of this class from a problem description.
Expand All @@ -97,16 +101,21 @@ def from_problem(
policy
The policy choosing nodes at which to evaluate the integrand.
Choose ``None`` if you want to integrate from a fixed dataset.
scale_estimation
Estimation method to use to compute the scale parameter. Defaults to 'mle'.
max_evals
Maximum number of evaluations as stopping criterion.
var_tol
Variance tolerance as stopping criterion.
rel_tol
Relative tolerance as stopping criterion.
batch_size
Batch size used in node acquisition.
Batch size used in node acquisition. Defaults to 1.
rng
The random number generator.
jitter
Non-negative jitter to numerically stabilise kernel matrix inversion.
Defaults to 1e-8.
Returns
-------
Expand Down Expand Up @@ -150,14 +159,15 @@ def from_problem(
)
raise ValueError(errormsg)
policy = RandomPolicy(measure.sample, batch_size=batch_size, rng=rng)

elif policy == "vdc":
policy = VanDerCorputPolicy(measure=measure, batch_size=batch_size)
else:
raise NotImplementedError(
"Policies other than random sampling are not available at the moment."
)
raise NotImplementedError(f"The given policy ({policy}) is unknown.")

# Select the belief updater
belief_update = BQStandardBeliefUpdate()
belief_update = BQStandardBeliefUpdate(
jitter=jitter, scale_estimation=scale_estimation
)

# Select stopping criterion: If multiple stopping criteria are given, BQ stops
# once any criterion is fulfilled (logical `or`).
Expand Down

0 comments on commit 35f50e4

Please sign in to comment.