Skip to content

Commit

Permalink
use numpy eigh
Browse files Browse the repository at this point in the history
  • Loading branch information
qbarthelemy committed Jul 8, 2022
1 parent d5d370b commit d34def6
Showing 1 changed file with 34 additions and 24 deletions.
58 changes: 34 additions & 24 deletions pyriemann/utils/base.py
@@ -1,36 +1,46 @@
"""Base functions for SPD matrices."""

import numpy as np
from scipy.linalg import eigh

from numpy.core.numerictypes import typecodes


def _diag(X):
"""Construct ndarray from diagonal terms.
Parameters
----------
X : ndarray, shape (..., n)
Diagonal terms of ndarray.
Returns
-------
Y : ndarray, shape (..., n, n)
Diagonal ndarray.
"""
dims = X.shape
n = dims[-1]
Y = np.zeros((*dims, n), X.dtype)
Y[..., np.arange(n), np.arange(n)] = X
return Y


def _matrix_operator(C, operator):
"""Matrix function."""
if not isinstance(C, np.ndarray) or C.ndim < 2:
raise ValueError('Input must be at least a 2D ndarray')
if C.dtype.char in typecodes['AllFloat'] and not np.isfinite(C).all():
raise ValueError(
"Matrices must be positive definite. Add "
"regularization to avoid this error.")
eigvals, eigvects = eigh(C, check_finite=False)
D = (eigvects * operator(eigvals)[np.newaxis, :]) @ eigvects.T
eigvals, eigvecs = np.linalg.eigh(C)
D = eigvecs @ _diag(operator(eigvals)) @ np.swapaxes(eigvecs, -2, -1)
return D


def _matrices_operator(C, operator):
"""Recursive matrix function."""
if not isinstance(C, np.ndarray) or C.ndim < 2:
raise ValueError('Input must be at least a 2D ndarray')
elif C.ndim == 2:
return _matrix_operator(C, operator)
else:
return np.asarray([_matrices_operator(c, operator) for c in C])


def sqrtm(C):
r"""Square root of SPD matrices.
The matrix square root of a SPD matrix is defined by:
The matrix square root of a SPD matrix C is defined by:
.. math::
\mathbf{D} =
Expand All @@ -49,13 +59,13 @@ def sqrtm(C):
D : ndarray, shape (..., n, n)
Matrix square root of C.
"""
return _matrices_operator(C, np.sqrt)
return _matrix_operator(C, np.sqrt)


def logm(C):
r"""Logarithm of SPD matrices.
The matrix logarithm of a SPD matrix is defined by:
The matrix logarithm of a SPD matrix C is defined by:
.. math::
\mathbf{D} = \mathbf{V} \log{(\mathbf{\Lambda})} \mathbf{V}^\top
Expand All @@ -73,13 +83,13 @@ def logm(C):
D : ndarray, shape (..., n, n)
Matrix logarithm of C.
"""
return _matrices_operator(C, np.log)
return _matrix_operator(C, np.log)


def expm(C):
r"""Exponential of SPD matrices.
The matrix exponential of a SPD matrix is defined by:
The matrix exponential of a SPD matrix C is defined by:
.. math::
\mathbf{D} = \mathbf{V} \exp{(\mathbf{\Lambda})} \mathbf{V}^\top
Expand All @@ -97,13 +107,13 @@ def expm(C):
D : ndarray, shape (..., n, n)
Matrix exponential of C.
"""
return _matrices_operator(C, np.exp)
return _matrix_operator(C, np.exp)


def invsqrtm(C):
r"""Inverse square root of SPD matrices.
The matrix inverse square root of a SPD matrix is defined by:
The matrix inverse square root of a SPD matrix C is defined by:
.. math::
\mathbf{D} =
Expand All @@ -123,13 +133,13 @@ def invsqrtm(C):
Matrix inverse square root of C.
"""
def isqrt(x): return 1. / np.sqrt(x)
return _matrices_operator(C, isqrt)
return _matrix_operator(C, isqrt)


def powm(C, alpha):
r"""Power of SPD matrices.
The matrix power :math:`\alpha` of a SPD matrix is defined by:
The matrix power :math:`\alpha` of a SPD matrix C is defined by:
.. math::
\mathbf{D} =
Expand All @@ -151,4 +161,4 @@ def powm(C, alpha):
Matrix power of C.
"""
def power(x): return x**alpha
return _matrices_operator(C, power)
return _matrix_operator(C, power)

0 comments on commit d34def6

Please sign in to comment.