Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor LinearOperator for method overriding #735

Merged
32 changes: 20 additions & 12 deletions docs/source/tutorials/linops/linear_operators_quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
{
"data": {
"text/plain": [
"<SumLinearOperator with shape=(5, 5) and dtype=float64>"
"<Matrix with shape=(5, 5) and dtype=float64>"
]
},
"execution_count": 9,
Expand All @@ -269,7 +269,10 @@
{
"data": {
"text/plain": [
"<ProductLinearOperator with shape=(5, 5) and dtype=float64>"
"ProductLinearOperator [\n",
"\t<Matrix with shape=(5, 5) and dtype=float64>, \n",
"\t<Matrix with shape=(5, 5) and dtype=float64>, \n",
"]"
]
},
"execution_count": 10,
Expand Down Expand Up @@ -311,10 +314,10 @@
{
"data": {
"text/plain": [
"array([-1.39282086, -2.09807924, -1.01469708, -0.74204673, -3.26963901,\n",
" -0.92439367, -0.65638407, 0.43823505, 0.66964627, -0.316306 ,\n",
" 5.7153326 , 0.43495681, 0.46390134, -2.66045433, 0.62615866,\n",
" 0.00715237, -0.83637837, -0.95389845, -0.41350942, -1.23499484])"
"array([ 1.49421769, -1.35451937, 1.05551543, -0.41823967, 0.42934955,\n",
" -0.82155968, -1.93141743, -4.31860989, -1.70475714, 4.36385187,\n",
" 2.36850628, -2.94034717, 0.39821307, -1.08656905, 0.36490375,\n",
" -0.86441656, -0.44778464, -0.44155178, 0.55687361, 0.17178464])"
]
},
"execution_count": 11,
Expand Down Expand Up @@ -361,7 +364,7 @@
{
"data": {
"text/plain": [
"<LinearOperator with shape=(5, 5) and dtype=float64>"
"<LambdaLinearOperator with shape=(5, 5) and dtype=float64>"
]
},
"execution_count": 12,
Expand All @@ -370,14 +373,14 @@
}
],
"source": [
"from probnum.linops import LinearOperator\n",
"from probnum.linops import LinearOperator, LambdaLinearOperator\n",
"\n",
"@LinearOperator.broadcast_matvec\n",
"def mv(v):\n",
" return np.roll(v, 1)\n",
"\n",
"n = 5\n",
"P_op = LinearOperator(shape=(n, n), dtype=np.float_, matmul=mv)\n",
"P_op = LambdaLinearOperator(shape=(n, n), dtype=np.float_, matmul=mv)\n",
"x = np.arange(0., n, 1)\n",
"\n",
"P_op"
Expand Down Expand Up @@ -509,7 +512,7 @@
"def mv(v):\n",
" return v[:n-1]\n",
"\n",
"Pr = LinearOperator(shape=(n-1, n), dtype=np.float_, matmul=mv)\n",
"Pr = LambdaLinearOperator(shape=(n-1, n), dtype=np.float_, matmul=mv)\n",
"\n",
"# Apply the operator to the 3D normal random variable\n",
"rv_projected = Pr @ rv"
Expand Down Expand Up @@ -602,7 +605,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.10.6 (conda)",
"language": "python",
"name": "python3"
},
Expand All @@ -616,7 +619,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
"version": "3.10.6"
},
"vscode": {
"interpreter": {
"hash": "0457b12441837086dec1b475e0008c28e5fc37f4ffe0e5ee9f2b481cc28bc3c9"
}
}
},
"nbformat": 4,
Expand Down
18 changes: 9 additions & 9 deletions src/probnum/linalg/solvers/matrixbased.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ def _matmul(M):

Ainv0_mean = linops.Scaling(
alpha, shape=(self.n, self.n)
) + 2 / bx0 * linops.LinearOperator(
) + 2 / bx0 * linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(x0.dtype, alpha.dtype, b.dtype),
matmul=_matmul,
)
A0_mean = linops.Scaling(1 / alpha, shape=(self.n, self.n)) - 1 / (
alpha * np.squeeze((x0 - alpha * b).T @ x0)
) * linops.LinearOperator(
) * linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(x0.dtype, alpha.dtype, b.dtype),
matmul=_matmul,
Expand Down Expand Up @@ -632,7 +632,7 @@ def null_space_proj(x):

