Skip to content

Commit

Permalink
Remove broadcast_rmat{vec,mul}
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Nov 8, 2022
1 parent 9beb719 commit 4b7a49e
Showing 1 changed file with 1 addition and 28 deletions.
29 changes: 1 addition & 28 deletions src/probnum/linops/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def __rmatmul__(
return matmul(other, self)

####################################################################################
# Automatic `(r)mat{vec,mat}`` to `(r)matmul` Broadcasting
# Automatic `mat{vec,mat}`` to `matmul` Broadcasting
####################################################################################

@classmethod
Expand Down Expand Up @@ -1060,30 +1060,6 @@ def _matmul(x: np.ndarray) -> np.ndarray:

return _matmul

@classmethod
def broadcast_rmatvec(
cls, rmatvec: Callable[[np.ndarray], np.ndarray]
) -> Callable[[np.ndarray], np.ndarray]:
def _rmatmul(x: np.ndarray) -> np.ndarray:
if x.ndim == 2 and x.shape[0] == 1:
return rmatvec(x[0, :])[np.newaxis, :]

return np.apply_along_axis(rmatvec, -1, x)

return _rmatmul

@classmethod
def broadcast_rmatmat(
cls, rmatmat: Callable[[np.ndarray], np.ndarray]
) -> Callable[[np.ndarray], np.ndarray]:
def _rmatmul(x: np.ndarray) -> np.ndarray:
if x.ndim == 2:
return rmatmat(x)

return _apply_to_matrix_stack(rmatmat, x)

return _rmatmul

@property
def _inexact_dtype(self) -> np.dtype:
if np.issubdtype(self.dtype, np.inexact):
Expand Down Expand Up @@ -1495,7 +1471,6 @@ def __init__(
dtype = self.A.dtype

matmul = LinearOperator.broadcast_matmat(lambda x: self.A @ x)
rmatmul = LinearOperator.broadcast_rmatmat(lambda x: x @ self.A)
todense = self.A.toarray
inverse = self._sparse_inv
trace = lambda: self.A.diagonal().sum()
Expand All @@ -1506,7 +1481,6 @@ def __init__(
dtype = self.A.dtype

matmul = lambda x: self.A @ x
rmatmul = lambda x: x @ self.A
todense = lambda: self.A
inverse = None
trace = lambda: np.trace(self.A)
Expand All @@ -1517,7 +1491,6 @@ def __init__(
shape,
dtype,
matmul=matmul,
rmatmul=rmatmul,
todense=todense,
transpose=transpose,
inverse=inverse,
Expand Down

0 comments on commit 4b7a49e

Please sign in to comment.