diff --git a/docs/source/api/quad.rst b/docs/source/api/quad.rst index c33a11907..7d321e78e 100644 --- a/docs/source/api/quad.rst +++ b/docs/source/api/quad.rst @@ -5,3 +5,23 @@ probnum.quad .. automodapi:: probnum.quad :no-heading: :headings: "=" + +.. toctree:: + :hidden: + + quad/integration_measures + +.. toctree:: + :hidden: + + quad/kernel_embeddings + +.. toctree:: + :hidden: + + quad/solvers + +.. toctree:: + :hidden: + + quad/typing diff --git a/docs/source/api/quad/integration_measures.rst b/docs/source/api/quad/integration_measures.rst new file mode 100644 index 000000000..8a6dd9ff8 --- /dev/null +++ b/docs/source/api/quad/integration_measures.rst @@ -0,0 +1,6 @@ +probnum.quad.integration_measures +================================= + +.. automodapi:: probnum.quad.integration_measures + :no-heading: + :headings: "-" diff --git a/docs/source/api/quad/kernel_embeddings.rst b/docs/source/api/quad/kernel_embeddings.rst new file mode 100644 index 000000000..bc3e4a186 --- /dev/null +++ b/docs/source/api/quad/kernel_embeddings.rst @@ -0,0 +1,6 @@ +probnum.quad.kernel_embeddings +============================== + +.. automodapi:: probnum.quad.kernel_embeddings + :no-heading: + :headings: "-" diff --git a/docs/source/api/quad/solvers.belief_updates.rst b/docs/source/api/quad/solvers.belief_updates.rst new file mode 100644 index 000000000..220b915ea --- /dev/null +++ b/docs/source/api/quad/solvers.belief_updates.rst @@ -0,0 +1,5 @@ +Belief Updates +-------------- +.. automodapi:: probnum.quad.solvers.belief_updates + :no-heading: + :headings: "*" diff --git a/docs/source/api/quad/solvers.policies.rst b/docs/source/api/quad/solvers.policies.rst new file mode 100644 index 000000000..bdf5bd24f --- /dev/null +++ b/docs/source/api/quad/solvers.policies.rst @@ -0,0 +1,5 @@ +Policies +-------- +.. automodapi:: probnum.quad.solvers.policies + :no-heading: + :headings: "*" diff --git a/docs/source/api/quad/solvers.rst b/docs/source/api/quad/solvers.rst new file mode 100644 index 000000000..c5d6b314c --- /dev/null +++ b/docs/source/api/quad/solvers.rst @@ -0,0 +1,21 @@ +probnum.quad.solvers +==================== + +.. automodapi:: probnum.quad.solvers + :no-heading: + :headings: "-" + +.. toctree:: + :hidden: + + solvers.belief_updates + +.. toctree:: + :hidden: + + solvers.policies + +.. toctree:: + :hidden: + + solvers.stopping_criteria diff --git a/docs/source/api/quad/solvers.stopping_criteria.rst b/docs/source/api/quad/solvers.stopping_criteria.rst new file mode 100644 index 000000000..9ad138add --- /dev/null +++ b/docs/source/api/quad/solvers.stopping_criteria.rst @@ -0,0 +1,5 @@ +Stopping Criteria +----------------- +.. automodapi:: probnum.quad.solvers.stopping_criteria + :no-heading: + :headings: "*" diff --git a/docs/source/api/quad/typing.rst b/docs/source/api/quad/typing.rst new file mode 100644 index 000000000..7426458cc --- /dev/null +++ b/docs/source/api/quad/typing.rst @@ -0,0 +1,7 @@ +probnum.quad.typing +==================== + +.. automodapi:: probnum.quad.typing + :no-heading: + :headings: "-" + :include-all-objects: diff --git a/docs/source/conf.py b/docs/source/conf.py index 8d0dc2ed4..6086e44ea 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -56,7 +56,10 @@ autodoc_typehints_description_target = "all" autodoc_typehints_format = "short" autodoc_type_aliases = { - type_alias: f"typing.{type_alias}" for type_alias in probnum.typing.__all__ + **{type_alias: f"typing.{type_alias}" for type_alias in probnum.typing.__all__}, + **{ + type_alias: f"typing.{type_alias}" for type_alias in probnum.quad.typing.__all__ + }, } # Ensures type aliases are correctly displayed and linked in the documentation # Settings for napoleon diff --git a/src/probnum/quad/__init__.py b/src/probnum/quad/__init__.py index c404378d6..dc8d59280 100644 --- a/src/probnum/quad/__init__.py +++ b/src/probnum/quad/__init__.py @@ -6,50 +6,15 @@ choosing points to evaluate the integrand based on said model. """ -from probnum.quad.solvers.policies import Policy, RandomPolicy, VanDerCorputPolicy -from probnum.quad.solvers.stopping_criteria import ( - BQStoppingCriterion, - ImmediateStop, - IntegralVarianceTolerance, - MaxNevals, - RelativeMeanChange, -) - +from . import integration_measures, kernel_embeddings, solvers from ._bayesquad import bayesquad, bayesquad_from_data -from ._integration_measures import GaussianMeasure, IntegrationMeasure, LebesgueMeasure -from .kernel_embeddings import KernelEmbedding -from .solvers import ( - BayesianQuadrature, - BQBeliefUpdate, - BQIterInfo, - BQStandardBeliefUpdate, - BQState, -) # Public classes and functions. Order is reflected in documentation. __all__ = [ "bayesquad", "bayesquad_from_data", - "BayesianQuadrature", - "ImmediateStop", - "IntegrationMeasure", - "ImmediateStop", - "KernelEmbedding", - "GaussianMeasure", - "LebesgueMeasure", - "BQStoppingCriterion", - "IntegralVarianceTolerance", - "MaxNevals", - "RandomPolicy", - "VanDerCorputPolicy", - "RelativeMeanChange", ] # Set correct module paths. Corrects links and module paths in documentation. -BayesianQuadrature.__module__ = "probnum.quad" -BQStoppingCriterion.__module__ = "probnum.quad" -ImmediateStop.__module__ = "probnum.quad" -IntegrationMeasure.__module__ = "probnum.quad" -KernelEmbedding.__module__ = "probnum.quad" -GaussianMeasure.__module__ = "probnum.quad" -LebesgueMeasure.__module__ = "probnum.quad" +bayesquad.__module__ = "probnum.quad" +bayesquad_from_data.__module__ = "probnum.quad" diff --git a/src/probnum/quad/_bayesquad.py b/src/probnum/quad/_bayesquad.py index 09c05f327..ffb185dea 100644 --- a/src/probnum/quad/_bayesquad.py +++ b/src/probnum/quad/_bayesquad.py @@ -13,15 +13,13 @@ import numpy as np -from probnum.quad.solvers.bq_state import BQIterInfo +from probnum.quad.integration_measures import IntegrationMeasure, LebesgueMeasure +from probnum.quad.solvers import BayesianQuadrature, BQIterInfo +from probnum.quad.typing import DomainLike, DomainType from probnum.randprocs.kernels import Kernel from probnum.randvars import Normal from probnum.typing import FloatLike, IntLike -from ._integration_measures import IntegrationMeasure, LebesgueMeasure -from ._quad_typing import DomainLike, DomainType -from .solvers import BayesianQuadrature - def bayesquad( fun: Callable, diff --git a/src/probnum/quad/_utils.py b/src/probnum/quad/_utils.py index cd3754d20..4e7c5a1a5 100644 --- a/src/probnum/quad/_utils.py +++ b/src/probnum/quad/_utils.py @@ -1,12 +1,14 @@ """Helper functions for the quad package""" +from __future__ import annotations + from typing import Optional, Tuple import numpy as np from probnum.typing import IntLike -from ._quad_typing import DomainLike, DomainType +from .typing import DomainLike, DomainType def as_domain( diff --git a/src/probnum/quad/integration_measures/__init__.py b/src/probnum/quad/integration_measures/__init__.py new file mode 100644 index 000000000..a23a65c93 --- /dev/null +++ b/src/probnum/quad/integration_measures/__init__.py @@ -0,0 +1,15 @@ +"""Integration measures for Bayesian quadrature methods.""" + +from ._integration_measures import GaussianMeasure, IntegrationMeasure, LebesgueMeasure + +# Public classes and functions. Order is reflected in documentation. +__all__ = [ + "IntegrationMeasure", + "GaussianMeasure", + "LebesgueMeasure", +] + +# Set correct module paths. Corrects links and module paths in documentation. +IntegrationMeasure.__module__ = "probnum.quad.integration_measures" +GaussianMeasure.__module__ = "probnum.quad.integration_measures" +LebesgueMeasure.__module__ = "probnum.quad.integration_measures" diff --git a/src/probnum/quad/_integration_measures.py b/src/probnum/quad/integration_measures/_integration_measures.py similarity index 97% rename from src/probnum/quad/_integration_measures.py rename to src/probnum/quad/integration_measures/_integration_measures.py index 350f02385..6fb210c2b 100644 --- a/src/probnum/quad/_integration_measures.py +++ b/src/probnum/quad/integration_measures/_integration_measures.py @@ -7,12 +7,11 @@ import numpy as np import scipy.stats +from probnum.quad._utils import as_domain +from probnum.quad.typing import DomainLike from probnum.randvars import Normal from probnum.typing import FloatLike, IntLike -from ._quad_typing import DomainLike -from ._utils import as_domain - class IntegrationMeasure(abc.ABC): """An abstract class for a measure against which a target function is integrated. @@ -90,7 +89,7 @@ class LebesgueMeasure(IntegrationMeasure): input_dim Dimension of the integration domain. If not given, inferred from ``domain``. normalized - Boolean which controls whether or not the measure is normalized (i.e., + Boolean which controls whether the measure is normalized (i.e., integral over the domain is one). Defaults to ``False``. """ diff --git a/src/probnum/quad/kernel_embeddings/__init__.py b/src/probnum/quad/kernel_embeddings/__init__.py index aaed94244..e7564d650 100644 --- a/src/probnum/quad/kernel_embeddings/__init__.py +++ b/src/probnum/quad/kernel_embeddings/__init__.py @@ -1 +1,11 @@ +"""Kernel embeddings for Bayesian quadrature methods.""" + from ._kernel_embedding import KernelEmbedding + +# Public classes and functions. Order is reflected in documentation. +__all__ = [ + "KernelEmbedding", +] + +# Set correct module paths. Corrects links and module paths in documentation. +KernelEmbedding.__module__ = "probnum.quad.kernel_embeddings" diff --git a/src/probnum/quad/kernel_embeddings/_expquad_gauss.py b/src/probnum/quad/kernel_embeddings/_expquad_gauss.py index ac1088150..3e75ed6fb 100644 --- a/src/probnum/quad/kernel_embeddings/_expquad_gauss.py +++ b/src/probnum/quad/kernel_embeddings/_expquad_gauss.py @@ -4,7 +4,7 @@ import numpy as np import scipy.linalg as slinalg -from probnum.quad._integration_measures import GaussianMeasure +from probnum.quad.integration_measures import GaussianMeasure from probnum.randprocs.kernels import ExpQuad diff --git a/src/probnum/quad/kernel_embeddings/_expquad_lebesgue.py b/src/probnum/quad/kernel_embeddings/_expquad_lebesgue.py index a46b8e4fa..0f6392154 100644 --- a/src/probnum/quad/kernel_embeddings/_expquad_lebesgue.py +++ b/src/probnum/quad/kernel_embeddings/_expquad_lebesgue.py @@ -6,7 +6,7 @@ import numpy as np from scipy.special import erf -from probnum.quad._integration_measures import LebesgueMeasure +from probnum.quad.integration_measures import LebesgueMeasure from probnum.randprocs.kernels import ExpQuad diff --git a/src/probnum/quad/kernel_embeddings/_kernel_embedding.py b/src/probnum/quad/kernel_embeddings/_kernel_embedding.py index 85a57429e..be0f742ac 100644 --- a/src/probnum/quad/kernel_embeddings/_kernel_embedding.py +++ b/src/probnum/quad/kernel_embeddings/_kernel_embedding.py @@ -4,7 +4,7 @@ import numpy as np -from probnum.quad._integration_measures import ( +from probnum.quad.integration_measures import ( GaussianMeasure, IntegrationMeasure, LebesgueMeasure, diff --git a/src/probnum/quad/kernel_embeddings/_matern_lebesgue.py b/src/probnum/quad/kernel_embeddings/_matern_lebesgue.py index 727eadc55..74558c92e 100644 --- a/src/probnum/quad/kernel_embeddings/_matern_lebesgue.py +++ b/src/probnum/quad/kernel_embeddings/_matern_lebesgue.py @@ -5,7 +5,7 @@ import numpy as np -from probnum.quad._integration_measures import LebesgueMeasure +from probnum.quad.integration_measures import LebesgueMeasure from probnum.randprocs.kernels import Matern, ProductMatern diff --git a/src/probnum/quad/solvers/__init__.py b/src/probnum/quad/solvers/__init__.py index b3ae7d3ae..35e884f4e 100644 --- a/src/probnum/quad/solvers/__init__.py +++ b/src/probnum/quad/solvers/__init__.py @@ -1,3 +1,17 @@ -from .bayesian_quadrature import BayesianQuadrature -from .belief_updates import BQBeliefUpdate, BQStandardBeliefUpdate -from .bq_state import BQIterInfo, BQState +"""Bayesian quadrature methods and their components.""" + +from . import belief_updates, policies, stopping_criteria +from ._bayesian_quadrature import BayesianQuadrature +from ._bq_state import BQIterInfo, BQState + +# Public classes and functions. Order is reflected in documentation. +__all__ = [ + "BayesianQuadrature", + "BQState", + "BQIterInfo", +] + +# Set correct module paths. Corrects links and module paths in documentation. +BayesianQuadrature.__module__ = "probnum.quad.solvers" +BQState.__module__ = "probnum.quad.solvers" +BQIterInfo.__module__ = "probnum.quad.solvers" diff --git a/src/probnum/quad/solvers/bayesian_quadrature.py b/src/probnum/quad/solvers/_bayesian_quadrature.py similarity index 97% rename from src/probnum/quad/solvers/bayesian_quadrature.py rename to src/probnum/quad/solvers/_bayesian_quadrature.py index 36f13de6f..6b037ecba 100644 --- a/src/probnum/quad/solvers/bayesian_quadrature.py +++ b/src/probnum/quad/solvers/_bayesian_quadrature.py @@ -1,10 +1,16 @@ """Probabilistic numerical methods for solving integrals.""" +from __future__ import annotations + from typing import Callable, Optional, Tuple import warnings import numpy as np +from probnum.quad.integration_measures import IntegrationMeasure, LebesgueMeasure +from probnum.quad.kernel_embeddings import KernelEmbedding +from probnum.quad.solvers._bq_state import BQIterInfo, BQState +from probnum.quad.solvers.belief_updates import BQBeliefUpdate, BQStandardBeliefUpdate from probnum.quad.solvers.policies import Policy, RandomPolicy, VanDerCorputPolicy from probnum.quad.solvers.stopping_criteria import ( BQStoppingCriterion, @@ -13,16 +19,11 @@ MaxNevals, RelativeMeanChange, ) +from probnum.quad.typing import DomainLike from probnum.randprocs.kernels import ExpQuad, Kernel from probnum.randvars import Normal from probnum.typing import FloatLike, IntLike -from .._integration_measures import IntegrationMeasure, LebesgueMeasure -from .._quad_typing import DomainLike -from ..kernel_embeddings import KernelEmbedding -from .belief_updates import BQBeliefUpdate, BQStandardBeliefUpdate -from .bq_state import BQIterInfo, BQState - # pylint: disable=too-many-branches, too-complex diff --git a/src/probnum/quad/solvers/bq_state.py b/src/probnum/quad/solvers/_bq_state.py similarity index 98% rename from src/probnum/quad/solvers/bq_state.py rename to src/probnum/quad/solvers/_bq_state.py index 854e90178..337d5c571 100644 --- a/src/probnum/quad/solvers/bq_state.py +++ b/src/probnum/quad/solvers/_bq_state.py @@ -1,11 +1,13 @@ """State of a Bayesian quadrature method.""" +from __future__ import annotations + from dataclasses import dataclass from typing import Optional, Tuple import numpy as np -from probnum.quad._integration_measures import IntegrationMeasure +from probnum.quad.integration_measures import IntegrationMeasure from probnum.quad.kernel_embeddings import KernelEmbedding from probnum.randprocs.kernels import Kernel from probnum.randvars import Normal diff --git a/src/probnum/quad/solvers/belief_updates/__init__.py b/src/probnum/quad/solvers/belief_updates/__init__.py index adfbbbe29..422a5f8a5 100644 --- a/src/probnum/quad/solvers/belief_updates/__init__.py +++ b/src/probnum/quad/solvers/belief_updates/__init__.py @@ -1 +1,13 @@ +"""Belief updates for Bayesian quadrature.""" + from ._belief_update import BQBeliefUpdate, BQStandardBeliefUpdate + +# Public classes and functions. Order is reflected in documentation. +__all__ = [ + "BQBeliefUpdate", + "BQStandardBeliefUpdate", +] + +# Set correct module paths. Corrects links and module paths in documentation. +BQBeliefUpdate.__module__ = "probnum.quad.solvers.belief_updates" +BQStandardBeliefUpdate.__module__ = "probnum.quad.solvers.belief_updates" diff --git a/src/probnum/quad/solvers/belief_updates/_belief_update.py b/src/probnum/quad/solvers/belief_updates/_belief_update.py index 84a1d06f6..a04bb5356 100644 --- a/src/probnum/quad/solvers/belief_updates/_belief_update.py +++ b/src/probnum/quad/solvers/belief_updates/_belief_update.py @@ -1,5 +1,7 @@ """Belief updates for Bayesian quadrature.""" +from __future__ import annotations + import abc from typing import Optional, Tuple @@ -7,7 +9,7 @@ from scipy.linalg import cho_factor, cho_solve from probnum.quad.kernel_embeddings import KernelEmbedding -from probnum.quad.solvers.bq_state import BQState +from probnum.quad.solvers._bq_state import BQState from probnum.randprocs.kernels import Kernel from probnum.randvars import Normal from probnum.typing import FloatLike diff --git a/src/probnum/quad/solvers/policies/__init__.py b/src/probnum/quad/solvers/policies/__init__.py index 61ff04af1..e53c1b18c 100644 --- a/src/probnum/quad/solvers/policies/__init__.py +++ b/src/probnum/quad/solvers/policies/__init__.py @@ -3,3 +3,15 @@ from ._policy import Policy from ._random_policy import RandomPolicy from ._van_der_corput_policy import VanDerCorputPolicy + +# Public classes and functions. Order is reflected in documentation. +__all__ = [ + "Policy", + "RandomPolicy", + "VanDerCorputPolicy", +] + +# Set correct module paths. Corrects links and module paths in documentation. +Policy.__module__ = "probnum.quad.solvers.policies" +RandomPolicy.__module__ = "probnum.quad.solvers.policies" +VanDerCorputPolicy.__module__ = "probnum.quad.solvers.policies" diff --git a/src/probnum/quad/solvers/policies/_policy.py b/src/probnum/quad/solvers/policies/_policy.py index 000f87713..f57d3438e 100644 --- a/src/probnum/quad/solvers/policies/_policy.py +++ b/src/probnum/quad/solvers/policies/_policy.py @@ -4,7 +4,7 @@ import numpy as np -from probnum.quad.solvers.bq_state import BQState +from probnum.quad.solvers._bq_state import BQState # pylint: disable=too-few-public-methods, fixme diff --git a/src/probnum/quad/solvers/policies/_random_policy.py b/src/probnum/quad/solvers/policies/_random_policy.py index 58a35decf..6110ca465 100644 --- a/src/probnum/quad/solvers/policies/_random_policy.py +++ b/src/probnum/quad/solvers/policies/_random_policy.py @@ -4,7 +4,7 @@ import numpy as np -from probnum.quad.solvers.bq_state import BQState +from probnum.quad.solvers._bq_state import BQState from ._policy import Policy diff --git a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py index 9b72c007e..1e7edfe85 100644 --- a/src/probnum/quad/solvers/policies/_van_der_corput_policy.py +++ b/src/probnum/quad/solvers/policies/_van_der_corput_policy.py @@ -4,8 +4,8 @@ import numpy as np -from probnum.quad._integration_measures import IntegrationMeasure -from probnum.quad.solvers.bq_state import BQState +from probnum.quad.integration_measures import IntegrationMeasure +from probnum.quad.solvers._bq_state import BQState from ._policy import Policy diff --git a/src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py b/src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py index 3c18b9bd4..04e7ea595 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py +++ b/src/probnum/quad/solvers/stopping_criteria/_bq_stopping_criterion.py @@ -1,7 +1,7 @@ """Base class for Bayesian quadrature stopping criteria.""" from probnum import StoppingCriterion -from probnum.quad.solvers.bq_state import BQIterInfo, BQState +from probnum.quad.solvers._bq_state import BQIterInfo, BQState # pylint: disable=too-few-public-methods,arguments-differ fixme diff --git a/src/probnum/quad/solvers/stopping_criteria/_immediate_stop.py b/src/probnum/quad/solvers/stopping_criteria/_immediate_stop.py index e22fc91d2..846f287ca 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_immediate_stop.py +++ b/src/probnum/quad/solvers/stopping_criteria/_immediate_stop.py @@ -1,6 +1,6 @@ """Stopping criterion that stops immediately.""" -from probnum.quad.solvers.bq_state import BQIterInfo, BQState +from probnum.quad.solvers._bq_state import BQIterInfo, BQState from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion # pylint: disable=too-few-public-methods diff --git a/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py b/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py index 5c2dc1bc5..dbcfbbf52 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py +++ b/src/probnum/quad/solvers/stopping_criteria/_integral_variance_tol.py @@ -1,6 +1,8 @@ """Stopping criterion based on the absolute value of the integral variance""" -from probnum.quad.solvers.bq_state import BQIterInfo, BQState +from __future__ import annotations + +from probnum.quad.solvers._bq_state import BQIterInfo, BQState from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion from probnum.typing import FloatLike diff --git a/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py index 00c5ff53e..bf87dd252 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py +++ b/src/probnum/quad/solvers/stopping_criteria/_max_nevals.py @@ -1,6 +1,8 @@ """Stopping criterion based on a maximum number of integrand evaluations.""" -from probnum.quad.solvers.bq_state import BQIterInfo, BQState +from __future__ import annotations + +from probnum.quad.solvers._bq_state import BQIterInfo, BQState from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion from probnum.typing import IntLike diff --git a/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py b/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py index ecb097445..6e9144809 100644 --- a/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py +++ b/src/probnum/quad/solvers/stopping_criteria/_rel_mean_change.py @@ -1,9 +1,11 @@ """Stopping criterion based on the relative change of the successive integral estimators.""" +from __future__ import annotations + import numpy as np -from probnum.quad.solvers.bq_state import BQIterInfo, BQState +from probnum.quad.solvers._bq_state import BQIterInfo, BQState from probnum.quad.solvers.stopping_criteria import BQStoppingCriterion from probnum.typing import FloatLike diff --git a/src/probnum/quad/_quad_typing.py b/src/probnum/quad/typing.py similarity index 56% rename from src/probnum/quad/_quad_typing.py rename to src/probnum/quad/typing.py index db2c0a3f1..f8009394a 100644 --- a/src/probnum/quad/_quad_typing.py +++ b/src/probnum/quad/typing.py @@ -1,10 +1,17 @@ """Types specific to the quad package.""" +from __future__ import annotations + from typing import Tuple, Union import numpy as np from probnum.typing import FloatLike +__all__ = ["DomainLike", "DomainType"] + DomainType = Tuple[np.ndarray, np.ndarray] +"""Type defining an integration domain.""" + DomainLike = Union[Tuple[FloatLike, FloatLike], DomainType] +"""Object that can be converted to an integration domain.""" diff --git a/tests/test_quad/conftest.py b/tests/test_quad/conftest.py index 48cc4c508..961002356 100644 --- a/tests/test_quad/conftest.py +++ b/tests/test_quad/conftest.py @@ -5,8 +5,12 @@ import numpy as np import pytest -import probnum.quad._integration_measures as measures -from probnum.quad.kernel_embeddings._kernel_embedding import KernelEmbedding +from probnum.quad.integration_measures import ( + GaussianMeasure, + IntegrationMeasure, + LebesgueMeasure, +) +from probnum.quad.kernel_embeddings import KernelEmbedding from probnum.randprocs import kernels # pylint: disable=unnecessary-lambda @@ -107,14 +111,14 @@ def fixture_measure_params( @pytest.fixture(name="measure") -def fixture_measure(measure_params) -> measures.IntegrationMeasure: +def fixture_measure(measure_params) -> IntegrationMeasure: """Measure.""" name = measure_params.pop("name") if name == "gauss": - return measures.GaussianMeasure(**measure_params) + return GaussianMeasure(**measure_params) elif name == "lebesgue": - return measures.LebesgueMeasure(**measure_params) + return LebesgueMeasure(**measure_params) raise NotImplementedError @@ -136,7 +140,7 @@ def fixture_kernel(request, input_dim: int) -> kernels.Kernel: # Kernel Embeddings @pytest.fixture(name="kernel_embedding") def fixture_kernel_embedding( - request, kernel: kernels.Kernel, measure: measures.IntegrationMeasure + request, kernel: kernels.Kernel, measure: IntegrationMeasure ) -> KernelEmbedding: """Set up kernel embedding.""" return KernelEmbedding(kernel, measure) diff --git a/tests/test_quad/test_bayesian_quadrature.py b/tests/test_quad/test_bayesian_quadrature.py index 60b7373f1..9af6fd4a9 100644 --- a/tests/test_quad/test_bayesian_quadrature.py +++ b/tests/test_quad/test_bayesian_quadrature.py @@ -4,13 +4,10 @@ import pytest from probnum import LambdaStoppingCriterion -from probnum.quad import ( - BayesianQuadrature, - ImmediateStop, - LebesgueMeasure, - RandomPolicy, - VanDerCorputPolicy, -) +from probnum.quad.integration_measures import LebesgueMeasure +from probnum.quad.solvers import BayesianQuadrature +from probnum.quad.solvers.policies import RandomPolicy, VanDerCorputPolicy +from probnum.quad.solvers.stopping_criteria import ImmediateStop from probnum.randprocs.kernels import ExpQuad diff --git a/tests/test_quad/test_bayesquad/test_bq.py b/tests/test_quad/test_bayesquad/test_bq.py index 1ce644c73..704c62976 100644 --- a/tests/test_quad/test_bayesquad/test_bq.py +++ b/tests/test_quad/test_bayesquad/test_bq.py @@ -4,8 +4,8 @@ import pytest from scipy.integrate import quad as scipyquad -import probnum.quad from probnum.quad import bayesquad, bayesquad_from_data +from probnum.quad.integration_measures import LebesgueMeasure from probnum.quad.kernel_embeddings import KernelEmbedding from probnum.randvars import Normal @@ -52,7 +52,7 @@ def test_integral_values_1d( domains. """ - measure = probnum.quad.LebesgueMeasure(input_dim=input_dim, domain=domain) + measure = LebesgueMeasure(input_dim=input_dim, domain=domain) # numerical integral # pylint: disable=invalid-name def integrand(x): diff --git a/tests/test_quad/test_bq_state.py b/tests/test_quad/test_bq_state.py index 9294f8844..85123f06c 100644 --- a/tests/test_quad/test_bq_state.py +++ b/tests/test_quad/test_bq_state.py @@ -3,8 +3,9 @@ import numpy as np import pytest -from probnum.quad import IntegrationMeasure, KernelEmbedding, LebesgueMeasure -from probnum.quad.solvers.bq_state import BQIterInfo, BQState +from probnum.quad.integration_measures import IntegrationMeasure, LebesgueMeasure +from probnum.quad.kernel_embeddings import KernelEmbedding +from probnum.quad.solvers import BQIterInfo, BQState from probnum.randprocs.kernels import ExpQuad, Kernel from probnum.randvars import Normal diff --git a/tests/test_quad/test_integration_measure.py b/tests/test_quad/test_integration_measure.py index e8ece4f12..7596a2167 100644 --- a/tests/test_quad/test_integration_measure.py +++ b/tests/test_quad/test_integration_measure.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from probnum import quad +from probnum.quad.integration_measures import GaussianMeasure, LebesgueMeasure # Tests for Gaussian measure @@ -11,7 +11,7 @@ def test_gaussian_diagonal_covariance(input_dim: int): """Check that diagonal covariance matrices are recognised as diagonal.""" mean = np.full((input_dim,), 0.0) cov = np.eye(input_dim) - measure = quad.GaussianMeasure(mean, cov) + measure = GaussianMeasure(mean, cov) assert measure.diagonal_covariance @@ -21,7 +21,7 @@ def test_gaussian_non_diagonal_covariance(input_dim_non_diagonal): mean = np.full((input_dim_non_diagonal,), 0.0) cov = np.eye(input_dim_non_diagonal) cov[0, 1] = 1.5 - measure = quad.GaussianMeasure(mean, cov) + measure = GaussianMeasure(mean, cov) assert not measure.diagonal_covariance @@ -30,7 +30,7 @@ def test_gaussian_non_diagonal_covariance(input_dim_non_diagonal): def test_gaussian_mean_shape_1d(mean, cov): """Test that different types of one-dimensional means and covariances yield one- dimensional Gaussian measures when no dimension is given.""" - measure = quad.GaussianMeasure(mean=mean, cov=cov) + measure = GaussianMeasure(mean=mean, cov=cov) assert measure.input_dim == 1 assert measure.mean.size == 1 assert measure.cov.size == 1 @@ -40,13 +40,13 @@ def test_gaussian_mean_shape_1d(mean, cov): def test_gaussian_negative_dimension(neg_dim): """Make sure that a negative dimension raises ValueError.""" with pytest.raises(ValueError): - quad.GaussianMeasure(0, 1, neg_dim) + GaussianMeasure(0, 1, neg_dim) def test_gaussian_param_assignment(input_dim: int): """Check that diagonal mean and covariance for higher dimensions are extended correctly.""" - measure = quad.GaussianMeasure(0, 1, input_dim) + measure = GaussianMeasure(0, 1, input_dim) if input_dim == 1: assert measure.mean == 0.0 assert measure.cov == 1.0 @@ -59,7 +59,7 @@ def test_gaussian_param_assignment(input_dim: int): def test_gaussian_scalar(): """Check that the 1d Gaussian case works.""" - measure = quad.GaussianMeasure(0.5, 1.5) + measure = GaussianMeasure(0.5, 1.5) assert measure.mean == 0.5 assert measure.cov == 1.5 @@ -68,15 +68,15 @@ def test_gaussian_scalar(): def test_lebesgue_dim_correct(input_dim: int): """Check that dimensions are handled correctly.""" domain1 = (0.0, 1.87) - measure11 = quad.LebesgueMeasure(domain=domain1) - measure12 = quad.LebesgueMeasure(input_dim=input_dim, domain=domain1) + measure11 = LebesgueMeasure(domain=domain1) + measure12 = LebesgueMeasure(input_dim=input_dim, domain=domain1) assert measure11.input_dim == 1 assert measure12.input_dim == input_dim domain2 = (np.full((input_dim,), -0.1), np.full((input_dim,), 0.0)) - measure21 = quad.LebesgueMeasure(domain=domain2) - measure22 = quad.LebesgueMeasure(input_dim=input_dim, domain=domain2) + measure21 = LebesgueMeasure(domain=domain2) + measure22 = LebesgueMeasure(input_dim=input_dim, domain=domain2) assert measure21.input_dim == input_dim assert measure22.input_dim == input_dim @@ -89,13 +89,13 @@ def test_lebesgue_dim_incorrect(domain_a, domain_b, input_dim): """Check that ValueError is raised if domain limits have mismatching dimensions or dimension is not positive.""" with pytest.raises(ValueError): - quad.LebesgueMeasure(domain=(domain_a, domain_b), input_dim=input_dim) + LebesgueMeasure(domain=(domain_a, domain_b), input_dim=input_dim) def test_lebesgue_normalization(input_dim: int): """Check that normalization constants are handled properly when not equal to one.""" domain = (0, 2) - measure = quad.LebesgueMeasure(domain=domain, input_dim=input_dim, normalized=True) + measure = LebesgueMeasure(domain=domain, input_dim=input_dim, normalized=True) volume = 2**input_dim assert measure.normalization_constant == 1 / volume @@ -105,15 +105,13 @@ def test_lebesgue_normalization(input_dim: int): def test_lebesgue_normalization_raises(domain, input_dim: int): """Check that exception is raised when normalization is not possible.""" with pytest.raises(ValueError): - quad.LebesgueMeasure(domain=domain, input_dim=input_dim, normalized=True) + LebesgueMeasure(domain=domain, input_dim=input_dim, normalized=True) def test_lebesgue_unnormalized(input_dim: int): """Check that normalization constants are handled properly when equal to one.""" - measure1 = quad.LebesgueMeasure(domain=(0, 1), input_dim=input_dim, normalized=True) - measure2 = quad.LebesgueMeasure( - domain=(0, 1), input_dim=input_dim, normalized=False - ) + measure1 = LebesgueMeasure(domain=(0, 1), input_dim=input_dim, normalized=True) + measure2 = LebesgueMeasure(domain=(0, 1), input_dim=input_dim, normalized=False) assert measure1.normalization_constant == measure2.normalization_constant diff --git a/tests/test_quad/test_kernel_embeddings.py b/tests/test_quad/test_kernel_embeddings.py index ca5ce120b..e6b4fe137 100644 --- a/tests/test_quad/test_kernel_embeddings.py +++ b/tests/test_quad/test_kernel_embeddings.py @@ -4,7 +4,7 @@ import pytest from scipy.integrate import quad -from probnum.quad import KernelEmbedding +from probnum.quad.kernel_embeddings import KernelEmbedding from .util import gauss_hermite_tensor, gauss_legendre_tensor diff --git a/tests/test_quad/test_policy.py b/tests/test_quad/test_policy.py index c4e9b0181..fa848da73 100644 --- a/tests/test_quad/test_policy.py +++ b/tests/test_quad/test_policy.py @@ -3,7 +3,8 @@ import numpy as np import pytest -from probnum.quad import GaussianMeasure, LebesgueMeasure, VanDerCorputPolicy +from probnum.quad.integration_measures import GaussianMeasure, LebesgueMeasure +from probnum.quad.solvers.policies import VanDerCorputPolicy def test_van_der_corput_multi_d_error(): diff --git a/tests/test_quad/test_stopping_criterion.py b/tests/test_quad/test_stopping_criterion.py index a6a78ba90..ae91acd97 100644 --- a/tests/test_quad/test_stopping_criterion.py +++ b/tests/test_quad/test_stopping_criterion.py @@ -5,15 +5,15 @@ import numpy as np import pytest -from probnum.quad import ( +from probnum.quad.integration_measures import LebesgueMeasure +from probnum.quad.solvers import BQIterInfo, BQState +from probnum.quad.solvers.stopping_criteria import ( BQStoppingCriterion, ImmediateStop, IntegralVarianceTolerance, - LebesgueMeasure, MaxNevals, RelativeMeanChange, ) -from probnum.quad.solvers.bq_state import BQIterInfo, BQState from probnum.randprocs.kernels import ExpQuad from probnum.randvars import Normal