# Compute calibration term in the A view as a linear operator with scaling from
# degrees of freedom
calibration_term_A = linops.LinearOperator(
calibration_term_A = linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=S.dtype,
matmul=linops.LinearOperator.broadcast_matvec(
Expand All @@ -642,7 +642,7 @@ def null_space_proj(x):

# Compute calibration term in the Ainv view as a linear operator with scaling
# from degrees of freedom
calibration_term_Ainv = linops.LinearOperator(
calibration_term_Ainv = linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=S.dtype,
matmul=linops.LinearOperator.broadcast_matvec(
Expand All @@ -669,7 +669,7 @@ def _matmul(x):
# First term of calibration covariance class: AS(S'AS)^{-1}S'A
return (Y * sy**-1) @ (Y.T @ x.ravel())

_A_covfactor0 = linops.LinearOperator(
_A_covfactor0 = linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(Y, sy),
matmul=_matmul,
Expand All @@ -686,7 +686,7 @@ def _matmul(x):
)
return self.Ainv_mean0 @ (Y @ YAinv0Y_inv_YAinv0x)

_Ainv_covfactor0 = linops.LinearOperator(
_Ainv_covfactor0 = linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(Y, self.Ainv_mean0),
matmul=_matmul,
Expand Down Expand Up @@ -733,7 +733,7 @@ def _matmul(x):
def _matmul(x):
return 0.5 * (bWb * _Ainv_covfactor @ x + Wb @ (Wb.T @ x))

cov_op = linops.LinearOperator(
cov_op = linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(Wb.dtype, bWb.dtype),
matmul=_matmul,
Expand All @@ -755,7 +755,7 @@ def _mean_update(self, u, v):
def _matmul(x):
return u @ (v.T @ x) + v @ (u.T @ x)

return linops.LinearOperator(
return linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(u.dtype, v.dtype),
matmul=_matmul,
Expand All @@ -768,7 +768,7 @@ def _covariance_update(self, u, Ws):
def _matmul(x):
return Ws @ (u.T @ x)

return linops.LinearOperator(
return linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(u.dtype, Ws.dtype),
matmul=_matmul,
Expand Down
11 changes: 10 additions & 1 deletion src/probnum/linops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
"""

from ._kronecker import IdentityKronecker, Kronecker, SymmetricKronecker, Symmetrize
from ._linear_operator import Embedding, Identity, LinearOperator, Matrix, Selection
from ._linear_operator import (
Embedding,
Identity,
LambdaLinearOperator,
LinearOperator,
Matrix,
Selection,
)
from ._scaling import Scaling, Zero
from ._utils import LinearOperatorLike, aslinop

Expand All @@ -22,6 +29,7 @@
"aslinop",
"Embedding",
"LinearOperator",
"LambdaLinearOperator",
"Matrix",
"Identity",
"IdentityKronecker",
Expand All @@ -35,6 +43,7 @@

# Set correct module paths. Corrects links and module paths in documentation.
LinearOperator.__module__ = "probnum.linops"
LambdaLinearOperator.__module__ = "probnum.linops"
Embedding.__module__ = "probnum.linops"
Matrix.__module__ = "probnum.linops"
Identity.__module__ = "probnum.linops"
Expand Down
8 changes: 4 additions & 4 deletions src/probnum/linops/_arithmetic_fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from probnum.typing import NotImplementedType, ScalarLike
import probnum.utils

from ._linear_operator import BinaryOperandType, LinearOperator
from ._linear_operator import BinaryOperandType, LambdaLinearOperator, LinearOperator

########################################################################################
# Generic Linear Operator Arithmetic (Fallbacks)
########################################################################################


class ScaledLinearOperator(LinearOperator):
class ScaledLinearOperator(LambdaLinearOperator):
"""Linear operator scaled with a scalar."""

def __init__(self, linop: LinearOperator, scalar: ScalarLike):
Expand Down Expand Up @@ -81,7 +81,7 @@ def __repr__(self) -> str:
return f"-{self._linop}"


class SumLinearOperator(LinearOperator):
class SumLinearOperator(LambdaLinearOperator):
"""Sum of linear operators."""

def __init__(self, *summands: LinearOperator):
Expand Down Expand Up @@ -166,7 +166,7 @@ def _mul_fallback(
return res


class ProductLinearOperator(LinearOperator):
class ProductLinearOperator(LambdaLinearOperator):
"""(Operator) Product of linear operators."""

def __init__(self, *factors: LinearOperator):
Expand Down
64 changes: 7 additions & 57 deletions src/probnum/linops/_kronecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from . import _linear_operator, _utils


class Symmetrize(_linear_operator.LinearOperator):
class Symmetrize(_linear_operator.LambdaLinearOperator):
r"""Symmetrizes a vector in its matrix representation.

Given a vector :math:`x=\operatorname{vec}(X)`
Expand Down Expand Up @@ -65,7 +65,7 @@ def _matmul(self, x: np.ndarray) -> np.ndarray:
)


class Kronecker(_linear_operator.LinearOperator):
class Kronecker(_linear_operator.LambdaLinearOperator):
"""Kronecker product of two linear operators.

The Kronecker product [1]_ :math:`A \\otimes B` of two linear operators :math:`A`
Expand Down Expand Up @@ -136,7 +136,6 @@ def __init__(self, A: LinearOperatorLike, B: LinearOperatorLike):
self.A.shape[1] * self.B.shape[1],
),
matmul=lambda x: _kronecker_matmul(self.A, self.B, x),
rmatmul=lambda x: _kronecker_rmatmul(self.A, self.B, x),
todense=lambda: np.kron(
self.A.todense(cache=False), self.B.todense(cache=False)
),
Expand Down Expand Up @@ -260,29 +259,7 @@ def _kronecker_matmul(
return y


def _kronecker_rmatmul(
A: _linear_operator.LinearOperator,
B: _linear_operator.LinearOperator,
x: np.ndarray,
) -> np.ndarray:
# Reshape into stack of matrices
y = x

if not y.flags.c_contiguous:
y = y.copy(order="C")

y = y.reshape(y.shape[:-1] + (A.shape[0], B.shape[0]))

# ((A.T) @ X) @ (B.T).T
y = (A.T @ y) @ B

# Revert to stack of vectorized matrices
y = y.reshape(y.shape[:-2] + (-1,))

return y


class SymmetricKronecker(_linear_operator.LinearOperator):
class SymmetricKronecker(_linear_operator.LambdaLinearOperator):
"""Symmetric Kronecker product of two linear operators.

The symmetric Kronecker product [1]_ :math:`A \\otimes_{s} B` of two square linear
Expand Down Expand Up @@ -337,7 +314,6 @@ def __init__(

dtype = self.A.dtype
matmul = lambda x: _kronecker_matmul(self.A, self.A, x)
rmatmul = lambda x: _kronecker_rmatmul(self.A, self.A, x)
todense = self._todense_identical_factors
# (A (x)_s A)^T = A^T (x)_s A^T
transpose = lambda: SymmetricKronecker(A=self.A.T)
Expand All @@ -357,7 +333,6 @@ def __init__(

dtype = np.result_type(self.A.dtype, self.B.dtype, 0.5)
matmul = self._matmul_different_factors
rmatmul = self._rmatmul_different_factors
todense = self._todense_different_factors
# (A (x)_s B)^T = A^T (x)_s B^T
transpose = lambda: SymmetricKronecker(A=self.A.T, B=self.B.T)
Expand All @@ -371,7 +346,6 @@ def __init__(
dtype=dtype,
shape=2 * (self._n**2,),
matmul=matmul,
rmatmul=rmatmul,
todense=todense,
transpose=transpose,
inverse=inverse,
Expand Down Expand Up @@ -441,29 +415,6 @@ def _matmul_different_factors(self, x: np.ndarray) -> np.ndarray:

return y

def _rmatmul_different_factors(self, x: np.ndarray) -> np.ndarray:
# Reshape into stack of matrices
y = x

if not y.flags.c_contiguous:
y = y.copy(order="C")

y = y.reshape(y.shape[:-1] + (self._n, self._n))

# (A.T) @ X @ (B.T).T
y1 = (self.A.T @ y) @ self.B

# (B.T) @ X @ (A.T).T
y2 = (self.B.T @ y) @ self.A

# 1/2 ((A^T)X(B^T)^T + (B^T)X(A^T)^T)
y = 0.5 * (y1 + y2)

# Revert to stack of vectorized matrices
y = y.reshape(y.shape[:-2] + (-1,))

return y

def _todense_identical_factors(self) -> np.ndarray:
"""Dense representation of the symmetric Kronecker product."""
# 1/2 (A (x) B + B (x) A)
Expand Down Expand Up @@ -498,7 +449,7 @@ def _symmetrize(self) -> SymmetricKronecker:
return SymmetricKronecker(A=self.A.symmetrize(), B=self.B.symmetrize())


class IdentityKronecker(_linear_operator.LinearOperator):
class IdentityKronecker(_linear_operator.LambdaLinearOperator):
"""Block-diagonal linear operator.

Parameters
Expand Down Expand Up @@ -533,9 +484,6 @@ def __init__(self, num_blocks: int, B: LinearOperatorLike):
matmul=lambda x: _kronecker_matmul(
self.A, self.B, x
), # TODO: can be implemented more efficiently
rmatmul=lambda x: _kronecker_rmatmul(
self.A, self.B, x
), # TODO: can be implemented more efficiently
todense=lambda: np.kron(
self.A.todense(cache=False), self.B.todense(cache=False)
),
Expand Down Expand Up @@ -589,7 +537,9 @@ def _sub_idkronecker(

return NotImplemented

def _cond(self, p) -> np.inexact:
def _cond(
self, p: Optional[Union[None, int, str, np.floating]] = None
) -> np.number:
if p is None or p in (2, 1, np.inf, "fro", -2, -1, -np.inf):
return self.A.cond(p=p) * self.B.cond(p=p)

Expand Down