From 8dde541fc33acee4859a4e8db0c733bdd18a75a1 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci <42842079+mmahsereci@users.noreply.github.com> Date: Thu, 24 Nov 2022 18:30:19 +0100 Subject: [PATCH] Make `quad` policies stateless (#744) --- src/probnum/quad/_bayesquad.py | 15 +-- .../quad/solvers/_bayesian_quadrature.py | 43 +++---- src/probnum/quad/solvers/policies/_policy.py | 21 +++- .../quad/solvers/policies/_random_policy.py | 27 +++-- .../policies/_van_der_corput_policy.py | 17 ++- .../solvers/stopping_criteria/_max_nevals.py | 2 +- tests/test_quad/test_bayesian_quadrature.py | 55 +++++++-- tests/test_quad/test_bayesquad/test_bq.py | 23 ++-- tests/test_quad/test_policy.py | 107 +++++++++++++++++- 9 files changed, 242 insertions(+), 68 deletions(-) diff --git a/src/probnum/quad/_bayesquad.py b/src/probnum/quad/_bayesquad.py index ffb185dea..6ddf8364e 100644 --- a/src/probnum/quad/_bayesquad.py +++ b/src/probnum/quad/_bayesquad.py @@ -33,7 +33,7 @@ def bayesquad( var_tol: Optional[FloatLike] = None, rel_tol: Optional[FloatLike] = None, batch_size: IntLike = 1, - rng: Optional[np.random.Generator] = np.random.default_rng(), + rng: Optional[np.random.Generator] = None, jitter: FloatLike = 1.0e-8, ) -> Tuple[Normal, BQIterInfo]: r"""Infer the solution of the uni- or multivariate integral @@ -100,7 +100,7 @@ def bayesquad( Number of new observations at each update. Defaults to 1. rng Random number generator. Used by Bayesian Monte Carlo other random sampling - policies. Optional. Default is `np.random.default_rng()`. + policies. jitter Non-negative jitter to numerically stabilise kernel matrix inversion. Defaults to 1e-8. @@ -145,9 +145,9 @@ def bayesquad( >>> input_dim = 1 >>> domain = (0, 1) - >>> def f(x): + >>> def fun(x): ... return x.reshape(-1, ) - >>> F, info = bayesquad(fun=f, input_dim=input_dim, domain=domain) + >>> F, info = bayesquad(fun, input_dim, domain=domain, rng=np.random.default_rng(0)) >>> print(F.mean) 0.5 """ @@ -167,12 +167,13 @@ def bayesquad( var_tol=var_tol, rel_tol=rel_tol, batch_size=batch_size, - rng=rng, jitter=jitter, ) # Integrate - integral_belief, _, info = bq_method.integrate(fun=fun, nodes=None, fun_evals=None) + integral_belief, _, info = bq_method.integrate( + fun=fun, nodes=None, fun_evals=None, rng=rng + ) return integral_belief, info @@ -261,7 +262,7 @@ def bayesquad_from_data( # Integrate integral_belief, _, info = bq_method.integrate( - fun=None, nodes=nodes, fun_evals=fun_evals + fun=None, nodes=nodes, fun_evals=fun_evals, rng=None ) return integral_belief, info diff --git a/src/probnum/quad/solvers/_bayesian_quadrature.py b/src/probnum/quad/solvers/_bayesian_quadrature.py index 6b037ecba..691d36a80 100644 --- a/src/probnum/quad/solvers/_bayesian_quadrature.py +++ b/src/probnum/quad/solvers/_bayesian_quadrature.py @@ -83,7 +83,6 @@ def from_problem( var_tol: Optional[FloatLike] = None, rel_tol: Optional[FloatLike] = None, batch_size: IntLike = 1, - rng: np.random.Generator = None, jitter: FloatLike = 1.0e-8, ) -> "BayesianQuadrature": @@ -112,8 +111,6 @@ def from_problem( Relative tolerance as stopping criterion. batch_size Batch size used in node acquisition. Defaults to 1. - rng - The random number generator. jitter Non-negative jitter to numerically stabilise kernel matrix inversion. Defaults to 1e-8. @@ -127,9 +124,6 @@ def from_problem( ------ ValueError If neither a ``domain`` nor a ``measure`` are given. - ValueError - If Bayesian Monte Carlo ('bmc') is selected as ``policy`` and no random - number generator (``rng``) is given. NotImplementedError If an unknown ``policy`` is given. """ @@ -153,15 +147,9 @@ def from_problem( # require an acquisition loop. The error handling is done in ``integrate``. pass elif policy == "bmc": - if rng is None: - errormsg = ( - "Policy 'bmc' relies on random sampling, " - "thus requires a random number generator ('rng')." - ) - raise ValueError(errormsg) - policy = RandomPolicy(measure.sample, batch_size=batch_size, rng=rng) + policy = RandomPolicy(batch_size, measure.sample) elif policy == "vdc": - policy = VanDerCorputPolicy(measure=measure, batch_size=batch_size) + policy = VanDerCorputPolicy(batch_size, measure) else: raise NotImplementedError(f"The given policy ({policy}) is unknown.") @@ -215,6 +203,7 @@ def bq_iterator( bq_state: BQState, info: Optional[BQIterInfo], fun: Optional[Callable], + rng: Optional[np.random.Generator], ) -> Tuple[Normal, BQState, BQIterInfo]: """Generator that implements the iteration of the BQ method. @@ -231,6 +220,8 @@ def bq_iterator( fun Function to be integrated. It needs to accept a shape=(n_eval, input_dim) ``np.ndarray`` and return a shape=(n_eval,) ``np.ndarray``. + rng + The random number generator used for random methods. Yields ------ @@ -258,7 +249,7 @@ def bq_iterator( break # Select new nodes via policy - new_nodes = self.policy(bq_state=bq_state) + new_nodes = self.policy(bq_state, rng) # Evaluate the integrand at new nodes new_fun_evals = fun(new_nodes) @@ -278,6 +269,7 @@ def integrate( fun: Optional[Callable], nodes: Optional[np.ndarray], fun_evals: Optional[np.ndarray], + rng: Optional[np.random.Generator] = None, ) -> Tuple[Normal, BQState, BQIterInfo]: """Integrates the function ``fun``. @@ -297,6 +289,8 @@ def integrate( fun_evals *shape=(n_eval,)* -- Optional function evaluations at ``nodes`` available from the start. + rng + The random number generator used for random methods. Returns ------- @@ -308,14 +302,17 @@ def integrate( Raises ------ ValueError - If neither the integrand function (``fun``) nor integrand evaluations - (``fun_evals``) are given. + If neither the integrand function ``fun`` nor integrand evaluations + ``fun_evals`` are given. ValueError - If ``nodes`` are not given and no policy is present. + If neither ``nodes`` nor ``policy`` is given. ValueError If dimension of ``nodes`` or ``fun_evals`` is incorrect, or if their shapes do not match. + ValueError + If ``rng`` is not given but ``policy`` requires it. """ + # no policy given: Integrate on fixed dataset. if self.policy is None: # nodes must be provided if no policy is given. @@ -325,13 +322,19 @@ def integrate( # Use fun_evals and disregard fun if both are given if fun is not None and fun_evals is not None: warnings.warn( - "No policy available: 'fun_eval' are used instead of 'fun'." + "No policy available: 'fun_evals' are used instead of 'fun'." ) fun = None # override stopping condition as no policy is given. self.stopping_criterion = ImmediateStop() + elif self.policy.requires_rng and rng is None: + raise ValueError( + f"The policy '{self.policy.__class__.__name__}' requires a random " + f"number generator (rng) to be given." + ) + # Check if integrand function is provided if fun is None and fun_evals is None: raise ValueError( @@ -375,7 +378,7 @@ def integrate( ) info = None - for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun): + for (_, bq_state, info) in self.bq_iterator(bq_state, info, fun, rng): pass return bq_state.integral_belief, bq_state, info diff --git a/src/probnum/quad/solvers/policies/_policy.py b/src/probnum/quad/solvers/policies/_policy.py index f57d3438e..8fb1a3be7 100644 --- a/src/probnum/quad/solvers/policies/_policy.py +++ b/src/probnum/quad/solvers/policies/_policy.py @@ -1,10 +1,14 @@ """Abstract base class for BQ policies.""" +from __future__ import annotations + import abc +from typing import Optional import numpy as np from probnum.quad.solvers._bq_state import BQState +from probnum.typing import IntLike # pylint: disable=too-few-public-methods, fixme @@ -18,17 +22,28 @@ class Policy(abc.ABC): Size of batch of nodes when calling the policy once. """ - def __init__(self, batch_size: int) -> None: - self.batch_size = batch_size + def __init__(self, batch_size: IntLike) -> None: + self.batch_size = int(batch_size) + @property @abc.abstractmethod - def __call__(self, bq_state: BQState) -> np.ndarray: + def requires_rng(self) -> bool: + """Whether the policy requires a random number generator when called.""" + raise NotImplementedError + + @abc.abstractmethod + def __call__( + self, bq_state: BQState, rng: Optional[np.random.Generator] + ) -> np.ndarray: """Find nodes according to the policy. Parameters ---------- bq_state State of the BQ belief. + rng + A random number generator. + Returns ------- nodes : diff --git a/src/probnum/quad/solvers/policies/_random_policy.py b/src/probnum/quad/solvers/policies/_random_policy.py index cae7f297e..d6f417a85 100644 --- a/src/probnum/quad/solvers/policies/_random_policy.py +++ b/src/probnum/quad/solvers/policies/_random_policy.py @@ -1,10 +1,13 @@ """Random policy for Bayesian Monte Carlo.""" -from typing import Callable +from __future__ import annotations + +from typing import Callable, Optional import numpy as np from probnum.quad.solvers._bq_state import BQState +from probnum.typing import IntLike from ._policy import Policy @@ -16,25 +19,27 @@ class RandomPolicy(Policy): Parameters ---------- + batch_size + Size of batch of nodes when calling the policy once. sample_func The sample function. Needs to have the following interface: `sample_func(batch_size: int, rng: np.random.Generator)` and return an array of - shape (batch_size, n_dim). - batch_size - Size of batch of nodes when calling the policy once. - rng - A random number generator. + shape (batch_size, input_dim). """ def __init__( self, + batch_size: IntLike, sample_func: Callable, - batch_size: int, - rng: np.random.Generator = np.random.default_rng(), ) -> None: super().__init__(batch_size=batch_size) self.sample_func = sample_func - self.rng = rng - def __call__(self, bq_state: BQState) -> np.ndarray: - return self.sample_func(self.batch_size, rng=self.rng) + @property + def requires_rng(self) -> bool: + return True + + def __call__( + self, bq_state: BQState, rng: Optional[np.random.Generator] + ) -> np.ndarray: + return self.sample_func(self.batch_size, rng=rng) diff --git a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py index 3d78bc0e9..c276f5946 100644 --- a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py +++ b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py @@ -1,11 +1,14 @@ """Van der Corput points for integration on 1D intervals.""" +from __future__ import annotations + from typing import Optional import numpy as np from probnum.quad.integration_measures import IntegrationMeasure from probnum.quad.solvers._bq_state import BQState +from probnum.typing import IntLike from ._policy import Policy @@ -22,17 +25,17 @@ class VanDerCorputPolicy(Policy): Parameters ---------- - measure - The integration measure with finite domain. batch_size Size of batch of nodes when calling the policy once. + measure + The integration measure with finite domain. References -------- .. [1] https://en.wikipedia.org/wiki/Van_der_Corput_sequence """ - def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None: + def __init__(self, batch_size: IntLike, measure: IntegrationMeasure) -> None: super().__init__(batch_size=batch_size) if int(measure.input_dim) > 1: @@ -46,7 +49,13 @@ def __init__(self, measure: IntegrationMeasure, batch_size: int) -> None: self.domain_a = domain_a self.domain_b = domain_b - def __call__(self, bq_state: BQState) -> np.ndarray: + @property + def requires_rng(self) -> bool: + return False + + def __call__( + self, bq_state: BQState, rng: Optional[np.random.Generator] + ) -> np.ndarray: n_nodes = bq_state.nodes.shape[0] vdc_seq = VanDerCorputPolicy.van_der_corput_sequence( n_nodes + 1, n_nodes + 1 + self.batch_size diff --git a/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py index bf87dd252..9d50d19c9 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py +++ b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py @@ -19,7 +19,7 @@ class MaxNevals(BQStoppingCriterion): """ def __init__(self, max_nevals: IntLike): - self.max_nevals = max_nevals + self.max_nevals = int(max_nevals) def __call__(self, bq_state: BQState, info: BQIterInfo) -> bool: return info.nevals >= self.max_nevals diff --git a/tests/test_quad/test_bayesian_quadrature.py b/tests/test_quad/test_bayesian_quadrature.py index 9af6fd4a9..0f36caf4c 100644 --- a/tests/test_quad/test_bayesian_quadrature.py +++ b/tests/test_quad/test_bayesian_quadrature.py @@ -7,7 +7,12 @@ from probnum.quad.integration_measures import LebesgueMeasure from probnum.quad.solvers import BayesianQuadrature from probnum.quad.solvers.policies import RandomPolicy, VanDerCorputPolicy -from probnum.quad.solvers.stopping_criteria import ImmediateStop +from probnum.quad.solvers.stopping_criteria import ( + ImmediateStop, + IntegralVarianceTolerance, + MaxNevals, + RelativeMeanChange, +) from probnum.randprocs.kernels import ExpQuad @@ -31,7 +36,6 @@ def bq(input_dim): return BayesianQuadrature.from_problem( input_dim=input_dim, domain=(np.zeros(input_dim), np.ones(input_dim)), - rng=np.random.default_rng(), ) @@ -56,9 +60,7 @@ def test_bq_from_problem_wrong_inputs(input_dim): ) def test_bq_from_problem_policy_assignment(policy, policy_type): """Test if correct policy is assigned from string identifier.""" - bq = BayesianQuadrature.from_problem( - input_dim=1, domain=(0, 1), policy=policy, rng=np.random.default_rng() - ) + bq = BayesianQuadrature.from_problem(input_dim=1, domain=(0, 1), policy=policy) assert isinstance(bq.policy, policy_type) @@ -81,7 +83,32 @@ def test_bq_from_problem_defaults(bq_no_policy, bq): assert isinstance(bq.kernel, ExpQuad) +@pytest.mark.parametrize( + "max_evals, var_tol, rel_tol, t", + [ + (None, None, None, LambdaStoppingCriterion), + (1000, None, None, MaxNevals), + (None, 1e-5, None, IntegralVarianceTolerance), + (None, None, 1e-5, RelativeMeanChange), + (None, 1e-5, 1e-5, LambdaStoppingCriterion), + (1000, None, 1e-5, LambdaStoppingCriterion), + (1000, 1e-5, None, LambdaStoppingCriterion), + (1000, 1e-5, 1e-5, LambdaStoppingCriterion), + ], +) +def test_bq_from_problem_stopping_criterion_assignment(max_evals, var_tol, rel_tol, t): + bq = BayesianQuadrature.from_problem( + input_dim=2, + domain=(0, 1), + max_evals=max_evals, + var_tol=var_tol, + rel_tol=rel_tol, + ) + assert isinstance(bq.stopping_criterion, t) + + def test_integrate_no_policy_wrong_input(bq_no_policy, data): + # The combination of inputs below is important to trigger the correct exception. nodes, fun_evals, fun = data # no nodes provided @@ -93,30 +120,36 @@ def test_integrate_no_policy_wrong_input(bq_no_policy, data): bq_no_policy.integrate(fun=fun, nodes=nodes, fun_evals=fun_evals) -def test_integrate_wrong_input(bq, bq_no_policy, data): +def test_integrate_wrong_input(bq, bq_no_policy, data, rng): + # The combination of inputs below is important to trigger the correct exception. + nodes, fun_evals, fun = data # no integrand provided with pytest.raises(ValueError): - bq.integrate(fun=None, nodes=nodes, fun_evals=None) + bq.integrate(fun=None, nodes=nodes, fun_evals=None, rng=rng) with pytest.raises(ValueError): bq_no_policy.integrate(fun=None, nodes=nodes, fun_evals=None) # wrong fun_evals shape with pytest.raises(ValueError): - bq.integrate(fun=fun, nodes=nodes, fun_evals=fun_evals[:, None]) + bq.integrate(fun=fun, nodes=nodes, fun_evals=fun_evals[:, None], rng=rng) with pytest.raises(ValueError): bq_no_policy.integrate(fun=None, nodes=nodes, fun_evals=fun_evals[:, None]) # wrong nodes shape with pytest.raises(ValueError): - bq.integrate(fun=fun, nodes=nodes[:, None], fun_evals=None) + bq.integrate(fun=fun, nodes=nodes[:, None], fun_evals=fun_evals, rng=rng) with pytest.raises(ValueError): - bq_no_policy.integrate(fun=None, nodes=nodes[:, None], fun_evals=None) + bq_no_policy.integrate(fun=None, nodes=nodes[:, None], fun_evals=fun_evals) # number of points in nodes and fun_evals do not match wrong_nodes = np.vstack([nodes, np.ones([1, nodes.shape[1]])]) with pytest.raises(ValueError): - bq.integrate(fun=fun, nodes=wrong_nodes, fun_evals=fun_evals) + bq.integrate(fun=fun, nodes=wrong_nodes, fun_evals=fun_evals, rng=rng) with pytest.raises(ValueError): bq_no_policy.integrate(fun=None, nodes=wrong_nodes, fun_evals=fun_evals) + + # no rng provided but policy requires it + with pytest.raises(ValueError): + bq.integrate(fun=fun, nodes=nodes, fun_evals=fun_evals, rng=None) diff --git a/tests/test_quad/test_bayesquad/test_bq.py b/tests/test_quad/test_bayesquad/test_bq.py index 704c62976..9578c8893 100644 --- a/tests/test_quad/test_bayesquad/test_bq.py +++ b/tests/test_quad/test_bayesquad/test_bq.py @@ -18,10 +18,15 @@ def rng(): @pytest.mark.parametrize("input_dim", [1], ids=["dim1"]) -def test_type_1d(f1d, kernel, measure, input_dim): +def test_type_1d(f1d, kernel, measure, input_dim, rng): """Test that BQ outputs normal random variables for 1D integrands.""" integral, _ = bayesquad( - fun=f1d, input_dim=input_dim, kernel=kernel, measure=measure, max_evals=10 + fun=f1d, + input_dim=input_dim, + kernel=kernel, + measure=measure, + max_evals=10, + rng=rng, ) assert isinstance(integral, Normal) @@ -43,7 +48,7 @@ def test_type_1d(f1d, kernel, measure, input_dim): @pytest.mark.parametrize("scale_estimation", [None, "mle"]) @pytest.mark.parametrize("jitter", [1e-6, 1e-7]) def test_integral_values_1d( - f1d, kernel, domain, input_dim, scale_estimation, var_tol, rel_tol, jitter + f1d, kernel, domain, input_dim, scale_estimation, var_tol, rel_tol, jitter, rng ): """Test numerically that BQ computes 1D integrals correctly for a number of different parameters. @@ -70,6 +75,7 @@ def integrand(x): var_tol=var_tol, rel_tol=rel_tol, jitter=jitter, + rng=rng, ) domain = measure.domain num_integral, _ = scipyquad(integrand, domain[0], domain[1]) @@ -138,7 +144,7 @@ def test_integral_values_sin_lebesgue( @pytest.mark.parametrize("input_dim", [2, 3, 4]) @pytest.mark.parametrize("num_data", [1]) # pylint: disable=invalid-name -def test_integral_values_kernel_translate(kernel, measure, input_dim, x): +def test_integral_values_kernel_translate(kernel, measure, input_dim, x, rng): """Test numerical integration of kernel translates.""" kernel_embedding = KernelEmbedding(kernel, measure) # pylint: disable=cell-var-from-loop @@ -152,6 +158,7 @@ def test_integral_values_kernel_translate(kernel, measure, input_dim, x): var_tol=1e-8, max_evals=1000, batch_size=50, + rng=rng, ) true_integral = kernel_embedding.kernel_mean(np.atleast_2d(translate_point)) np.testing.assert_almost_equal(bq_integral.mean, true_integral, decimal=2) @@ -173,13 +180,13 @@ def test_no_domain_or_measure_raises_error(input_dim): @pytest.mark.parametrize("input_dim", [1]) @pytest.mark.parametrize("measure_name", ["lebesgue"]) -def test_domain_ignored_if_lebesgue(input_dim, measure): +def test_domain_ignored_if_lebesgue(input_dim, measure, rng): domain = (0, 1) fun = lambda x: np.reshape(x, (x.shape[0],)) # standard BQ bq_integral, _ = bayesquad( - fun=fun, input_dim=input_dim, domain=domain, measure=measure + fun=fun, input_dim=input_dim, domain=domain, measure=measure, rng=rng ) assert isinstance(bq_integral, Normal) @@ -193,7 +200,7 @@ def test_domain_ignored_if_lebesgue(input_dim, measure): assert isinstance(bq_integral, Normal) -def test_zero_function_gives_zero_variance_with_mle(): +def test_zero_function_gives_zero_variance_with_mle(rng): """Test that BQ variance is zero for zero function when MLE is used to set the scale parameter.""" input_dim = 1 @@ -203,7 +210,7 @@ def test_zero_function_gives_zero_variance_with_mle(): fun_evals = fun(nodes) bq_integral1, _ = bayesquad( - fun=fun, input_dim=input_dim, domain=domain, scale_estimation="mle" + fun=fun, input_dim=input_dim, domain=domain, scale_estimation="mle", rng=rng ) bq_integral2, _ = bayesquad_from_data( nodes=nodes, fun_evals=fun_evals, domain=domain, scale_estimation="mle" diff --git a/tests/test_quad/test_policy.py b/tests/test_quad/test_policy.py index fa848da73..86402dda4 100644 --- a/tests/test_quad/test_policy.py +++ b/tests/test_quad/test_policy.py @@ -1,10 +1,111 @@ """Basic tests for BQ policies.""" + +# New policies need to be added to the fixtures 'policy_name' and 'policy_params' +# and 'policy'. + + import numpy as np import pytest from probnum.quad.integration_measures import GaussianMeasure, LebesgueMeasure -from probnum.quad.solvers.policies import VanDerCorputPolicy +from probnum.quad.solvers import BQState +from probnum.quad.solvers.policies import RandomPolicy, VanDerCorputPolicy +from probnum.randprocs.kernels import ExpQuad + + +@pytest.fixture +def batch_size(): + return 3 + + +@pytest.fixture( + params=[ + pytest.param(name, id=name) for name in ["RandomPolicy", "VanDerCorputPolicy"] + ] +) +def policy_name(request): + return request.param + + +@pytest.fixture +def policy_params(policy_name, input_dim, batch_size, rng): + def _get_bq_states(ndim): + nevals = 5 + bq_state_no_data = BQState( + measure=LebesgueMeasure(input_dim=ndim, domain=(0, 1)), + kernel=ExpQuad(input_shape=(ndim,)), + ) + bq_state = BQState( + measure=LebesgueMeasure(input_dim=ndim, domain=(0, 1)), + kernel=ExpQuad(input_shape=(ndim,)), + nodes=np.zeros([nevals, ndim]), + fun_evals=np.ones(nevals), + ) + return bq_state, bq_state_no_data + + params = dict(name=policy_name, ndim=input_dim) + params["bq_state"], params["bq_state_no_data"] = _get_bq_states(input_dim) + + if policy_name == "RandomPolicy": + input_params = dict( + batch_size=batch_size, + sample_func=lambda batch_size, rng: np.ones([batch_size, input_dim]), + ) + params["requires_rng"] = True + elif policy_name == "VanDerCorputPolicy": + # Since VanDerCorputPolicy can only produce univariate nodes, this overrides + # input_dim = 1 for all tests. This is a bit cheap, but pytest parametrization + # is convoluted enough. + input_params = dict( + batch_size=batch_size, + measure=LebesgueMeasure(input_dim=1, domain=(0, 1)), + ) + params["bq_state"], params["bq_state_no_data"] = _get_bq_states(1) + params["ndim"] = 1 + params["requires_rng"] = False + else: + raise NotImplementedError + + params["input_params"] = input_params + + return params + + +@pytest.fixture() +def policy(policy_params): + name = policy_params.pop("name") + input_params = policy_params.pop("input_params") + + if name == "RandomPolicy": + return RandomPolicy(**input_params), policy_params + elif name == "VanDerCorputPolicy": + return VanDerCorputPolicy(**input_params), policy_params + else: + raise NotImplementedError + + +# Tests shared by all policies start here. + + +def test_policy_shapes(policy, batch_size, rng): + policy, params = policy + bq_state, bq_state_no_data = params["bq_state"], params["bq_state_no_data"] + ndim = params["ndim"] + + # bq state contains data + assert policy(bq_state, rng).shape == (batch_size, ndim) + + # bq state contains no data yet + assert policy(bq_state_no_data, rng).shape == (batch_size, ndim) + + +def test_policy_property_values(policy): + policy, params = policy + assert policy.requires_rng is params["requires_rng"] + + +# Tests specific to VanDerCorputPolicy start here def test_van_der_corput_multi_d_error(): @@ -12,7 +113,7 @@ def test_van_der_corput_multi_d_error(): wrong_dimension = 2 measure = GaussianMeasure(input_dim=wrong_dimension, mean=0.0, cov=1.0) with pytest.raises(ValueError): - VanDerCorputPolicy(measure, batch_size=1) + VanDerCorputPolicy(1, measure) @pytest.mark.parametrize("domain", [(-np.Inf, 0), (1, np.Inf), (-np.Inf, np.Inf)]) @@ -20,7 +121,7 @@ def test_van_der_corput_infinite_error(domain): """Check that van der Corput policy fails on infinite domains.""" measure = LebesgueMeasure(input_dim=1, domain=domain) with pytest.raises(ValueError): - VanDerCorputPolicy(measure, batch_size=1) + VanDerCorputPolicy(1, measure) @pytest.mark.parametrize("n", [4, 8, 16, 32, 64, 128, 256])