Skip to content

Commit

Permalink
[PLS Refactor] Belief Updates (#546)
Browse files Browse the repository at this point in the history
* belief update and some naming inconsistencies

* linear system belief update skeleton

* added interface for belief updates and changed problinsolve to solve

* general test for linear system belief updates added

* belief update classes for matrix and solution based view

* added belief update interfaces for solution-based and matrix-based views

* solution based belief update

* solution-based docstring added

* solutionbased tests pass

* pylint fixes

* matrix-based belief update

* added some more tests for (a)symmetric matrix-based inference

* symmetric matrix-based inference update

* test matrix based against naive implementation

* test naive implementation

* nullify belief over solution

* improved documentation for the solution based perspective

* better docstrings for matrix-based updates

* minor doc fix

* removed print statement

* Update src/probnum/linalg/solvers/belief_updates/_symmetric_matrix_based_linear_belief_update.py

Co-authored-by: Marvin Pförtner <marvin.pfoertner@icloud.com>

* Update src/probnum/linalg/solvers/belief_updates/_matrix_based_linear_belief_update.py

Co-authored-by: Marvin Pförtner <marvin.pfoertner@icloud.com>

* naming improvements

* more docstring improvements

* changed package structure separate matrix- and solutionbased belief updates

* missing tests

* Update docs/source/api/linalg/solvers.belief_updates.matrix_based.rst

Co-authored-by: Marvin Pförtner <marvin.pfoertner@icloud.com>

* Update docs/source/api/linalg/solvers.belief_updates.solution_based.rst

Co-authored-by: Marvin Pförtner <marvin.pfoertner@icloud.com>

* docfix

Co-authored-by: Marvin Pförtner <marvin.pfoertner@icloud.com>
  • Loading branch information
JonathanWenger and marvinpfoertner committed Nov 4, 2021
1 parent e5260ee commit 3662bf5
Show file tree
Hide file tree
Showing 48 changed files with 861 additions and 138 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Belief Updates for Matrix-based Inference
-----------------------------------------
.. automodapi:: probnum.linalg.solvers.belief_updates.matrix_based
:no-heading:
:headings: "*"
16 changes: 16 additions & 0 deletions docs/source/api/linalg/solvers.belief_updates.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Belief Updates
--------------
.. automodapi:: probnum.linalg.solvers.belief_updates
:no-heading:
:headings: "*"


.. toctree::
:hidden:

solvers.belief_updates.solution_based

.. toctree::
:hidden:

solvers.belief_updates.matrix_based
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Belief Updates for Solution-based Inference
-------------------------------------------
.. automodapi:: probnum.linalg.solvers.belief_updates.solution_based
:no-heading:
:headings: "*"
5 changes: 5 additions & 0 deletions docs/source/api/linalg/solvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ probnum.linalg.solvers

solvers.information_ops

.. toctree::
:hidden:

solvers.belief_updates

.. toctree::
:hidden:

Expand Down
6 changes: 3 additions & 3 deletions src/probnum/linalg/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
SymmetricMatrixBasedSolver,
)

from . import beliefs, information_ops, policies
from . import belief_updates, beliefs, information_ops, policies, stopping_criteria
from ._probabilistic_linear_solver import ProbabilisticLinearSolver
from ._state import ProbabilisticLinearSolverState
from ._state import LinearSolverState

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"ProbabilisticLinearSolver",
"MatrixBasedSolver",
"SymmetricMatrixBasedSolver",
"ProbabilisticLinearSolverState",
"LinearSolverState",
]

# Set correct module paths. Corrects links and module paths in documentation.
Expand Down
12 changes: 5 additions & 7 deletions src/probnum/linalg/solvers/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


