From 3662bf5de6d6d639c0fd5f1d1a9edaff51539bd5 Mon Sep 17 00:00:00 2001 From: Jonathan Wenger Date: Thu, 4 Nov 2021 14:52:27 +0100 Subject: [PATCH] [PLS Refactor] Belief Updates (#546) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * belief update and some naming inconsistencies * linear system belief update skeleton * added interface for belief updates and changed problinsolve to solve * general test for linear system belief updates added * belief update classes for matrix and solution based view * added belief update interfaces for solution-based and matrix-based views * solution based belief update * solution-based docstring added * solutionbased tests pass * pylint fixes * matrix-based belief update * added some more tests for (a)symmetric matrix-based inference * symmetric matrix-based inference update * test matrix based against naive implementation * test naive implementation * nullify belief over solution * improved documentation for the solution based perspective * better docstrings for matrix-based updates * minor doc fix * removed print statement * Update src/probnum/linalg/solvers/belief_updates/_symmetric_matrix_based_linear_belief_update.py Co-authored-by: Marvin Pförtner * Update src/probnum/linalg/solvers/belief_updates/_matrix_based_linear_belief_update.py Co-authored-by: Marvin Pförtner * naming improvements * more docstring improvements * changed package structure separate matrix- and solutionbased belief updates * missing tests * Update docs/source/api/linalg/solvers.belief_updates.matrix_based.rst Co-authored-by: Marvin Pförtner * Update docs/source/api/linalg/solvers.belief_updates.solution_based.rst Co-authored-by: Marvin Pförtner * docfix Co-authored-by: Marvin Pförtner --- .../solvers.belief_updates.matrix_based.rst | 5 + .../api/linalg/solvers.belief_updates.rst | 16 +++ .../solvers.belief_updates.solution_based.rst | 5 + docs/source/api/linalg/solvers.rst | 5 + src/probnum/linalg/solvers/__init__.py | 6 +- src/probnum/linalg/solvers/_state.py | 12 +- .../linalg/solvers/belief_updates/__init__.py | 12 ++ .../_linear_solver_belief_update.py | 29 +++++ .../belief_updates/matrix_based/__init__.py | 21 ++++ .../_matrix_based_linear_belief_update.py | 76 ++++++++++++ ...etric_matrix_based_linear_belief_update.py | 77 ++++++++++++ .../belief_updates/solution_based/__init__.py | 15 +++ .../_solution_based_proj_rhs_belief_update.py | 62 ++++++++++ .../solvers/information_ops/__init__.py | 20 +-- ...op.py => _linear_solver_information_op.py} | 8 +- .../linalg/solvers/information_ops/_matvec.py | 6 +- .../solvers/information_ops/_proj_residual.py | 27 ---- .../solvers/information_ops/_projected_rhs.py | 25 ++++ .../solvers/policies/_conjugate_gradient.py | 2 +- .../solvers/policies/_linear_solver_policy.py | 2 +- .../solvers/policies/_random_unit_vector.py | 2 +- .../solvers/stopping_criteria/__init__.py | 26 ++-- .../_linear_solver_stopping_criterion.py | 12 +- .../solvers/stopping_criteria/_maxiter.py | 6 +- .../_posterior_contraction.py | 6 +- .../stopping_criteria/_residual_norm.py | 6 +- .../test_solvers/cases/belief_updates.py | 18 +++ .../test_solvers/cases/information_ops.py | 8 +- .../test_linalg/test_solvers/cases/states.py | 70 +++++++++-- .../test_solvers/cases/stopping_criteria.py | 6 +- .../test_belief_updates/__init__.py | 0 .../test_linear_system_belief_update.py | 7 ++ .../test_matrix_based/__init__.py | 0 .../test_matrix_based_linear_belief_update.py | 109 +++++++++++++++++ ...etric_matrix_based_linear_belief_update.py | 115 ++++++++++++++++++ .../test_solution_based/__init__.py | 0 ...t_solution_based_proj_rhs_belief_update.py | 103 ++++++++++++++++ .../test_linear_solver_info_op.py | 4 +- .../test_information_ops/test_matvec.py | 4 +- ...test_residual.py => test_projected_rhs.py} | 12 +- .../test_policies/test_conjugate_gradient.py | 6 +- .../test_linear_solver_policy.py | 12 +- .../test_policies/test_random_unit_vector.py | 4 +- tests/test_linalg/test_solvers/test_state.py | 8 +- .../test_linear_solver_stopping_criterion.py | 6 +- .../test_stopping_criteria/test_maxiter.py | 6 +- .../test_posterior_contraction.py | 6 +- .../test_residual_norm.py | 6 +- 48 files changed, 861 insertions(+), 138 deletions(-) create mode 100644 docs/source/api/linalg/solvers.belief_updates.matrix_based.rst create mode 100644 docs/source/api/linalg/solvers.belief_updates.rst create mode 100644 docs/source/api/linalg/solvers.belief_updates.solution_based.rst create mode 100644 src/probnum/linalg/solvers/belief_updates/__init__.py create mode 100644 src/probnum/linalg/solvers/belief_updates/_linear_solver_belief_update.py create mode 100644 src/probnum/linalg/solvers/belief_updates/matrix_based/__init__.py create mode 100644 src/probnum/linalg/solvers/belief_updates/matrix_based/_matrix_based_linear_belief_update.py create mode 100644 src/probnum/linalg/solvers/belief_updates/matrix_based/_symmetric_matrix_based_linear_belief_update.py create mode 100644 src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py create mode 100644 src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py rename src/probnum/linalg/solvers/information_ops/{_linear_solver_info_op.py => _linear_solver_information_op.py} (69%) delete mode 100644 src/probnum/linalg/solvers/information_ops/_proj_residual.py create mode 100644 src/probnum/linalg/solvers/information_ops/_projected_rhs.py create mode 100644 tests/test_linalg/test_solvers/cases/belief_updates.py create mode 100644 tests/test_linalg/test_solvers/test_belief_updates/__init__.py create mode 100644 tests/test_linalg/test_solvers/test_belief_updates/test_linear_system_belief_update.py create mode 100644 tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/__init__.py create mode 100644 tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_matrix_based_linear_belief_update.py create mode 100644 tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_symmetric_matrix_based_linear_belief_update.py create mode 100644 tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/__init__.py create mode 100644 tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_solution_based_proj_rhs_belief_update.py rename tests/test_linalg/test_solvers/test_information_ops/{test_residual.py => test_projected_rhs.py} (62%) diff --git a/docs/source/api/linalg/solvers.belief_updates.matrix_based.rst b/docs/source/api/linalg/solvers.belief_updates.matrix_based.rst new file mode 100644 index 000000000..e333c28c2 --- /dev/null +++ b/docs/source/api/linalg/solvers.belief_updates.matrix_based.rst @@ -0,0 +1,5 @@ +Belief Updates for Matrix-based Inference +----------------------------------------- +.. automodapi:: probnum.linalg.solvers.belief_updates.matrix_based + :no-heading: + :headings: "*" diff --git a/docs/source/api/linalg/solvers.belief_updates.rst b/docs/source/api/linalg/solvers.belief_updates.rst new file mode 100644 index 000000000..9fbad00c3 --- /dev/null +++ b/docs/source/api/linalg/solvers.belief_updates.rst @@ -0,0 +1,16 @@ +Belief Updates +-------------- +.. automodapi:: probnum.linalg.solvers.belief_updates + :no-heading: + :headings: "*" + + +.. toctree:: + :hidden: + + solvers.belief_updates.solution_based + +.. toctree:: + :hidden: + + solvers.belief_updates.matrix_based diff --git a/docs/source/api/linalg/solvers.belief_updates.solution_based.rst b/docs/source/api/linalg/solvers.belief_updates.solution_based.rst new file mode 100644 index 000000000..5a037564a --- /dev/null +++ b/docs/source/api/linalg/solvers.belief_updates.solution_based.rst @@ -0,0 +1,5 @@ +Belief Updates for Solution-based Inference +------------------------------------------- +.. automodapi:: probnum.linalg.solvers.belief_updates.solution_based + :no-heading: + :headings: "*" diff --git a/docs/source/api/linalg/solvers.rst b/docs/source/api/linalg/solvers.rst index 6ebe74056..5e478a11f 100644 --- a/docs/source/api/linalg/solvers.rst +++ b/docs/source/api/linalg/solvers.rst @@ -20,6 +20,11 @@ probnum.linalg.solvers solvers.information_ops +.. toctree:: + :hidden: + + solvers.belief_updates + .. toctree:: :hidden: diff --git a/src/probnum/linalg/solvers/__init__.py b/src/probnum/linalg/solvers/__init__.py index 4e1eeca10..aab40c938 100644 --- a/src/probnum/linalg/solvers/__init__.py +++ b/src/probnum/linalg/solvers/__init__.py @@ -10,16 +10,16 @@ SymmetricMatrixBasedSolver, ) -from . import beliefs, information_ops, policies +from . import belief_updates, beliefs, information_ops, policies, stopping_criteria from ._probabilistic_linear_solver import ProbabilisticLinearSolver -from ._state import ProbabilisticLinearSolverState +from ._state import LinearSolverState # Public classes and functions. Order is reflected in documentation. __all__ = [ "ProbabilisticLinearSolver", "MatrixBasedSolver", "SymmetricMatrixBasedSolver", - "ProbabilisticLinearSolverState", + "LinearSolverState", ] # Set correct module paths. Corrects links and module paths in documentation. diff --git a/src/probnum/linalg/solvers/_state.py b/src/probnum/linalg/solvers/_state.py index 7c65f95cd..a883991ea 100644 --- a/src/probnum/linalg/solvers/_state.py +++ b/src/probnum/linalg/solvers/_state.py @@ -10,7 +10,7 @@ @dataclasses.dataclass -class ProbabilisticLinearSolverState: +class LinearSolverState: """State of a probabilistic linear solver. The solver state separates the state of a probabilistic linear solver from the algorithm itself, making the solver stateless. The state contains the problem to be solved, the current belief over the quantities of interest and any miscellaneous quantities computed during an iteration of a probabilistic linear solver. The solver state is passed between the different components of the solver and may be used internally to cache quantities which are used more than once. @@ -51,8 +51,7 @@ def __repr__(self) -> str: def action(self) -> Optional[np.ndarray]: """Action of the solver for the current step. - Is ``None`` at the beginning of a step and will be set by the - policy. + Is ``None`` at the beginning of a step and will be set by the policy. """ return self._actions[self.step] @@ -65,8 +64,8 @@ def action(self, value: np.ndarray) -> None: def observation(self) -> Optional[Any]: """Observation of the solver for the current step. - Is ``None`` at the beginning of a step, will be set by the - observation model for a given action. + Is ``None`` at the beginning of a step, will be set by the observation model for + a given action. """ return self._observations[self.step] @@ -102,8 +101,7 @@ def residuals(self) -> Tuple[np.ndarray, ...]: def next_step(self) -> None: """Advance the solver state to the next solver step. - Called after a completed step / iteration of the probabilistic - linear solver. + Called after a completed step / iteration of the probabilistic linear solver. """ self._actions.append(None) self._observations.append(None) diff --git a/src/probnum/linalg/solvers/belief_updates/__init__.py b/src/probnum/linalg/solvers/belief_updates/__init__.py new file mode 100644 index 000000000..dba7df2c7 --- /dev/null +++ b/src/probnum/linalg/solvers/belief_updates/__init__.py @@ -0,0 +1,12 @@ +"""Belief updates for the quantities of interest of a linear system.""" + +from . import matrix_based, solution_based +from ._linear_solver_belief_update import LinearSolverBeliefUpdate + +# Public classes and functions. Order is reflected in documentation. +__all__ = [ + "LinearSolverBeliefUpdate", +] + +# Set correct module paths. Corrects links and module paths in documentation. +LinearSolverBeliefUpdate.__module__ = "probnum.linalg.solvers.belief_updates" diff --git a/src/probnum/linalg/solvers/belief_updates/_linear_solver_belief_update.py b/src/probnum/linalg/solvers/belief_updates/_linear_solver_belief_update.py new file mode 100644 index 000000000..47784e0e9 --- /dev/null +++ b/src/probnum/linalg/solvers/belief_updates/_linear_solver_belief_update.py @@ -0,0 +1,29 @@ +"""Linear system belief updates. + +Class defining how a belief over quantities of interest of a linear system is updated +given information about the problem. +""" +import abc + +import probnum # pylint: disable="unused-import" +from probnum.linalg.solvers.beliefs import LinearSystemBelief + + +class LinearSolverBeliefUpdate(abc.ABC): + r"""Belief update for the quantities of interest of a linear system. + + Given a solver state containing information about the linear system collected in the current step, update the belief about the quantities of interest. + """ + + @abc.abstractmethod + def __call__( + self, solver_state: "probnum.linalg.solvers.LinearSolverState" + ) -> LinearSystemBelief: + r"""Update the belief about the quantities of interest of a linear system. + + Parameters + ---------- + solver_state : + Current state of the linear solver. + """ + raise NotImplementedError diff --git a/src/probnum/linalg/solvers/belief_updates/matrix_based/__init__.py b/src/probnum/linalg/solvers/belief_updates/matrix_based/__init__.py new file mode 100644 index 000000000..6c500bf13 --- /dev/null +++ b/src/probnum/linalg/solvers/belief_updates/matrix_based/__init__.py @@ -0,0 +1,21 @@ +"""Matrix-based belief updates for the quantities of interest of a linear system.""" + +from ._matrix_based_linear_belief_update import MatrixBasedLinearBeliefUpdate +from ._symmetric_matrix_based_linear_belief_update import ( + SymmetricMatrixBasedLinearBeliefUpdate, +) + +# Public classes and functions. Order is reflected in documentation. +__all__ = [ + "MatrixBasedLinearBeliefUpdate", + "SymmetricMatrixBasedLinearBeliefUpdate", +] + +# Set correct module paths. Corrects links and module paths in documentation. +MatrixBasedLinearBeliefUpdate.__module__ = ( + "probnum.linalg.solvers.belief_updates.matrix_based" +) + +SymmetricMatrixBasedLinearBeliefUpdate.__module__ = ( + "probnum.linalg.solvers.belief_updates.matrix_based" +) diff --git a/src/probnum/linalg/solvers/belief_updates/matrix_based/_matrix_based_linear_belief_update.py b/src/probnum/linalg/solvers/belief_updates/matrix_based/_matrix_based_linear_belief_update.py new file mode 100644 index 000000000..9f1b07c50 --- /dev/null +++ b/src/probnum/linalg/solvers/belief_updates/matrix_based/_matrix_based_linear_belief_update.py @@ -0,0 +1,76 @@ +"""Belief update in a matrix-based inference view where the information is given by +matrix-vector multiplication.""" +import numpy as np + +import probnum # pylint: disable="unused-import" +from probnum import linops, randvars +from probnum.linalg.solvers.beliefs import LinearSystemBelief + +from .._linear_solver_belief_update import LinearSolverBeliefUpdate + + +class MatrixBasedLinearBeliefUpdate(LinearSolverBeliefUpdate): + r"""Gaussian belief update in a matrix-based inference framework assuming linear information. + + Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given matrix-variate Gaussian beliefs with Kronecker covariance structure and linear observations :math:`y=As`. The belief update computes :math:`p(M \mid y) = \mathcal{N}(M; M_{i+1}, V \otimes W_{i+1})`, [1]_ [2]_ such that + + .. math :: + \begin{align} + M_{i+1} &= M_i + (y - M_i s) (s^\top W_i s)^\dagger s^\top W_i,\\ + W_{i+1} &= W_i - W_i s (s^\top W_i s)^\dagger s^\top W_i. + \end{align} + + + References + ---------- + .. [1] Hennig, P., Probabilistic Interpretation of Linear Solvers, *SIAM Journal on + Optimization*, 2015, 25, 234-260 + .. [2] Wenger, J. and Hennig, P., Probabilistic Linear Solvers for Machine Learning, + *Advances in Neural Information Processing Systems (NeurIPS)*, 2020 + """ + + def __call__( + self, solver_state: "probnum.linalg.solvers.LinearSolverState" + ) -> LinearSystemBelief: + + # Inference for A + A = self._matrix_based_update( + matrix=solver_state.belief.A, + action=solver_state.action, + observ=solver_state.observation, + ) + + # Inference for Ainv (interpret action and observation as swapped) + Ainv = self._matrix_based_update( + matrix=solver_state.belief.Ainv, + action=solver_state.observation, + observ=solver_state.action, + ) + return LinearSystemBelief(A=A, Ainv=Ainv, x=None, b=solver_state.belief.b) + + def _matrix_based_update( + self, matrix: randvars.Normal, action: np.ndarray, observ: np.ndarray + ) -> randvars.Normal: + """Matrix-based inference update for linear information.""" + if not isinstance(matrix.cov, linops.Kronecker): + raise ValueError( + f"Covariance must have Kronecker structure, but is '{type(matrix.cov).__name__}'." + ) + + pred = matrix.mean @ action + resid = observ - pred + covfactor_Ms = matrix.cov.B @ action + gram = action.T @ covfactor_Ms + gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 + gain = covfactor_Ms * gram_pinv + covfactor_update = linops.aslinop(gain[:, None]) @ linops.aslinop( + covfactor_Ms[None, :] + ) + resid_gain = linops.aslinop(resid[:, None]) @ linops.aslinop( + gain[None, :] + ) # residual and gain are flipped due to matrix vectorization + + return randvars.Normal( + mean=matrix.mean + resid_gain, + cov=linops.Kronecker(A=matrix.cov.A, B=matrix.cov.B - covfactor_update), + ) diff --git a/src/probnum/linalg/solvers/belief_updates/matrix_based/_symmetric_matrix_based_linear_belief_update.py b/src/probnum/linalg/solvers/belief_updates/matrix_based/_symmetric_matrix_based_linear_belief_update.py new file mode 100644 index 000000000..649dc9891 --- /dev/null +++ b/src/probnum/linalg/solvers/belief_updates/matrix_based/_symmetric_matrix_based_linear_belief_update.py @@ -0,0 +1,77 @@ +"""Belief update in a matrix-based inference view assuming symmetry where the +information is given by matrix-vector multiplication.""" +import numpy as np + +import probnum # pylint: disable="unused-import" +from probnum import linops, randvars +from probnum.linalg.solvers.beliefs import LinearSystemBelief + +from .._linear_solver_belief_update import LinearSolverBeliefUpdate + + +class SymmetricMatrixBasedLinearBeliefUpdate(LinearSolverBeliefUpdate): + r"""Symmetric Gaussian belief update in a matrix-based inference framework assuming linear information. + + Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given symmetric matrix-variate Gaussian beliefs with symmetric Kronecker covariance structure and linear observations. The belief update computes :math:`p(M \mid y) = \mathcal{N}(M; M_{i+1}, W_{i+1} \otimes_s W_{i+1})`, [1]_ [2]_ such that + + .. math :: + \begin{align} + M_{i+1} &= M_i + (y - M_i s) u^\top + u (y - M_i s)^\top - u s^\top(y - M_i s)u^\top,\\ + W_{i+1} &= W_i - W_i s (s^\top W_i s)^\dagger s^\top W_i. + \end{align} + + where :math:`u = W_i s (s^\top W s)^\dagger`. + + References + ---------- + .. [1] Hennig, P., Probabilistic Interpretation of Linear Solvers, *SIAM Journal on + Optimization*, 2015, 25, 234-260 + .. [2] Wenger, J. and Hennig, P., Probabilistic Linear Solvers for Machine Learning, + *Advances in Neural Information Processing Systems (NeurIPS)*, 2020 + """ + + def __call__( + self, solver_state: "probnum.linalg.solvers.LinearSolverState" + ) -> LinearSystemBelief: + + # Inference for A + A = self._symmetric_matrix_based_update( + matrix=solver_state.belief.A, + action=solver_state.action, + observ=solver_state.observation, + ) + + # Inference for Ainv (interpret action and observation as swapped) + Ainv = self._symmetric_matrix_based_update( + matrix=solver_state.belief.Ainv, + action=solver_state.observation, + observ=solver_state.action, + ) + return LinearSystemBelief(A=A, Ainv=Ainv, x=None, b=solver_state.belief.b) + + def _symmetric_matrix_based_update( + self, matrix: randvars.Normal, action: np.ndarray, observ: np.ndarray + ) -> randvars.Normal: + """Symmetric matrix-based inference update for linear information.""" + if not isinstance(matrix.cov, linops.SymmetricKronecker): + raise ValueError( + f"Covariance must have symmetric Kronecker structure, but is '{type(matrix.cov).__name__}'." + ) + + pred = matrix.mean @ action + resid = observ - pred + covfactor_Ms = matrix.cov.A @ action + gram = action.T @ covfactor_Ms + gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 + gain = covfactor_Ms * gram_pinv + covfactor_update = gain @ covfactor_Ms.T + resid_gain = linops.aslinop(resid[:, None]) @ linops.aslinop(gain[None, :]) + + return randvars.Normal( + mean=matrix.mean + + resid_gain + + resid_gain.T + - linops.aslinop(gain[:, None]) + @ linops.aslinop((action.T @ resid_gain)[None, :]), + cov=linops.SymmetricKronecker(A=matrix.cov.A - covfactor_update), + ) diff --git a/src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py b/src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py new file mode 100644 index 000000000..fec2b1a72 --- /dev/null +++ b/src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py @@ -0,0 +1,15 @@ +"""Solution-based belief updates for the quantities of interest of a linear system.""" + +from ._solution_based_proj_rhs_belief_update import ( + SolutionBasedProjectedRHSBeliefUpdate, +) + +# Public classes and functions. Order is reflected in documentation. +__all__ = [ + "SolutionBasedProjectedRHSBeliefUpdate", +] + +# Set correct module paths. Corrects links and module paths in documentation. +SolutionBasedProjectedRHSBeliefUpdate.__module__ = ( + "probnum.linalg.solvers.belief_updates.solution_based" +) diff --git a/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py b/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py new file mode 100644 index 000000000..68607b90f --- /dev/null +++ b/src/probnum/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py @@ -0,0 +1,62 @@ +"""Belief update in a solution-based inference view where the information is given by +projecting the current residual to a subspace.""" +import probnum # pylint: disable="unused-import" +from probnum import randvars +from probnum.linalg.solvers.beliefs import LinearSystemBelief +from probnum.typing import FloatArgType + +from .._linear_solver_belief_update import LinearSolverBeliefUpdate + + +class SolutionBasedProjectedRHSBeliefUpdate(LinearSolverBeliefUpdate): + r"""Gaussian belief update in a solution-based inference framework assuming projected right-hand-side information. + + Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given a Gaussian belief over the solution :math:`x` and information of the form :math:`y = s\^top b=s^\top Ax`. The belief update computes the posterior belief about the solution, given by :math:`p(x \mid y) = \mathcal{N}(x; x_{i+1}, \Sigma_{i+1})`, [1]_ such that + + .. math :: + \begin{align} + x_{i+1} &= x_i + \Sigma_i A^\top s (s^\top A \Sigma_i A^\top s + \lambda)^\dagger s^\top (b - Ax_i),\\ + \Sigma_{i+1} &= \Sigma_i - \Sigma_i A^\top s (s^\top A \Sigma_i A s + \lambda)^\dagger s^\top A \Sigma_i, + \end{align} + + where :math:`\lambda` is the noise variance. + + + Parameters + ---------- + noise_var : + Variance of the scalar observation noise. + + References + ---------- + .. [1] Cockayne, J. et al., A Bayesian Conjugate Gradient Method, *Bayesian + Analysis*, 2019, 14, 937-1012 + """ + + def __init__(self, noise_var: FloatArgType = 0.0) -> None: + if noise_var < 0.0: + raise ValueError(f"Noise variance {noise_var} must be non-negative.") + self._noise_var = noise_var + + def __call__( + self, solver_state: "probnum.linalg.solvers.LinearSolverState" + ) -> LinearSystemBelief: + + action_A = solver_state.action @ solver_state.problem.A + pred = action_A @ solver_state.belief.x.mean + proj_resid = solver_state.observation - pred + cov_xy = solver_state.belief.x.cov @ action_A.T + gram = action_A @ cov_xy + self._noise_var + gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 + gain = cov_xy * gram_pinv + cov_update = gain @ cov_xy.T + + x = randvars.Normal( + mean=solver_state.belief.x.mean + gain * proj_resid, + cov=solver_state.belief.x.cov - cov_update, + ) + Ainv = solver_state.belief.Ainv + cov_update + + return LinearSystemBelief( + x=x, A=solver_state.belief.A, Ainv=Ainv, b=solver_state.belief.b + ) diff --git a/src/probnum/linalg/solvers/information_ops/__init__.py b/src/probnum/linalg/solvers/information_ops/__init__.py index 54068c27f..6c3b14e73 100644 --- a/src/probnum/linalg/solvers/information_ops/__init__.py +++ b/src/probnum/linalg/solvers/information_ops/__init__.py @@ -4,21 +4,21 @@ by observing the numerical problem to be solved given an action. When solving linear systems, the information operator takes an action vector and observes the tuple :math:`(A, b)`, returning an observation vector. For example, one might -observe the projected residual :math:`y = s^\top (A x_i - b)` with the action :math:`s`. +observe the right hand side :math:`y = b^\top s = (Ax)^\top s` with the action :math:`s`. """ -from ._linear_solver_info_op import LinearSolverInfoOp -from ._matvec import MatVecInfoOp -from ._proj_residual import ProjResidualInfoOp +from ._linear_solver_information_op import LinearSolverInformationOp +from ._matvec import MatVecInformationOp +from ._projected_rhs import ProjectedRHSInformationOp # Public classes and functions. Order is reflected in documentation. __all__ = [ - "LinearSolverInfoOp", - "MatVecInfoOp", - "ProjResidualInfoOp", + "LinearSolverInformationOp", + "MatVecInformationOp", + "ProjectedRHSInformationOp", ] # Set correct module paths. Corrects links and module paths in documentation. -LinearSolverInfoOp.__module__ = "probnum.linalg.solvers.information_ops" -MatVecInfoOp.__module__ = "probnum.linalg.solvers.information_ops" -ProjResidualInfoOp.__module__ = "probnum.linalg.solvers.information_ops" +LinearSolverInformationOp.__module__ = "probnum.linalg.solvers.information_ops" +MatVecInformationOp.__module__ = "probnum.linalg.solvers.information_ops" +ProjectedRHSInformationOp.__module__ = "probnum.linalg.solvers.information_ops" diff --git a/src/probnum/linalg/solvers/information_ops/_linear_solver_info_op.py b/src/probnum/linalg/solvers/information_ops/_linear_solver_information_op.py similarity index 69% rename from src/probnum/linalg/solvers/information_ops/_linear_solver_info_op.py rename to src/probnum/linalg/solvers/information_ops/_linear_solver_information_op.py index 18711bf31..41aafcce2 100644 --- a/src/probnum/linalg/solvers/information_ops/_linear_solver_info_op.py +++ b/src/probnum/linalg/solvers/information_ops/_linear_solver_information_op.py @@ -6,20 +6,20 @@ import probnum # pylint: disable="unused-import" -class LinearSolverInfoOp(abc.ABC): +class LinearSolverInformationOp(abc.ABC): r"""Information operator of a (probabilistic) linear solver. For a given action, the information operator collects information about the linear system to be solved. See Also -------- - MatVecInfoOp: Collect information via matrix-vector multiplication. - ProjResidualInfoOp: Collect information via a projection of the current residual. + MatVecInformationOp: Collect information via matrix-vector multiplication. + ProjectedRHSInformationOp: Collect information via a projection of the current residual. """ @abc.abstractmethod def __call__( - self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" + self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> np.ndarray: """Return information about the linear system for a given solver state. diff --git a/src/probnum/linalg/solvers/information_ops/_matvec.py b/src/probnum/linalg/solvers/information_ops/_matvec.py index f65e80056..430f21889 100644 --- a/src/probnum/linalg/solvers/information_ops/_matvec.py +++ b/src/probnum/linalg/solvers/information_ops/_matvec.py @@ -3,10 +3,10 @@ import probnum # pylint: disable="unused-import" -from ._linear_solver_info_op import LinearSolverInfoOp +from ._linear_solver_information_op import LinearSolverInformationOp -class MatVecInfoOp(LinearSolverInfoOp): +class MatVecInformationOp(LinearSolverInformationOp): r"""Matrix-vector product :math:`s_i \mapsto A s_i` with the system matrix. Obtain information about a linear system by multiplying an action :math:`s_i` @@ -14,7 +14,7 @@ class MatVecInfoOp(LinearSolverInfoOp): """ def __call__( - self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" + self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> np.ndarray: r"""Matrix vector product with the system matrix :math:`A`. diff --git a/src/probnum/linalg/solvers/information_ops/_proj_residual.py b/src/probnum/linalg/solvers/information_ops/_proj_residual.py deleted file mode 100644 index 166141fb8..000000000 --- a/src/probnum/linalg/solvers/information_ops/_proj_residual.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Information operator returning a projection of the residual.""" -import numpy as np - -import probnum # pylint: disable="unused-import" - -from ._linear_solver_info_op import LinearSolverInfoOp - - -class ProjResidualInfoOp(LinearSolverInfoOp): - r"""Projected residual :math:`s_i \mapsto s_i^\top (A x_i-b)` of the linear system. - - Obtain information about a linear system by projecting the current - residual :math:`r_i = A x_i - b` onto a given action :math:`s_i` resulting - in :math:`y_i = s_i^\top r_i`. - """ - - def __call__( - self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" - ) -> np.ndarray: - r"""Projected residual :math:`s_i^\top (A x_i - b)` of the linear system. - - Parameters - ---------- - solver_state : - Current state of the linear solver. - """ - return solver_state.action @ solver_state.residual diff --git a/src/probnum/linalg/solvers/information_ops/_projected_rhs.py b/src/probnum/linalg/solvers/information_ops/_projected_rhs.py new file mode 100644 index 000000000..f9d7fcdbf --- /dev/null +++ b/src/probnum/linalg/solvers/information_ops/_projected_rhs.py @@ -0,0 +1,25 @@ +"""Information operator returning a projection of the residual.""" +import numpy as np + +import probnum # pylint: disable="unused-import" + +from ._linear_solver_information_op import LinearSolverInformationOp + + +class ProjectedRHSInformationOp(LinearSolverInformationOp): + r"""Projected right hand side :math:`s_i \mapsto b^\top s_i = (Ax)^\top s_i` of the linear system. + + Obtain information about a linear system by projecting the right hand side :math:`b=Ax` onto a given action :math:`s_i` resulting in :math:`y_i = s_i^\top b`. + """ + + def __call__( + self, solver_state: "probnum.linalg.solvers.LinearSolverState" + ) -> np.ndarray: + r"""Projected right hand side :math:`s_i^\top b = s_i^\top Ax` of the linear system. + + Parameters + ---------- + solver_state : + Current state of the linear solver. + """ + return solver_state.action @ solver_state.problem.b diff --git a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py index 5ef6cde77..3069f9b3c 100644 --- a/src/probnum/linalg/solvers/policies/_conjugate_gradient.py +++ b/src/probnum/linalg/solvers/policies/_conjugate_gradient.py @@ -14,7 +14,7 @@ class ConjugateGradientPolicy(_linear_solver_policy.LinearSolverPolicy): """ def __call__( - self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" + self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> np.ndarray: action = -solver_state.residual.copy() diff --git a/src/probnum/linalg/solvers/policies/_linear_solver_policy.py b/src/probnum/linalg/solvers/policies/_linear_solver_policy.py index 7c8c10e4e..220da66b5 100644 --- a/src/probnum/linalg/solvers/policies/_linear_solver_policy.py +++ b/src/probnum/linalg/solvers/policies/_linear_solver_policy.py @@ -22,7 +22,7 @@ class LinearSolverPolicy(abc.ABC): @abc.abstractmethod def __call__( - self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" + self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> np.ndarray: """Return an action for a given solver state. diff --git a/src/probnum/linalg/solvers/policies/_random_unit_vector.py b/src/probnum/linalg/solvers/policies/_random_unit_vector.py index 229b3e984..ca19ea119 100644 --- a/src/probnum/linalg/solvers/policies/_random_unit_vector.py +++ b/src/probnum/linalg/solvers/policies/_random_unit_vector.py @@ -14,7 +14,7 @@ class RandomUnitVectorPolicy(_linear_solver_policy.LinearSolverPolicy): """ def __call__( - self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" + self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> np.ndarray: n = solver_state.problem.A.shape[1] diff --git a/src/probnum/linalg/solvers/stopping_criteria/__init__.py b/src/probnum/linalg/solvers/stopping_criteria/__init__.py index cbbcb7a60..cbb5e6423 100644 --- a/src/probnum/linalg/solvers/stopping_criteria/__init__.py +++ b/src/probnum/linalg/solvers/stopping_criteria/__init__.py @@ -1,20 +1,22 @@ """Stopping criteria for probabilistic linear solvers.""" -from ._linear_solver_stopping_criterion import LinearSolverStopCrit -from ._maxiter import MaxIterationsStopCrit -from ._posterior_contraction import PosteriorContractionStopCrit -from ._residual_norm import ResidualNormStopCrit +from ._linear_solver_stopping_criterion import LinearSolverStoppingCriterion +from ._maxiter import MaxIterationsStoppingCriterion +from ._posterior_contraction import PosteriorContractionStoppingCriterion +from ._residual_norm import ResidualNormStoppingCriterion # Public classes and functions. Order is reflected in documentation. __all__ = [ - "LinearSolverStopCrit", - "MaxIterationsStopCrit", - "ResidualNormStopCrit", - "PosteriorContractionStopCrit", + "LinearSolverStoppingCriterion", + "MaxIterationsStoppingCriterion", + "ResidualNormStoppingCriterion", + "PosteriorContractionStoppingCriterion", ] # Set correct module paths. Corrects links and module paths in documentation. -LinearSolverStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria" -MaxIterationsStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria" -ResidualNormStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria" -PosteriorContractionStopCrit.__module__ = "probnum.linalg.solvers.stopping_criteria" +LinearSolverStoppingCriterion.__module__ = "probnum.linalg.solvers.stopping_criteria" +MaxIterationsStoppingCriterion.__module__ = "probnum.linalg.solvers.stopping_criteria" +ResidualNormStoppingCriterion.__module__ = "probnum.linalg.solvers.stopping_criteria" +PosteriorContractionStoppingCriterion.__module__ = ( + "probnum.linalg.solvers.stopping_criteria" +) diff --git a/src/probnum/linalg/solvers/stopping_criteria/_linear_solver_stopping_criterion.py b/src/probnum/linalg/solvers/stopping_criteria/_linear_solver_stopping_criterion.py index 8310be3c5..26147586e 100644 --- a/src/probnum/linalg/solvers/stopping_criteria/_linear_solver_stopping_criterion.py +++ b/src/probnum/linalg/solvers/stopping_criteria/_linear_solver_stopping_criterion.py @@ -5,21 +5,21 @@ import probnum # pylint: disable="unused-import" -class LinearSolverStopCrit(abc.ABC): +class LinearSolverStoppingCriterion(abc.ABC): r"""Stopping criterion of a (probabilistic) linear solver. - Checks whether quantities tracked by the :class:`~probnum.linalg.solvers.ProbabilisticLinearSolverState` meet a desired terminal condition. + Checks whether quantities tracked by the :class:`~probnum.linalg.solvers.LinearSolverState` meet a desired terminal condition. See Also -------- - ResidualNormStopCrit : Stop based on the norm of the residual. - PosteriorContractionStopCrit : Stop based on the uncertainty about the quantity of interest. - MaxIterationsStopCrit : Stop after a maximum number of iterations. + ResidualNormStoppingCriterion : Stop based on the norm of the residual. + PosteriorContractionStoppingCriterion : Stop based on the uncertainty about the quantity of interest. + MaxIterationsStoppingCriterion : Stop after a maximum number of iterations. """ @abc.abstractmethod def __call__( - self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" + self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> bool: """Check whether tracked quantities meet a desired terminal condition. diff --git a/src/probnum/linalg/solvers/stopping_criteria/_maxiter.py b/src/probnum/linalg/solvers/stopping_criteria/_maxiter.py index c69e7bf92..730ff050e 100644 --- a/src/probnum/linalg/solvers/stopping_criteria/_maxiter.py +++ b/src/probnum/linalg/solvers/stopping_criteria/_maxiter.py @@ -3,10 +3,10 @@ import probnum # pylint: disable="unused-import" -from ._linear_solver_stopping_criterion import LinearSolverStopCrit +from ._linear_solver_stopping_criterion import LinearSolverStoppingCriterion -class MaxIterationsStopCrit(LinearSolverStopCrit): +class MaxIterationsStoppingCriterion(LinearSolverStoppingCriterion): r"""Stop after a maximum number of iterations. Stop when the solver has taken a maximum number of steps. If ``None`` is @@ -23,7 +23,7 @@ def __init__(self, maxiter: Optional[int] = None): self.maxiter = maxiter def __call__( - self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" + self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> bool: """Check whether the maximum number of iterations has been reached. diff --git a/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py b/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py index 4d9c3ab38..cc5172224 100644 --- a/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py +++ b/src/probnum/linalg/solvers/stopping_criteria/_posterior_contraction.py @@ -5,10 +5,10 @@ import probnum # pylint: disable="unused-import" from probnum.typing import ScalarArgType -from ._linear_solver_stopping_criterion import LinearSolverStopCrit +from ._linear_solver_stopping_criterion import LinearSolverStoppingCriterion -class PosteriorContractionStopCrit(LinearSolverStopCrit): +class PosteriorContractionStoppingCriterion(LinearSolverStoppingCriterion): r"""Posterior contraction stopping criterion. Terminate when the uncertainty about the quantity of interest :math:`q` is @@ -37,7 +37,7 @@ def __init__( self.rtol = probnum.utils.as_numpy_scalar(rtol) def __call__( - self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" + self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> bool: """Check whether the uncertainty about the quantity of interest is smaller than the specified tolerance. diff --git a/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py b/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py index a883196d7..57026417d 100644 --- a/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py +++ b/src/probnum/linalg/solvers/stopping_criteria/_residual_norm.py @@ -5,10 +5,10 @@ import probnum from probnum.typing import ScalarArgType -from ._linear_solver_stopping_criterion import LinearSolverStopCrit +from ._linear_solver_stopping_criterion import LinearSolverStoppingCriterion -class ResidualNormStopCrit(LinearSolverStopCrit): +class ResidualNormStoppingCriterion(LinearSolverStoppingCriterion): r"""Residual stopping criterion. Terminate when the euclidean norm of the residual :math:`r_{i} = A x_{i} - b` is @@ -32,7 +32,7 @@ def __init__( self.rtol = probnum.utils.as_numpy_scalar(rtol) def __call__( - self, solver_state: "probnum.linalg.solvers.ProbabilisticLinearSolverState" + self, solver_state: "probnum.linalg.solvers.LinearSolverState" ) -> bool: """Check whether the residual norm is smaller than the specified tolerance. diff --git a/tests/test_linalg/test_solvers/cases/belief_updates.py b/tests/test_linalg/test_solvers/cases/belief_updates.py new file mode 100644 index 000000000..628a4fe12 --- /dev/null +++ b/tests/test_linalg/test_solvers/cases/belief_updates.py @@ -0,0 +1,18 @@ +"""Test cases describing different belief updates over quantities of interest of a +linear system.""" +from pytest_cases import parametrize + +from probnum.linalg.solvers.belief_updates import matrix_based, solution_based + + +@parametrize(noise_var=[0.0, 0.001, 1.0]) +def case_solution_based_projected_rhs_belief_update(noise_var: float): + return solution_based.SolutionBasedProjectedRHSBeliefUpdate(noise_var=noise_var) + + +def case_matrix_based_linear_belief_update(): + return matrix_based.MatrixBasedLinearBeliefUpdate() + + +def case_symmetric_matrix_based_linear_belief_update(): + return matrix_based.SymmetricMatrixBasedLinearBeliefUpdate() diff --git a/tests/test_linalg/test_solvers/cases/information_ops.py b/tests/test_linalg/test_solvers/cases/information_ops.py index e2034c480..b200bb610 100644 --- a/tests/test_linalg/test_solvers/cases/information_ops.py +++ b/tests/test_linalg/test_solvers/cases/information_ops.py @@ -1,13 +1,11 @@ """Test cases defined by information operators.""" -from pytest_cases import case - from probnum.linalg.solvers import information_ops def case_matvec(): - return information_ops.MatVecInfoOp() + return information_ops.MatVecInformationOp() -def case_proj_residual(): - return information_ops.ProjResidualInfoOp() +def case_projected_rhs(): + return information_ops.ProjectedRHSInformationOp() diff --git a/tests/test_linalg/test_solvers/cases/states.py b/tests/test_linalg/test_solvers/cases/states.py index cceee4246..44f1406c6 100644 --- a/tests/test_linalg/test_solvers/cases/states.py +++ b/tests/test_linalg/test_solvers/cases/states.py @@ -32,9 +32,7 @@ def case_initial_state( rng: np.random.Generator, ): """Initial state of a linear solver.""" - return linalg.solvers.ProbabilisticLinearSolverState( - problem=linsys, prior=prior, rng=rng - ) + return linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) @case(tags=["has_action"]) @@ -42,10 +40,70 @@ def case_state( rng: np.random.Generator, ): """State of a linear solver.""" - initial_state = linalg.solvers.ProbabilisticLinearSolverState( + state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) + state.action = rng.standard_normal(size=state.problem.A.shape[1]) + + return state + + +@case(tags=["has_action", "has_observation", "matrix_based"]) +def case_state_matrix_based( + rng: np.random.Generator, +): + """State of a matrix-based linear solver.""" + prior = linalg.solvers.beliefs.LinearSystemBelief( + A=randvars.Normal( + mean=linops.Matrix(linsys.A), + cov=linops.Kronecker(A=linops.Identity(n), B=linops.Identity(n)), + ), + x=(Ainv @ b[:, None]).reshape((n,)), + Ainv=randvars.Normal( + mean=linops.Identity(n), + cov=linops.Kronecker(A=linops.Identity(n), B=linops.Identity(n)), + ), + b=b, + ) + state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) + state.action = rng.standard_normal(size=state.problem.A.shape[1]) + state.observation = rng.standard_normal(size=state.problem.A.shape[1]) + + return state + + +@case(tags=["has_action", "has_observation", "symmetric_matrix_based"]) +def case_state_symmetric_matrix_based( + rng: np.random.Generator, +): + """State of a symmetric matrix-based linear solver.""" + prior = linalg.solvers.beliefs.LinearSystemBelief( + A=randvars.Normal( + mean=linops.Matrix(linsys.A), + cov=linops.SymmetricKronecker(A=linops.Identity(n)), + ), + x=(Ainv @ b[:, None]).reshape((n,)), + Ainv=randvars.Normal( + mean=linops.Identity(n), + cov=linops.SymmetricKronecker(A=linops.Identity(n)), + ), + b=b, + ) + state = linalg.solvers.LinearSolverState(problem=linsys, prior=prior, rng=rng) + state.action = rng.standard_normal(size=state.problem.A.shape[1]) + state.observation = rng.standard_normal(size=state.problem.A.shape[1]) + + return state + + +@case(tags=["has_action", "has_observation", "solution_based"]) +def case_state_solution_based( + rng: np.random.Generator, +): + """State of a solution-based linear solver.""" + initial_state = linalg.solvers.LinearSolverState( problem=linsys, prior=prior, rng=rng ) initial_state.action = rng.standard_normal(size=initial_state.problem.A.shape[1]) + initial_state.observation = rng.standard_normal() return initial_state @@ -60,7 +118,5 @@ def case_state_converged( x=randvars.Constant(linsys.solution), b=randvars.Constant(linsys.b), ) - state = linalg.solvers.ProbabilisticLinearSolverState( - problem=linsys, prior=belief, rng=rng - ) + state = linalg.solvers.LinearSolverState(problem=linsys, prior=belief, rng=rng) return state diff --git a/tests/test_linalg/test_solvers/cases/stopping_criteria.py b/tests/test_linalg/test_solvers/cases/stopping_criteria.py index 47bae9b22..826255d70 100644 --- a/tests/test_linalg/test_solvers/cases/stopping_criteria.py +++ b/tests/test_linalg/test_solvers/cases/stopping_criteria.py @@ -6,13 +6,13 @@ def case_maxiter(): - return stopping_criteria.MaxIterationsStopCrit() + return stopping_criteria.MaxIterationsStoppingCriterion() def case_residual_norm(): - return stopping_criteria.ResidualNormStopCrit() + return stopping_criteria.ResidualNormStoppingCriterion() @parametrize("qoi", ["x", "Ainv", "A"]) def case_posterior_contraction(qoi: str): - return stopping_criteria.PosteriorContractionStopCrit(qoi=qoi) + return stopping_criteria.PosteriorContractionStoppingCriterion(qoi=qoi) diff --git a/tests/test_linalg/test_solvers/test_belief_updates/__init__.py b/tests/test_linalg/test_solvers/test_belief_updates/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_linalg/test_solvers/test_belief_updates/test_linear_system_belief_update.py b/tests/test_linalg/test_solvers/test_belief_updates/test_linear_system_belief_update.py new file mode 100644 index 000000000..81608daa6 --- /dev/null +++ b/tests/test_linalg/test_solvers/test_belief_updates/test_linear_system_belief_update.py @@ -0,0 +1,7 @@ +"""Tests for belief updates about quantities of interest of a linear system.""" + +import pathlib + +case_modules = (pathlib.Path(__file__).parent / "cases").stem +cases_belief_updates = case_modules + ".belief_updates" +cases_states = case_modules + ".states" diff --git a/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/__init__.py b/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_matrix_based_linear_belief_update.py b/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_matrix_based_linear_belief_update.py new file mode 100644 index 000000000..ca1a6661d --- /dev/null +++ b/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_matrix_based_linear_belief_update.py @@ -0,0 +1,109 @@ +"""Tests for the matrix-based belief update for linear information.""" + +import pathlib + +import numpy as np +import pytest +from pytest_cases import parametrize_with_cases + +from probnum import linops, randvars +from probnum.linalg.solvers import LinearSolverState, belief_updates, beliefs + +case_modules = (pathlib.Path(__file__).parent.parent / "cases").stem +cases_belief_updates = case_modules + ".belief_updates" +cases_states = case_modules + ".states" + + +@parametrize_with_cases( + "belief_update", cases=cases_belief_updates, glob="matrix_based_linear*" +) +@parametrize_with_cases( + "state", + cases=cases_states, + has_tag=["has_action", "has_observation", "matrix_based"], +) +def test_returns_linear_system_belief( + belief_update: belief_updates.matrix_based.MatrixBasedLinearBeliefUpdate, + state: LinearSolverState, +): + belief = belief_update(solver_state=state) + assert isinstance(belief, beliefs.LinearSystemBelief) + + +@parametrize_with_cases( + "belief_update", cases=cases_belief_updates, glob="matrix_based_linear*" +) +@parametrize_with_cases( + "state", + cases=cases_states, + has_tag=["has_action", "has_observation", "symmetric_matrix_based"], +) +def test_raises_error_for_non_Kronecker_structured_covariances( + belief_update: belief_updates.matrix_based.MatrixBasedLinearBeliefUpdate, + state: LinearSolverState, +): + with pytest.raises(ValueError): + belief_update(solver_state=state) + + +@parametrize_with_cases( + "belief_update", cases=cases_belief_updates, glob="matrix_based_linear*" +) +@parametrize_with_cases( + "state", + cases=cases_states, + has_tag=["has_action", "has_observation", "matrix_based"], +) +def test_against_naive_implementation( + belief_update: belief_updates.matrix_based.MatrixBasedLinearBeliefUpdate, + state: LinearSolverState, +): + """Compare the updated belief to a naive implementation.""" + + def dense_matrix_based_update( + matrix: randvars.Normal, action: np.ndarray, observ: np.ndarray + ): + pred = matrix.mean @ action + resid = observ - pred + covfactor_Ms = matrix.cov.B @ action + gram = action.T @ covfactor_Ms + gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 + gain = covfactor_Ms * gram_pinv + covfactor_update = np.outer(gain, covfactor_Ms) + + return randvars.Normal( + mean=matrix.mean + np.outer(resid, gain), + cov=linops.Kronecker(A=matrix.cov.A, B=matrix.cov.B - covfactor_update), + ) + + updated_belief = belief_update(solver_state=state) + A_naive = dense_matrix_based_update( + matrix=state.belief.A, action=state.action, observ=state.observation + ) + Ainv_naive = dense_matrix_based_update( + matrix=state.belief.Ainv, action=state.observation, observ=state.action + ) + + # System matrix + np.testing.assert_allclose( + updated_belief.A.mean.todense(), + A_naive.mean.todense(), + err_msg="Mean of system matrix estimate does not match naive implementation.", + ) + np.testing.assert_allclose( + updated_belief.A.cov.todense(), + A_naive.cov.todense(), + err_msg="Covariance of system matrix estimate does not match naive implementation.", + ) + + # Inverse + np.testing.assert_allclose( + updated_belief.Ainv.mean.todense(), + Ainv_naive.mean.todense(), + err_msg="Mean of matrix inverse estimate does not match naive implementation.", + ) + np.testing.assert_allclose( + updated_belief.Ainv.cov.todense(), + Ainv_naive.cov.todense(), + err_msg="Covariance of matrix inverse estimate does not match naive implementation.", + ) diff --git a/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_symmetric_matrix_based_linear_belief_update.py b/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_symmetric_matrix_based_linear_belief_update.py new file mode 100644 index 000000000..cbf2817d1 --- /dev/null +++ b/tests/test_linalg/test_solvers/test_belief_updates/test_matrix_based/test_symmetric_matrix_based_linear_belief_update.py @@ -0,0 +1,115 @@ +"""Tests for the symmetric matrix-based belief update for linear information.""" + +import pathlib + +import numpy as np +import pytest +from pytest_cases import parametrize_with_cases + +from probnum import linops, randvars +from probnum.linalg.solvers import LinearSolverState, belief_updates, beliefs + +case_modules = (pathlib.Path(__file__).parent.parent / "cases").stem +cases_belief_updates = case_modules + ".belief_updates" +cases_states = case_modules + ".states" + + +@parametrize_with_cases( + "belief_update", + cases=cases_belief_updates, + glob="symmetric_matrix_based_linear*", +) +@parametrize_with_cases( + "state", + cases=cases_states, + has_tag=["has_action", "has_observation", "symmetric_matrix_based"], +) +def test_returns_linear_system_belief( + belief_update: belief_updates.matrix_based.SymmetricMatrixBasedLinearBeliefUpdate, + state: LinearSolverState, +): + belief = belief_update(solver_state=state) + assert isinstance(belief, beliefs.LinearSystemBelief) + + +@parametrize_with_cases( + "belief_update", cases=cases_belief_updates, glob="symmetric_matrix_based_linear*" +) +@parametrize_with_cases( + "state", + cases=cases_states, + has_tag=["has_action", "has_observation", "matrix_based"], +) +def test_raises_error_for_non_symmetric_Kronecker_structured_covariances( + belief_update: belief_updates.matrix_based.SymmetricMatrixBasedLinearBeliefUpdate, + state: LinearSolverState, +): + with pytest.raises(ValueError): + belief_update(solver_state=state) + + +@parametrize_with_cases( + "belief_update", cases=cases_belief_updates, glob="symmetric_matrix_based_linear*" +) +@parametrize_with_cases( + "state", + cases=cases_states, + has_tag=["has_action", "has_observation", "symmetric_matrix_based"], +) +def test_against_naive_implementation( + belief_update: belief_updates.matrix_based.MatrixBasedLinearBeliefUpdate, + state: LinearSolverState, +): + """Compare the updated belief to a naive implementation.""" + + def dense_matrix_based_update( + matrix: randvars.Normal, action: np.ndarray, observ: np.ndarray + ): + pred = matrix.mean @ action + resid = observ - pred + covfactor_Ms = matrix.cov.A @ action + gram = action.T @ covfactor_Ms + gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 + gain = covfactor_Ms * gram_pinv + covfactor_update = gain @ covfactor_Ms.T + resid_gain = np.outer(resid, gain) + + return randvars.Normal( + mean=matrix.mean + + resid_gain + + resid_gain.T + - np.outer(gain, action.T @ resid_gain), + cov=linops.SymmetricKronecker(A=matrix.cov.A - covfactor_update), + ) + + updated_belief = belief_update(solver_state=state) + A_naive = dense_matrix_based_update( + matrix=state.belief.A, action=state.action, observ=state.observation + ) + Ainv_naive = dense_matrix_based_update( + matrix=state.belief.Ainv, action=state.observation, observ=state.action + ) + + # System matrix + np.testing.assert_allclose( + updated_belief.A.mean.todense(), + A_naive.mean.todense(), + err_msg="Mean of system matrix estimate does not match naive implementation.", + ) + np.testing.assert_allclose( + updated_belief.A.cov.todense(), + A_naive.cov.todense(), + err_msg="Covariance of system matrix estimate does not match naive implementation.", + ) + + # Inverse + np.testing.assert_allclose( + updated_belief.Ainv.mean.todense(), + Ainv_naive.mean.todense(), + err_msg="Mean of matrix inverse estimate does not match naive implementation.", + ) + np.testing.assert_allclose( + updated_belief.Ainv.cov.todense(), + Ainv_naive.cov.todense(), + err_msg="Covariance of matrix inverse estimate does not match naive implementation.", + ) diff --git a/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/__init__.py b/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_solution_based_proj_rhs_belief_update.py b/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_solution_based_proj_rhs_belief_update.py new file mode 100644 index 000000000..56b954a21 --- /dev/null +++ b/tests/test_linalg/test_solvers/test_belief_updates/test_solution_based/test_solution_based_proj_rhs_belief_update.py @@ -0,0 +1,103 @@ +"""Tests for the solution-based belief update for projected right hand side +information.""" + +import pathlib + +import numpy as np +import pytest +from pytest_cases import parametrize_with_cases + +from probnum import randvars +from probnum.linalg.solvers import LinearSolverState, belief_updates, beliefs + +case_modules = (pathlib.Path(__file__).parent.parent / "cases").stem +cases_belief_updates = case_modules + ".belief_updates" +cases_states = case_modules + ".states" + + +@parametrize_with_cases( + "belief_update", cases=cases_belief_updates, glob="*solution_based_projected_rhs*" +) +@parametrize_with_cases( + "state", + cases=cases_states, + has_tag=["has_action", "has_observation", "solution_based"], +) +def test_returns_linear_system_belief( + belief_update: belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate, + state: LinearSolverState, +): + belief = belief_update(solver_state=state) + assert isinstance(belief, beliefs.LinearSystemBelief) + + +def test_negative_noise_variance_raises_error(): + with pytest.raises(ValueError): + belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate( + noise_var=-1.0 + ) + + +@parametrize_with_cases( + "belief_update", cases=cases_belief_updates, glob="*solution_based_projected_rhs*" +) +@parametrize_with_cases( + "state", + cases=cases_states, + has_tag=["has_action", "has_observation", "solution_based"], +) +def test_beliefs_against_naive_implementation( + belief_update: belief_updates.solution_based.SolutionBasedProjectedRHSBeliefUpdate, + state: LinearSolverState, +): + """Compare the updated belief to a naive implementation.""" + # Belief update + updated_belief = belief_update(solver_state=state) + + # Naive implementation + belief = state.belief + action = state.action + observ = state.observation + noise_var = belief_update._noise_var + + A_action = state.problem.A @ action + gram = A_action.T @ belief.x.cov @ A_action + noise_var + gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 + + x = randvars.Normal( + mean=belief.x.mean + + belief.x.cov @ A_action * (observ - A_action.T @ belief.x.mean) * gram_pinv, + cov=belief.x.cov + - (belief.x.cov @ A_action) @ (belief.x.cov @ A_action).T * gram_pinv, + ) + Ainv = ( + belief.Ainv + + (belief.x.cov @ A_action) @ (belief.x.cov @ A_action).T * gram_pinv + ) + + naive_belief = beliefs.LinearSystemBelief(x=x, Ainv=Ainv) + + # Compare means and covariances + np.testing.assert_allclose( + updated_belief.x.mean, + naive_belief.x.mean, + err_msg="Mean of solution belief does not match naive implementation.", + atol=1e-12, + rtol=1e-12, + ) + + np.testing.assert_allclose( + updated_belief.x.cov, + naive_belief.x.cov, + err_msg="Covariance of solution belief does not match naive implementation.", + atol=1e-12, + rtol=1e-12, + ) + + np.testing.assert_allclose( + updated_belief.Ainv.mean.todense(), + naive_belief.Ainv.mean.todense(), + err_msg="Belief about the inverse does not match naive implementation.", + atol=1e-12, + rtol=1e-12, + ) diff --git a/tests/test_linalg/test_solvers/test_information_ops/test_linear_solver_info_op.py b/tests/test_linalg/test_solvers/test_information_ops/test_linear_solver_info_op.py index 2f73ff033..9577ea7ba 100644 --- a/tests/test_linalg/test_solvers/test_information_ops/test_linear_solver_info_op.py +++ b/tests/test_linalg/test_solvers/test_information_ops/test_linear_solver_info_op.py @@ -5,7 +5,7 @@ import numpy as np from pytest_cases import parametrize_with_cases -from probnum.linalg.solvers import ProbabilisticLinearSolverState, information_ops +from probnum.linalg.solvers import LinearSolverState, information_ops case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_information_ops = case_modules + ".information_ops" @@ -15,7 +15,7 @@ @parametrize_with_cases("info_op", cases=cases_information_ops) @parametrize_with_cases("state", cases=cases_states, has_tag=["has_action"]) def test_returns_ndarray_or_scalar( - info_op: information_ops.LinearSolverInfoOp, state: ProbabilisticLinearSolverState + info_op: information_ops.LinearSolverInformationOp, state: LinearSolverState ): observation = info_op(state) assert isinstance(observation, np.ndarray) or np.isscalar(observation) diff --git a/tests/test_linalg/test_solvers/test_information_ops/test_matvec.py b/tests/test_linalg/test_solvers/test_information_ops/test_matvec.py index 816a0d852..3433377dc 100644 --- a/tests/test_linalg/test_solvers/test_information_ops/test_matvec.py +++ b/tests/test_linalg/test_solvers/test_information_ops/test_matvec.py @@ -5,7 +5,7 @@ import numpy as np from pytest_cases import parametrize_with_cases -from probnum.linalg.solvers import ProbabilisticLinearSolverState, information_ops +from probnum.linalg.solvers import LinearSolverState, information_ops case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_information_ops = case_modules + ".information_ops" @@ -15,7 +15,7 @@ @parametrize_with_cases("info_op", cases=cases_information_ops, glob="*matvec") @parametrize_with_cases("state", cases=cases_states, has_tag=["has_action"]) def test_is_A_matvec( - info_op: information_ops.LinearSolverInfoOp, state: ProbabilisticLinearSolverState + info_op: information_ops.LinearSolverInformationOp, state: LinearSolverState ): observation = info_op(state) np.testing.assert_equal(observation, state.problem.A @ state.action) diff --git a/tests/test_linalg/test_solvers/test_information_ops/test_residual.py b/tests/test_linalg/test_solvers/test_information_ops/test_projected_rhs.py similarity index 62% rename from tests/test_linalg/test_solvers/test_information_ops/test_residual.py rename to tests/test_linalg/test_solvers/test_information_ops/test_projected_rhs.py index 01d64b0f4..6915030c6 100644 --- a/tests/test_linalg/test_solvers/test_information_ops/test_residual.py +++ b/tests/test_linalg/test_solvers/test_information_ops/test_projected_rhs.py @@ -1,21 +1,21 @@ -"""Tests for the projected residual information operator.""" +"""Tests for the projected right hand side information operator.""" import pathlib import numpy as np from pytest_cases import parametrize_with_cases -from probnum.linalg.solvers import ProbabilisticLinearSolverState, information_ops +from probnum.linalg.solvers import LinearSolverState, information_ops case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_information_ops = case_modules + ".information_ops" cases_states = case_modules + ".states" -@parametrize_with_cases("info_op", cases=cases_information_ops, glob="*proj_residual") +@parametrize_with_cases("info_op", cases=cases_information_ops, glob="*projected_rhs") @parametrize_with_cases("state", cases=cases_states, has_tag=["has_action"]) -def test_is_projected_residual( - info_op: information_ops.LinearSolverInfoOp, state: ProbabilisticLinearSolverState +def test_is_projected_rhs( + info_op: information_ops.LinearSolverInformationOp, state: LinearSolverState ): observation = info_op(state) - np.testing.assert_equal(observation, state.action.T @ state.residual) + np.testing.assert_equal(observation, state.action.T @ state.problem.b) diff --git a/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py b/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py index 6f03623ad..4a39495ad 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py +++ b/tests/test_linalg/test_solvers/test_policies/test_conjugate_gradient.py @@ -5,7 +5,7 @@ from pytest_cases import parametrize_with_cases from probnum import randvars -from probnum.linalg.solvers import ProbabilisticLinearSolverState, policies +from probnum.linalg.solvers import LinearSolverState, policies case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_policies = case_modules + ".policies" @@ -15,7 +15,7 @@ @parametrize_with_cases("policy", cases=cases_policies, glob="*conjugate_gradient") @parametrize_with_cases("state", cases=cases_states) def test_initial_action_is_negative_gradient( - policy: policies.ConjugateGradientPolicy, state: ProbabilisticLinearSolverState + policy: policies.ConjugateGradientPolicy, state: LinearSolverState ): assert state.step == 0 action = policy(state) @@ -25,7 +25,7 @@ def test_initial_action_is_negative_gradient( @parametrize_with_cases("policy", cases=cases_policies, glob="*conjugate_*") @parametrize_with_cases("state", cases=cases_states, has_tag=["initial"]) def test_conjugate_actions( - policy: policies.ConjugateGradientPolicy, state: ProbabilisticLinearSolverState + policy: policies.ConjugateGradientPolicy, state: LinearSolverState ): """Tests whether actions generated by the policy are A-conjugate via a naive CG implementation.""" diff --git a/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py b/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py index 18eb912de..e2b0a4f11 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py +++ b/tests/test_linalg/test_solvers/test_policies/test_linear_solver_policy.py @@ -4,7 +4,7 @@ import numpy as np from pytest_cases import parametrize_with_cases -from probnum.linalg.solvers import ProbabilisticLinearSolverState, policies +from probnum.linalg.solvers import LinearSolverState, policies case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_policies = case_modules + ".policies" @@ -13,18 +13,14 @@ @parametrize_with_cases("policy", cases=cases_policies) @parametrize_with_cases("state", cases=cases_states) -def test_returns_ndarray( - policy: policies.LinearSolverPolicy, state: ProbabilisticLinearSolverState -): +def test_returns_ndarray(policy: policies.LinearSolverPolicy, state: LinearSolverState): action = policy(state) assert isinstance(action, np.ndarray) @parametrize_with_cases("policy", cases=cases_policies) @parametrize_with_cases("state", cases=cases_states) -def test_shape( - policy: policies.LinearSolverPolicy, state: ProbabilisticLinearSolverState -): +def test_shape(policy: policies.LinearSolverPolicy, state: LinearSolverState): action = policy(state) assert action.shape[0] == state.problem.A.shape[1] @@ -32,7 +28,7 @@ def test_shape( @parametrize_with_cases("policy", cases=cases_policies, has_tag="random") @parametrize_with_cases("state", cases=cases_states) def test_uses_solver_state_random_number_generator( - policy: policies.LinearSolverPolicy, state: ProbabilisticLinearSolverState + policy: policies.LinearSolverPolicy, state: LinearSolverState ): """Test whether randomized policies make use of the random number generator stored in the linear solver state.""" diff --git a/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py b/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py index 3fec79d20..06265f0d8 100644 --- a/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py +++ b/tests/test_linalg/test_solvers/test_policies/test_random_unit_vector.py @@ -5,7 +5,7 @@ import pytest from pytest_cases import parametrize_with_cases -from probnum.linalg.solvers import ProbabilisticLinearSolverState, policies +from probnum.linalg.solvers import LinearSolverState, policies case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_policies = case_modules + ".policies" @@ -15,7 +15,7 @@ @parametrize_with_cases("policy", cases=cases_policies, glob="*unit_vector") @parametrize_with_cases("state", cases=cases_states) def test_returns_unit_vector( - policy: policies.LinearSolverPolicy, state: ProbabilisticLinearSolverState + policy: policies.LinearSolverPolicy, state: LinearSolverState ): action = policy(state) assert np.linalg.norm(action) == pytest.approx(1.0) diff --git a/tests/test_linalg/test_solvers/test_state.py b/tests/test_linalg/test_solvers/test_state.py index 53d6bd026..cd5cd2476 100644 --- a/tests/test_linalg/test_solvers/test_state.py +++ b/tests/test_linalg/test_solvers/test_state.py @@ -3,13 +3,13 @@ import numpy as np from pytest_cases import parametrize, parametrize_with_cases -from probnum.linalg.solvers import ProbabilisticLinearSolverState +from probnum.linalg.solvers import LinearSolverState cases_states = "cases.states" @parametrize_with_cases("state", cases=cases_states) -def test_residual(state: ProbabilisticLinearSolverState): +def test_residual(state: LinearSolverState): """Test whether the state computes the residual correctly.""" linsys = state.problem residual = linsys.A @ state.belief.x.mean - linsys.b @@ -17,7 +17,7 @@ def test_residual(state: ProbabilisticLinearSolverState): @parametrize_with_cases("state", cases=cases_states) -def test_next_step(state: ProbabilisticLinearSolverState): +def test_next_step(state: LinearSolverState): """Test whether advancing a state to the next step updates all state attributes correctly.""" initial_step = state.step @@ -31,7 +31,7 @@ def test_next_step(state: ProbabilisticLinearSolverState): @parametrize_with_cases("state", cases=cases_states) @parametrize("attr_name", ["action", "observation", "residual"]) -def test_current_iter_attribute(state: ProbabilisticLinearSolverState, attr_name: str): +def test_current_iter_attribute(state: LinearSolverState, attr_name: str): """Test whether the current iteration attribute if set returns the last element of the attribute lists.""" assert np.all( diff --git a/tests/test_linalg/test_solvers/test_stopping_criteria/test_linear_solver_stopping_criterion.py b/tests/test_linalg/test_solvers/test_stopping_criteria/test_linear_solver_stopping_criterion.py index e63ce31de..44344c5cb 100644 --- a/tests/test_linalg/test_solvers/test_stopping_criteria/test_linear_solver_stopping_criterion.py +++ b/tests/test_linalg/test_solvers/test_stopping_criteria/test_linear_solver_stopping_criterion.py @@ -4,7 +4,7 @@ from pytest_cases import parametrize_with_cases -from probnum.linalg.solvers import ProbabilisticLinearSolverState, stopping_criteria +from probnum.linalg.solvers import LinearSolverState, stopping_criteria case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_stopping_criteria = case_modules + ".stopping_criteria" @@ -14,7 +14,7 @@ @parametrize_with_cases("stop_crit", cases=cases_stopping_criteria) @parametrize_with_cases("state", cases=cases_states) def test_returns_bool( - stop_crit: stopping_criteria.LinearSolverStopCrit, - state: ProbabilisticLinearSolverState, + stop_crit: stopping_criteria.LinearSolverStoppingCriterion, + state: LinearSolverState, ): assert stop_crit(solver_state=state) in [True, False] diff --git a/tests/test_linalg/test_solvers/test_stopping_criteria/test_maxiter.py b/tests/test_linalg/test_solvers/test_stopping_criteria/test_maxiter.py index f3c40f76f..97915a7d2 100644 --- a/tests/test_linalg/test_solvers/test_stopping_criteria/test_maxiter.py +++ b/tests/test_linalg/test_solvers/test_stopping_criteria/test_maxiter.py @@ -4,7 +4,7 @@ from pytest_cases import parametrize_with_cases -from probnum.linalg.solvers import ProbabilisticLinearSolverState, stopping_criteria +from probnum.linalg.solvers import LinearSolverState, stopping_criteria case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_stopping_criteria = case_modules + ".stopping_criteria" @@ -12,10 +12,10 @@ @parametrize_with_cases("state", cases=cases_states, glob="*initial_state") -def test_maxiter_None(state: ProbabilisticLinearSolverState): +def test_maxiter_None(state: LinearSolverState): """Test whether if ``maxiter=None``, the maximum number of iterations is set to :math:`10n`, where :math:`n` is the dimension of the linear system.""" - stop_crit = stopping_criteria.MaxIterationsStopCrit() + stop_crit = stopping_criteria.MaxIterationsStoppingCriterion() for _ in range(10 * state.problem.A.shape[1]): assert not stop_crit(state) diff --git a/tests/test_linalg/test_solvers/test_stopping_criteria/test_posterior_contraction.py b/tests/test_linalg/test_solvers/test_stopping_criteria/test_posterior_contraction.py index 69a2ebb01..22f5e7467 100644 --- a/tests/test_linalg/test_solvers/test_stopping_criteria/test_posterior_contraction.py +++ b/tests/test_linalg/test_solvers/test_stopping_criteria/test_posterior_contraction.py @@ -4,7 +4,7 @@ from pytest_cases import parametrize_with_cases -from probnum.linalg.solvers import ProbabilisticLinearSolverState, stopping_criteria +from probnum.linalg.solvers import LinearSolverState, stopping_criteria case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_stopping_criteria = case_modules + ".stopping_criteria" @@ -16,7 +16,7 @@ ) @parametrize_with_cases("state", cases=cases_states, glob="*converged") def test_has_converged( - stop_crit: stopping_criteria.LinearSolverStopCrit, - state: ProbabilisticLinearSolverState, + stop_crit: stopping_criteria.LinearSolverStoppingCriterion, + state: LinearSolverState, ): assert stop_crit(solver_state=state) diff --git a/tests/test_linalg/test_solvers/test_stopping_criteria/test_residual_norm.py b/tests/test_linalg/test_solvers/test_stopping_criteria/test_residual_norm.py index 63b1215cb..013c10257 100644 --- a/tests/test_linalg/test_solvers/test_stopping_criteria/test_residual_norm.py +++ b/tests/test_linalg/test_solvers/test_stopping_criteria/test_residual_norm.py @@ -4,7 +4,7 @@ from pytest_cases import parametrize_with_cases -from probnum.linalg.solvers import ProbabilisticLinearSolverState, stopping_criteria +from probnum.linalg.solvers import LinearSolverState, stopping_criteria case_modules = (pathlib.Path(__file__).parent / "cases").stem cases_stopping_criteria = case_modules + ".stopping_criteria" @@ -16,7 +16,7 @@ ) @parametrize_with_cases("state", cases=cases_states, glob="*converged") def test_has_converged( - stop_crit: stopping_criteria.LinearSolverStopCrit, - state: ProbabilisticLinearSolverState, + stop_crit: stopping_criteria.LinearSolverStoppingCriterion, + state: LinearSolverState, ): assert stop_crit(solver_state=state)