-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PLS Refactor] Belief Updates (#546)
* belief update and some naming inconsistencies * linear system belief update skeleton * added interface for belief updates and changed problinsolve to solve * general test for linear system belief updates added * belief update classes for matrix and solution based view * added belief update interfaces for solution-based and matrix-based views * solution based belief update * solution-based docstring added * solutionbased tests pass * pylint fixes * matrix-based belief update * added some more tests for (a)symmetric matrix-based inference * symmetric matrix-based inference update * test matrix based against naive implementation * test naive implementation * nullify belief over solution * improved documentation for the solution based perspective * better docstrings for matrix-based updates * minor doc fix * removed print statement * Update src/probnum/linalg/solvers/belief_updates/_symmetric_matrix_based_linear_belief_update.py Co-authored-by: Marvin Pförtner <marvin.pfoertner@icloud.com> * Update src/probnum/linalg/solvers/belief_updates/_matrix_based_linear_belief_update.py Co-authored-by: Marvin Pförtner <marvin.pfoertner@icloud.com> * naming improvements * more docstring improvements * changed package structure separate matrix- and solutionbased belief updates * missing tests * Update docs/source/api/linalg/solvers.belief_updates.matrix_based.rst Co-authored-by: Marvin Pförtner <marvin.pfoertner@icloud.com> * Update docs/source/api/linalg/solvers.belief_updates.solution_based.rst Co-authored-by: Marvin Pförtner <marvin.pfoertner@icloud.com> * docfix Co-authored-by: Marvin Pförtner <marvin.pfoertner@icloud.com>
- Loading branch information
1 parent
e5260ee
commit 3662bf5
Showing
48 changed files
with
861 additions
and
138 deletions.
There are no files selected for viewing
5 changes: 5 additions & 0 deletions
5
docs/source/api/linalg/solvers.belief_updates.matrix_based.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Belief Updates for Matrix-based Inference | ||
----------------------------------------- | ||
.. automodapi:: probnum.linalg.solvers.belief_updates.matrix_based | ||
:no-heading: | ||
:headings: "*" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
Belief Updates | ||
-------------- | ||
.. automodapi:: probnum.linalg.solvers.belief_updates | ||
:no-heading: | ||
:headings: "*" | ||
|
||
|
||
.. toctree:: | ||
:hidden: | ||
|
||
solvers.belief_updates.solution_based | ||
|
||
.. toctree:: | ||
:hidden: | ||
|
||
solvers.belief_updates.matrix_based |
5 changes: 5 additions & 0 deletions
5
docs/source/api/linalg/solvers.belief_updates.solution_based.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
Belief Updates for Solution-based Inference | ||
------------------------------------------- | ||
.. automodapi:: probnum.linalg.solvers.belief_updates.solution_based | ||
:no-heading: | ||
:headings: "*" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
"""Belief updates for the quantities of interest of a linear system.""" | ||
|
||
from . import matrix_based, solution_based | ||
from ._linear_solver_belief_update import LinearSolverBeliefUpdate | ||
|
||
# Public classes and functions. Order is reflected in documentation. | ||
__all__ = [ | ||
"LinearSolverBeliefUpdate", | ||
] | ||
|
||
# Set correct module paths. Corrects links and module paths in documentation. | ||
LinearSolverBeliefUpdate.__module__ = "probnum.linalg.solvers.belief_updates" |
29 changes: 29 additions & 0 deletions
29
src/probnum/linalg/solvers/belief_updates/_linear_solver_belief_update.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
"""Linear system belief updates. | ||
Class defining how a belief over quantities of interest of a linear system is updated | ||
given information about the problem. | ||
""" | ||
import abc | ||
|
||
import probnum # pylint: disable="unused-import" | ||
from probnum.linalg.solvers.beliefs import LinearSystemBelief | ||
|
||
|
||
class LinearSolverBeliefUpdate(abc.ABC): | ||
r"""Belief update for the quantities of interest of a linear system. | ||
Given a solver state containing information about the linear system collected in the current step, update the belief about the quantities of interest. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def __call__( | ||
self, solver_state: "probnum.linalg.solvers.LinearSolverState" | ||
) -> LinearSystemBelief: | ||
r"""Update the belief about the quantities of interest of a linear system. | ||
Parameters | ||
---------- | ||
solver_state : | ||
Current state of the linear solver. | ||
""" | ||
raise NotImplementedError |
21 changes: 21 additions & 0 deletions
21
src/probnum/linalg/solvers/belief_updates/matrix_based/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""Matrix-based belief updates for the quantities of interest of a linear system.""" | ||
|
||
from ._matrix_based_linear_belief_update import MatrixBasedLinearBeliefUpdate | ||
from ._symmetric_matrix_based_linear_belief_update import ( | ||
SymmetricMatrixBasedLinearBeliefUpdate, | ||
) | ||
|
||
# Public classes and functions. Order is reflected in documentation. | ||
__all__ = [ | ||
"MatrixBasedLinearBeliefUpdate", | ||
"SymmetricMatrixBasedLinearBeliefUpdate", | ||
] | ||
|
||
# Set correct module paths. Corrects links and module paths in documentation. | ||
MatrixBasedLinearBeliefUpdate.__module__ = ( | ||
"probnum.linalg.solvers.belief_updates.matrix_based" | ||
) | ||
|
||
SymmetricMatrixBasedLinearBeliefUpdate.__module__ = ( | ||
"probnum.linalg.solvers.belief_updates.matrix_based" | ||
) |
76 changes: 76 additions & 0 deletions
76
src/probnum/linalg/solvers/belief_updates/matrix_based/_matrix_based_linear_belief_update.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""Belief update in a matrix-based inference view where the information is given by | ||
matrix-vector multiplication.""" | ||
import numpy as np | ||
|
||
import probnum # pylint: disable="unused-import" | ||
from probnum import linops, randvars | ||
from probnum.linalg.solvers.beliefs import LinearSystemBelief | ||
|
||
from .._linear_solver_belief_update import LinearSolverBeliefUpdate | ||
|
||
|
||
class MatrixBasedLinearBeliefUpdate(LinearSolverBeliefUpdate): | ||
r"""Gaussian belief update in a matrix-based inference framework assuming linear information. | ||
Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given matrix-variate Gaussian beliefs with Kronecker covariance structure and linear observations :math:`y=As`. The belief update computes :math:`p(M \mid y) = \mathcal{N}(M; M_{i+1}, V \otimes W_{i+1})`, [1]_ [2]_ such that | ||
.. math :: | ||
\begin{align} | ||
M_{i+1} &= M_i + (y - M_i s) (s^\top W_i s)^\dagger s^\top W_i,\\ | ||
W_{i+1} &= W_i - W_i s (s^\top W_i s)^\dagger s^\top W_i. | ||
\end{align} | ||
References | ||
---------- | ||
.. [1] Hennig, P., Probabilistic Interpretation of Linear Solvers, *SIAM Journal on | ||
Optimization*, 2015, 25, 234-260 | ||
.. [2] Wenger, J. and Hennig, P., Probabilistic Linear Solvers for Machine Learning, | ||
*Advances in Neural Information Processing Systems (NeurIPS)*, 2020 | ||
""" | ||
|
||
def __call__( | ||
self, solver_state: "probnum.linalg.solvers.LinearSolverState" | ||
) -> LinearSystemBelief: | ||
|
||
# Inference for A | ||
A = self._matrix_based_update( | ||
matrix=solver_state.belief.A, | ||
action=solver_state.action, | ||
observ=solver_state.observation, | ||
) | ||
|
||
# Inference for Ainv (interpret action and observation as swapped) | ||
Ainv = self._matrix_based_update( | ||
matrix=solver_state.belief.Ainv, | ||
action=solver_state.observation, | ||
observ=solver_state.action, | ||
) | ||
return LinearSystemBelief(A=A, Ainv=Ainv, x=None, b=solver_state.belief.b) | ||
|
||
def _matrix_based_update( | ||
self, matrix: randvars.Normal, action: np.ndarray, observ: np.ndarray | ||
) -> randvars.Normal: | ||
"""Matrix-based inference update for linear information.""" | ||
if not isinstance(matrix.cov, linops.Kronecker): | ||
raise ValueError( | ||
f"Covariance must have Kronecker structure, but is '{type(matrix.cov).__name__}'." | ||
) | ||
|
||
pred = matrix.mean @ action | ||
resid = observ - pred | ||
covfactor_Ms = matrix.cov.B @ action | ||
gram = action.T @ covfactor_Ms | ||
gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 | ||
gain = covfactor_Ms * gram_pinv | ||
covfactor_update = linops.aslinop(gain[:, None]) @ linops.aslinop( | ||
covfactor_Ms[None, :] | ||
) | ||
resid_gain = linops.aslinop(resid[:, None]) @ linops.aslinop( | ||
gain[None, :] | ||
) # residual and gain are flipped due to matrix vectorization | ||
|
||
return randvars.Normal( | ||
mean=matrix.mean + resid_gain, | ||
cov=linops.Kronecker(A=matrix.cov.A, B=matrix.cov.B - covfactor_update), | ||
) |
77 changes: 77 additions & 0 deletions
77
...inalg/solvers/belief_updates/matrix_based/_symmetric_matrix_based_linear_belief_update.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
"""Belief update in a matrix-based inference view assuming symmetry where the | ||
information is given by matrix-vector multiplication.""" | ||
import numpy as np | ||
|
||
import probnum # pylint: disable="unused-import" | ||
from probnum import linops, randvars | ||
from probnum.linalg.solvers.beliefs import LinearSystemBelief | ||
|
||
from .._linear_solver_belief_update import LinearSolverBeliefUpdate | ||
|
||
|
||
class SymmetricMatrixBasedLinearBeliefUpdate(LinearSolverBeliefUpdate): | ||
r"""Symmetric Gaussian belief update in a matrix-based inference framework assuming linear information. | ||
Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given symmetric matrix-variate Gaussian beliefs with symmetric Kronecker covariance structure and linear observations. The belief update computes :math:`p(M \mid y) = \mathcal{N}(M; M_{i+1}, W_{i+1} \otimes_s W_{i+1})`, [1]_ [2]_ such that | ||
.. math :: | ||
\begin{align} | ||
M_{i+1} &= M_i + (y - M_i s) u^\top + u (y - M_i s)^\top - u s^\top(y - M_i s)u^\top,\\ | ||
W_{i+1} &= W_i - W_i s (s^\top W_i s)^\dagger s^\top W_i. | ||
\end{align} | ||
where :math:`u = W_i s (s^\top W s)^\dagger`. | ||
References | ||
---------- | ||
.. [1] Hennig, P., Probabilistic Interpretation of Linear Solvers, *SIAM Journal on | ||
Optimization*, 2015, 25, 234-260 | ||
.. [2] Wenger, J. and Hennig, P., Probabilistic Linear Solvers for Machine Learning, | ||
*Advances in Neural Information Processing Systems (NeurIPS)*, 2020 | ||
""" | ||
|
||
def __call__( | ||
self, solver_state: "probnum.linalg.solvers.LinearSolverState" | ||
) -> LinearSystemBelief: | ||
|
||
# Inference for A | ||
A = self._symmetric_matrix_based_update( | ||
matrix=solver_state.belief.A, | ||
action=solver_state.action, | ||
observ=solver_state.observation, | ||
) | ||
|
||
# Inference for Ainv (interpret action and observation as swapped) | ||
Ainv = self._symmetric_matrix_based_update( | ||
matrix=solver_state.belief.Ainv, | ||
action=solver_state.observation, | ||
observ=solver_state.action, | ||
) | ||
return LinearSystemBelief(A=A, Ainv=Ainv, x=None, b=solver_state.belief.b) | ||
|
||
def _symmetric_matrix_based_update( | ||
self, matrix: randvars.Normal, action: np.ndarray, observ: np.ndarray | ||
) -> randvars.Normal: | ||
"""Symmetric matrix-based inference update for linear information.""" | ||
if not isinstance(matrix.cov, linops.SymmetricKronecker): | ||
raise ValueError( | ||
f"Covariance must have symmetric Kronecker structure, but is '{type(matrix.cov).__name__}'." | ||
) | ||
|
||
pred = matrix.mean @ action | ||
resid = observ - pred | ||
covfactor_Ms = matrix.cov.A @ action | ||
gram = action.T @ covfactor_Ms | ||
gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 | ||
gain = covfactor_Ms * gram_pinv | ||
covfactor_update = gain @ covfactor_Ms.T | ||
resid_gain = linops.aslinop(resid[:, None]) @ linops.aslinop(gain[None, :]) | ||
|
||
return randvars.Normal( | ||
mean=matrix.mean | ||
+ resid_gain | ||
+ resid_gain.T | ||
- linops.aslinop(gain[:, None]) | ||
@ linops.aslinop((action.T @ resid_gain)[None, :]), | ||
cov=linops.SymmetricKronecker(A=matrix.cov.A - covfactor_update), | ||
) |
15 changes: 15 additions & 0 deletions
15
src/probnum/linalg/solvers/belief_updates/solution_based/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
"""Solution-based belief updates for the quantities of interest of a linear system.""" | ||
|
||
from ._solution_based_proj_rhs_belief_update import ( | ||
SolutionBasedProjectedRHSBeliefUpdate, | ||
) | ||
|
||
# Public classes and functions. Order is reflected in documentation. | ||
__all__ = [ | ||
"SolutionBasedProjectedRHSBeliefUpdate", | ||
] | ||
|
||
# Set correct module paths. Corrects links and module paths in documentation. | ||
SolutionBasedProjectedRHSBeliefUpdate.__module__ = ( | ||
"probnum.linalg.solvers.belief_updates.solution_based" | ||
) |
62 changes: 62 additions & 0 deletions
62
...um/linalg/solvers/belief_updates/solution_based/_solution_based_proj_rhs_belief_update.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
"""Belief update in a solution-based inference view where the information is given by | ||
projecting the current residual to a subspace.""" | ||
import probnum # pylint: disable="unused-import" | ||
from probnum import randvars | ||
from probnum.linalg.solvers.beliefs import LinearSystemBelief | ||
from probnum.typing import FloatArgType | ||
|
||
from .._linear_solver_belief_update import LinearSolverBeliefUpdate | ||
|
||
|
||
class SolutionBasedProjectedRHSBeliefUpdate(LinearSolverBeliefUpdate): | ||
r"""Gaussian belief update in a solution-based inference framework assuming projected right-hand-side information. | ||
Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given a Gaussian belief over the solution :math:`x` and information of the form :math:`y = s\^top b=s^\top Ax`. The belief update computes the posterior belief about the solution, given by :math:`p(x \mid y) = \mathcal{N}(x; x_{i+1}, \Sigma_{i+1})`, [1]_ such that | ||
.. math :: | ||
\begin{align} | ||
x_{i+1} &= x_i + \Sigma_i A^\top s (s^\top A \Sigma_i A^\top s + \lambda)^\dagger s^\top (b - Ax_i),\\ | ||
\Sigma_{i+1} &= \Sigma_i - \Sigma_i A^\top s (s^\top A \Sigma_i A s + \lambda)^\dagger s^\top A \Sigma_i, | ||
\end{align} | ||
where :math:`\lambda` is the noise variance. | ||
Parameters | ||
---------- | ||
noise_var : | ||
Variance of the scalar observation noise. | ||
References | ||
---------- | ||
.. [1] Cockayne, J. et al., A Bayesian Conjugate Gradient Method, *Bayesian | ||
Analysis*, 2019, 14, 937-1012 | ||
""" | ||
|
||
def __init__(self, noise_var: FloatArgType = 0.0) -> None: | ||
if noise_var < 0.0: | ||
raise ValueError(f"Noise variance {noise_var} must be non-negative.") | ||
self._noise_var = noise_var | ||
|
||
def __call__( | ||
self, solver_state: "probnum.linalg.solvers.LinearSolverState" | ||
) -> LinearSystemBelief: | ||
|
||
action_A = solver_state.action @ solver_state.problem.A | ||
pred = action_A @ solver_state.belief.x.mean | ||
proj_resid = solver_state.observation - pred | ||
cov_xy = solver_state.belief.x.cov @ action_A.T | ||
gram = action_A @ cov_xy + self._noise_var | ||
gram_pinv = 1.0 / gram if gram > 0.0 else 0.0 | ||
gain = cov_xy * gram_pinv | ||
cov_update = gain @ cov_xy.T | ||
|
||
x = randvars.Normal( | ||
mean=solver_state.belief.x.mean + gain * proj_resid, | ||
cov=solver_state.belief.x.cov - cov_update, | ||
) | ||
Ainv = solver_state.belief.Ainv + cov_update | ||
|
||
return LinearSystemBelief( | ||
x=x, A=solver_state.belief.A, Ainv=Ainv, b=solver_state.belief.b | ||
) |
Oops, something went wrong.