@dataclasses.dataclass
class ProbabilisticLinearSolverState:
class LinearSolverState:
"""State of a probabilistic linear solver.
The solver state separates the state of a probabilistic linear solver from the algorithm itself, making the solver stateless. The state contains the problem to be solved, the current belief over the quantities of interest and any miscellaneous quantities computed during an iteration of a probabilistic linear solver. The solver state is passed between the different components of the solver and may be used internally to cache quantities which are used more than once.
Expand Down Expand Up @@ -51,8 +51,7 @@ def __repr__(self) -> str:
def action(self) -> Optional[np.ndarray]:
"""Action of the solver for the current step.
Is ``None`` at the beginning of a step and will be set by the
policy.
Is ``None`` at the beginning of a step and will be set by the policy.
"""
return self._actions[self.step]

Expand All @@ -65,8 +64,8 @@ def action(self, value: np.ndarray) -> None:
def observation(self) -> Optional[Any]:
"""Observation of the solver for the current step.
Is ``None`` at the beginning of a step, will be set by the
observation model for a given action.
Is ``None`` at the beginning of a step, will be set by the observation model for
a given action.
"""
return self._observations[self.step]

Expand Down Expand Up @@ -102,8 +101,7 @@ def residuals(self) -> Tuple[np.ndarray, ...]:
def next_step(self) -> None:
"""Advance the solver state to the next solver step.
Called after a completed step / iteration of the probabilistic
linear solver.
Called after a completed step / iteration of the probabilistic linear solver.
"""
self._actions.append(None)
self._observations.append(None)
Expand Down
12 changes: 12 additions & 0 deletions src/probnum/linalg/solvers/belief_updates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Belief updates for the quantities of interest of a linear system."""

from . import matrix_based, solution_based
from ._linear_solver_belief_update import LinearSolverBeliefUpdate

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"LinearSolverBeliefUpdate",
]

# Set correct module paths. Corrects links and module paths in documentation.
LinearSolverBeliefUpdate.__module__ = "probnum.linalg.solvers.belief_updates"
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Linear system belief updates.
Class defining how a belief over quantities of interest of a linear system is updated
given information about the problem.
"""
import abc

import probnum # pylint: disable="unused-import"
from probnum.linalg.solvers.beliefs import LinearSystemBelief


class LinearSolverBeliefUpdate(abc.ABC):
r"""Belief update for the quantities of interest of a linear system.
Given a solver state containing information about the linear system collected in the current step, update the belief about the quantities of interest.
"""

@abc.abstractmethod
def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> LinearSystemBelief:
r"""Update the belief about the quantities of interest of a linear system.
Parameters
----------
solver_state :
Current state of the linear solver.
"""
raise NotImplementedError
21 changes: 21 additions & 0 deletions src/probnum/linalg/solvers/belief_updates/matrix_based/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Matrix-based belief updates for the quantities of interest of a linear system."""

from ._matrix_based_linear_belief_update import MatrixBasedLinearBeliefUpdate
from ._symmetric_matrix_based_linear_belief_update import (
SymmetricMatrixBasedLinearBeliefUpdate,
)

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"MatrixBasedLinearBeliefUpdate",
"SymmetricMatrixBasedLinearBeliefUpdate",
]

# Set correct module paths. Corrects links and module paths in documentation.
MatrixBasedLinearBeliefUpdate.__module__ = (
"probnum.linalg.solvers.belief_updates.matrix_based"
)

