Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Base functions able to process ndarray of SPD matrices #186

Merged
merged 4 commits into from Jul 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whatsnew.rst
Expand Up @@ -68,6 +68,8 @@ v0.2.8.dev

- Fix dispersion when generating datasets: :func:`pyriemann.datasets.sample_gaussian_spd`

- Improve base functions to process ndarrays of SPD matrices

v0.2.7 (June 2021)
------------------

Expand Down
73 changes: 39 additions & 34 deletions pyriemann/utils/base.py
@@ -1,25 +1,30 @@
import numpy as np
from scipy.linalg import eigh
"""Base functions for SPD matrices."""

import numpy as np
from numpy.core.numerictypes import typecodes


def _matrix_operator(C, operator):
"""Matrix equivalent of an operator."""
if C.dtype.char in typecodes['AllFloat'] and not np.isfinite(C).all():
"""Matrix function."""
if not isinstance(C, np.ndarray) or C.ndim < 2:
raise ValueError('Input must be at least a 2D ndarray')
if C.dtype.char in typecodes['AllFloat'] and (
np.isinf(C).any() or np.isnan(C).any()):
raise ValueError(
"Covariance matrices must be positive definite. Add "
"Matrices must be positive definite. Add "
"regularization to avoid this error.")
eigvals, eigvects = eigh(C, check_finite=False)
eigvals = np.diag(operator(eigvals))
D = eigvects @ eigvals @ eigvects.T
eigvals, eigvecs = np.linalg.eigh(C)
eigvals = operator(eigvals)
if C.ndim >= 3:
eigvals = np.expand_dims(eigvals, -2)
D = (eigvecs * eigvals) @ np.swapaxes(eigvecs, -2, -1)
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved
return D


def sqrtm(C):
r""" Square root of SPD matrix.
r"""Square root of SPD matrices.

Return the matrix square root of a SPD matrix defined by:
The matrix square root of a SPD matrix C is defined by:

.. math::
\mathbf{D} =
Expand All @@ -30,21 +35,21 @@ def sqrtm(C):

Parameters
----------
C : ndarray, shape (n, n)
SPD matrix.
C : ndarray, shape (..., n, n)
SPD matrices, at least 2D ndarray.

Returns
-------
D : ndarray, shape (n, n)
D : ndarray, shape (..., n, n)
Matrix square root of C.
"""
return _matrix_operator(C, np.sqrt)


def logm(C):
r""" Logarithm of SPD matrix.
r"""Logarithm of SPD matrices.

Return the matrix logarithm of a SPD matrix defined by:
The matrix logarithm of a SPD matrix C is defined by:

.. math::
\mathbf{D} = \mathbf{V} \log{(\mathbf{\Lambda})} \mathbf{V}^\top
Expand All @@ -54,21 +59,21 @@ def logm(C):

Parameters
----------
C : ndarray, shape (n, n)
SPD matrix.
C : ndarray, shape (..., n, n)
SPD matrices, at least 2D ndarray.

Returns
-------
D : ndarray, shape (n, n)
D : ndarray, shape (..., n, n)
Matrix logarithm of C.
"""
return _matrix_operator(C, np.log)


def expm(C):
r""" Exponential of SPD matrix.
r"""Exponential of SPD matrices.

Return the matrix exponential of a SPD matrix defined by:
The matrix exponential of a SPD matrix C is defined by:

.. math::
\mathbf{D} = \mathbf{V} \exp{(\mathbf{\Lambda})} \mathbf{V}^\top
Expand All @@ -78,21 +83,21 @@ def expm(C):

Parameters
----------
C : ndarray, shape (n, n)
SPD matrix.
C : ndarray, shape (..., n, n)
SPD matrices, at least 2D ndarray.

Returns
-------
D : ndarray, shape (n, n)
D : ndarray, shape (..., n, n)
Matrix exponential of C.
"""
return _matrix_operator(C, np.exp)


def invsqrtm(C):
r""" Inverse square root of SPD matrix.
r"""Inverse square root of SPD matrices.

Return the inverse matrix square root of a SPD matrix defined by:
The matrix inverse square root of a SPD matrix C is defined by:

.. math::
\mathbf{D} =
Expand All @@ -103,22 +108,22 @@ def invsqrtm(C):

Parameters
----------
C : ndarray, shape (n, n)
SPD matrix.
C : ndarray, shape (..., n, n)
SPD matrices, at least 2D ndarray.

Returns
-------
D : ndarray, shape (n, n)
Inverse matrix square root of C.
D : ndarray, shape (..., n, n)
Matrix inverse square root of C.
"""
def isqrt(x): return 1. / np.sqrt(x)
return _matrix_operator(C, isqrt)


def powm(C, alpha):
r""" Power of SPD matrix.
r"""Power of SPD matrices.

Return the matrix power :math:`\alpha` of a SPD matrix defined by:
The matrix power :math:`\alpha` of a SPD matrix C is defined by:

.. math::
\mathbf{D} =
Expand All @@ -129,14 +134,14 @@ def powm(C, alpha):

Parameters
----------
C : ndarray, shape (n, n)
SPD matrix.
C : ndarray, shape (..., n, n)
SPD matrices, at least 2D ndarray.
alpha : float
The power to apply.

Returns
-------
D : ndarray, shape (n, n)
D : ndarray, shape (..., n, n)
Matrix power of C.
"""
def power(x): return x**alpha
Expand Down
27 changes: 27 additions & 0 deletions tests/test_utils_base.py
Expand Up @@ -49,3 +49,30 @@ def test_check_raise():
with pytest.warns(RuntimeWarning):
with pytest.raises(ValueError):
mean_riemann(C)


def test_funm_error():
with pytest.raises(ValueError):
sqrtm(np.ones(5))
with pytest.raises(ValueError):
invsqrtm(5.1)
with pytest.raises(ValueError):
logm([5.2])


@pytest.mark.parametrize("funm", [sqrtm, invsqrtm, logm, expm, powm])
def test_funm_ndarray(funm):
def test(funm, C):
if funm == powm:
D = funm(C, 0.2)
else:
D = funm(C)
assert C.shape == D.shape

n_matrices, n_channels = 6, 3
C_3d = np.asarray([np.eye(n_channels) for _ in range(n_matrices)])
test(funm, C_3d)

n_sets = 5
C_4d = np.asarray([C_3d for _ in range(n_sets)])
test(funm, C_4d)