SymmetricMatrixBasedLinearBeliefUpdate.__module__ = (
"probnum.linalg.solvers.belief_updates.matrix_based"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Belief update in a matrix-based inference view where the information is given by
matrix-vector multiplication."""
import numpy as np

import probnum # pylint: disable="unused-import"
from probnum import linops, randvars
from probnum.linalg.solvers.beliefs import LinearSystemBelief

from .._linear_solver_belief_update import LinearSolverBeliefUpdate


class MatrixBasedLinearBeliefUpdate(LinearSolverBeliefUpdate):
r"""Gaussian belief update in a matrix-based inference framework assuming linear information.
Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given matrix-variate Gaussian beliefs with Kronecker covariance structure and linear observations :math:`y=As`. The belief update computes :math:`p(M \mid y) = \mathcal{N}(M; M_{i+1}, V \otimes W_{i+1})`, [1]_ [2]_ such that
.. math ::
\begin{align}
M_{i+1} &= M_i + (y - M_i s) (s^\top W_i s)^\dagger s^\top W_i,\\
W_{i+1} &= W_i - W_i s (s^\top W_i s)^\dagger s^\top W_i.
\end{align}
References
----------
.. [1] Hennig, P., Probabilistic Interpretation of Linear Solvers, *SIAM Journal on
Optimization*, 2015, 25, 234-260
.. [2] Wenger, J. and Hennig, P., Probabilistic Linear Solvers for Machine Learning,
*Advances in Neural Information Processing Systems (NeurIPS)*, 2020
"""

def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> LinearSystemBelief:

# Inference for A
A = self._matrix_based_update(
matrix=solver_state.belief.A,
action=solver_state.action,
observ=solver_state.observation,
)

# Inference for Ainv (interpret action and observation as swapped)
Ainv = self._matrix_based_update(
matrix=solver_state.belief.Ainv,
action=solver_state.observation,
observ=solver_state.action,
)
return LinearSystemBelief(A=A, Ainv=Ainv, x=None, b=solver_state.belief.b)

def _matrix_based_update(
self, matrix: randvars.Normal, action: np.ndarray, observ: np.ndarray
) -> randvars.Normal:
"""Matrix-based inference update for linear information."""
if not isinstance(matrix.cov, linops.Kronecker):
raise ValueError(
f"Covariance must have Kronecker structure, but is '{type(matrix.cov).__name__}'."
)

pred = matrix.mean @ action
resid = observ - pred
covfactor_Ms = matrix.cov.B @ action
gram = action.T @ covfactor_Ms
gram_pinv = 1.0 / gram if gram > 0.0 else 0.0
gain = covfactor_Ms * gram_pinv
covfactor_update = linops.aslinop(gain[:, None]) @ linops.aslinop(
covfactor_Ms[None, :]
)
resid_gain = linops.aslinop(resid[:, None]) @ linops.aslinop(
gain[None, :]
) # residual and gain are flipped due to matrix vectorization

return randvars.Normal(
mean=matrix.mean + resid_gain,
cov=linops.Kronecker(A=matrix.cov.A, B=matrix.cov.B - covfactor_update),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Belief update in a matrix-based inference view assuming symmetry where the
information is given by matrix-vector multiplication."""
import numpy as np

import probnum # pylint: disable="unused-import"
from probnum import linops, randvars
from probnum.linalg.solvers.beliefs import LinearSystemBelief

from .._linear_solver_belief_update import LinearSolverBeliefUpdate


class SymmetricMatrixBasedLinearBeliefUpdate(LinearSolverBeliefUpdate):
r"""Symmetric Gaussian belief update in a matrix-based inference framework assuming linear information.
Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given symmetric matrix-variate Gaussian beliefs with symmetric Kronecker covariance structure and linear observations. The belief update computes :math:`p(M \mid y) = \mathcal{N}(M; M_{i+1}, W_{i+1} \otimes_s W_{i+1})`, [1]_ [2]_ such that
.. math ::
\begin{align}
M_{i+1} &= M_i + (y - M_i s) u^\top + u (y - M_i s)^\top - u s^\top(y - M_i s)u^\top,\\
W_{i+1} &= W_i - W_i s (s^\top W_i s)^\dagger s^\top W_i.
\end{align}
where :math:`u = W_i s (s^\top W s)^\dagger`.
References
----------
.. [1] Hennig, P., Probabilistic Interpretation of Linear Solvers, *SIAM Journal on
Optimization*, 2015, 25, 234-260
.. [2] Wenger, J. and Hennig, P., Probabilistic Linear Solvers for Machine Learning,
*Advances in Neural Information Processing Systems (NeurIPS)*, 2020
"""

def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> LinearSystemBelief:

# Inference for A
A = self._symmetric_matrix_based_update(
matrix=solver_state.belief.A,
action=solver_state.action,
observ=solver_state.observation,
)

# Inference for Ainv (interpret action and observation as swapped)
Ainv = self._symmetric_matrix_based_update(
matrix=solver_state.belief.Ainv,
action=solver_state.observation,
observ=solver_state.action,
)
return LinearSystemBelief(A=A, Ainv=Ainv, x=None, b=solver_state.belief.b)

def _symmetric_matrix_based_update(
self, matrix: randvars.Normal, action: np.ndarray, observ: np.ndarray
) -> randvars.Normal:
"""Symmetric matrix-based inference update for linear information."""
if not isinstance(matrix.cov, linops.SymmetricKronecker):
raise ValueError(
f"Covariance must have symmetric Kronecker structure, but is '{type(matrix.cov).__name__}'."
)

pred = matrix.mean @ action
resid = observ - pred
covfactor_Ms = matrix.cov.A @ action
gram = action.T @ covfactor_Ms
gram_pinv = 1.0 / gram if gram > 0.0 else 0.0
gain = covfactor_Ms * gram_pinv
covfactor_update = gain @ covfactor_Ms.T
resid_gain = linops.aslinop(resid[:, None]) @ linops.aslinop(gain[None, :])

return randvars.Normal(
mean=matrix.mean
+ resid_gain
+ resid_gain.T
- linops.aslinop(gain[:, None])
@ linops.aslinop((action.T @ resid_gain)[None, :]),
cov=linops.SymmetricKronecker(A=matrix.cov.A - covfactor_update),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Solution-based belief updates for the quantities of interest of a linear system."""

from ._solution_based_proj_rhs_belief_update import (
SolutionBasedProjectedRHSBeliefUpdate,
)

# Public classes and functions. Order is reflected in documentation.
__all__ = [
"SolutionBasedProjectedRHSBeliefUpdate",
]

# Set correct module paths. Corrects links and module paths in documentation.
SolutionBasedProjectedRHSBeliefUpdate.__module__ = (
"probnum.linalg.solvers.belief_updates.solution_based"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Belief update in a solution-based inference view where the information is given by
projecting the current residual to a subspace."""
import probnum # pylint: disable="unused-import"
from probnum import randvars
from probnum.linalg.solvers.beliefs import LinearSystemBelief
from probnum.typing import FloatArgType

from .._linear_solver_belief_update import LinearSolverBeliefUpdate


class SolutionBasedProjectedRHSBeliefUpdate(LinearSolverBeliefUpdate):
r"""Gaussian belief update in a solution-based inference framework assuming projected right-hand-side information.
Updates the belief over the quantities of interest of a linear system :math:`Ax=b` given a Gaussian belief over the solution :math:`x` and information of the form :math:`y = s\^top b=s^\top Ax`. The belief update computes the posterior belief about the solution, given by :math:`p(x \mid y) = \mathcal{N}(x; x_{i+1}, \Sigma_{i+1})`, [1]_ such that
.. math ::
\begin{align}
x_{i+1} &= x_i + \Sigma_i A^\top s (s^\top A \Sigma_i A^\top s + \lambda)^\dagger s^\top (b - Ax_i),\\
\Sigma_{i+1} &= \Sigma_i - \Sigma_i A^\top s (s^\top A \Sigma_i A s + \lambda)^\dagger s^\top A \Sigma_i,
\end{align}
where :math:`\lambda` is the noise variance.
Parameters
----------
noise_var :
Variance of the scalar observation noise.
References
----------
.. [1] Cockayne, J. et al., A Bayesian Conjugate Gradient Method, *Bayesian
Analysis*, 2019, 14, 937-1012
"""

def __init__(self, noise_var: FloatArgType = 0.0) -> None:
if noise_var < 0.0:
raise ValueError(f"Noise variance {noise_var} must be non-negative.")
self._noise_var = noise_var

def __call__(
self, solver_state: "probnum.linalg.solvers.LinearSolverState"
) -> LinearSystemBelief:

action_A = solver_state.action @ solver_state.problem.A
pred = action_A @ solver_state.belief.x.mean
proj_resid = solver_state.observation - pred
cov_xy = solver_state.belief.x.cov @ action_A.T
gram = action_A @ cov_xy + self._noise_var
gram_pinv = 1.0 / gram if gram > 0.0 else 0.0
gain = cov_xy * gram_pinv
cov_update = gain @ cov_xy.T

x = randvars.Normal(
mean=solver_state.belief.x.mean + gain * proj_resid,
cov=solver_state.belief.x.cov - cov_update,
)
Ainv = solver_state.belief.Ainv + cov_update

return LinearSystemBelief(
x=x, A=solver_state.belief.A, Ainv=Ainv, b=solver_state.belief.b
)

0 comments on commit 3662bf5

Please sign in to comment.