Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance 'lwf', 'mcd', 'oas' and 'sch' covariance estimator to process complex inputs #274

Merged
merged 13 commits into from Dec 15, 2023
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -10,3 +10,4 @@
*.egg-info/
.vscode/
sandbox/
.idea/
gabelstein marked this conversation as resolved.
Show resolved Hide resolved
117 changes: 97 additions & 20 deletions pyriemann/utils/covariance.py
Expand Up @@ -11,41 +11,69 @@

def _lwf(X, **kwds):
"""Wrapper for sklearn ledoit wolf covariance estimator"""
if not is_real_type(X):
raise ValueError("Input must be real-valued.")
iscomplex = np.iscomplexobj(X)
if iscomplex:
X = np.concatenate((X.real, X.imag), axis=0)
C, _ = ledoit_wolf(X.T, **kwds)
if iscomplex:
C = _make_complex_covariance(C)
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved
return C


def _mcd(X, **kwds):
"""Wrapper for sklearn mcd covariance estimator"""
if not is_real_type(X):
raise ValueError("Input must be real-valued.")
iscomplex = np.iscomplexobj(X)
if iscomplex:
X = np.concatenate((X.real, X.imag), axis=0)
_, C, _, _ = fast_mcd(X.T, **kwds)
if iscomplex:
C = _make_complex_covariance(C)
return C


def _oas(X, **kwds):
"""Wrapper for sklearn oas covariance estimator"""
if not is_real_type(X):
raise ValueError("Input must be real-valued.")
iscomplex = np.iscomplexobj(X)
if iscomplex:
X = np.concatenate((X.real, X.imag), axis=0)
C, _ = oas(X.T, **kwds)
if iscomplex:
C = _make_complex_covariance(C)
return C


def _hub(X, **kwds):
"""Wrapper for Huber's M-estimator"""
return covariance_mest(X, 'hub', **kwds)
iscomplex = np.iscomplexobj(X)
if iscomplex:
X = np.concatenate((X.real, X.imag), axis=0)
C = covariance_mest(X, 'hub', **kwds)
if iscomplex:
C = _make_complex_covariance(C)
return C


def _stu(X, **kwds):
"""Wrapper for Student-t's M-estimator"""
return covariance_mest(X, 'stu', **kwds)
iscomplex = np.iscomplexobj(X)
if iscomplex:
X = np.concatenate((X.real, X.imag), axis=0)
C = covariance_mest(X, 'stu', **kwds)
if iscomplex:
C = _make_complex_covariance(C)

return C


def _tyl(X, **kwds):
"""Wrapper for Tyler's M-estimator"""
return covariance_mest(X, 'tyl', **kwds)
iscomplex = np.iscomplexobj(X)
if iscomplex:
X = np.concatenate((X.real, X.imag), axis=0)
C = covariance_mest(X, 'tyl', **kwds)
if iscomplex:
C = _make_complex_covariance(C)
return C
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved


def covariance_mest(X, m_estimator, *, init=None, tol=10e-3, n_iter_max=50,
Expand Down Expand Up @@ -218,8 +246,10 @@ def covariance_sch(X):
J. Schafer, and K. Strimmer. Statistical Applications in Genetics and
Molecular Biology, Volume 4, Issue 1, 2005.
"""
if not is_real_type(X):
raise ValueError("Input must be real-valued.")
iscomplex = np.iscomplexobj(X)
if iscomplex:
X = np.concatenate((X.real, X.imag), axis=0)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sylvchev , do you think that sch estimator could be easily generalized to complex-valued input?

_, n_times = X.shape
X_c = X - X.mean(axis=1, keepdims=True)
C_scm = X_c @ X_c.T / n_times
Expand All @@ -237,7 +267,10 @@ def covariance_sch(X):

sigma = (1. - gamma) * (n_times / (n_times - 1.)) * C_scm
shrinkage = gamma * (n_times / (n_times - 1.)) * np.diag(np.diag(C_scm))
return sigma + shrinkage
C = sigma + shrinkage
if iscomplex:
C = _make_complex_covariance(C)
return C


def covariance_scm(X, *, assume_centered=False):
Expand Down Expand Up @@ -311,6 +344,10 @@ def _check_cov_est_function(est):

def covariances(X, estimator='cov', **kwds):
"""Estimation of covariance matrix.

Estimates covariance matrices from multi-channel time series according to
a covariance estimator. Supports real and complex-valued data. For complex
data, the covariance matrices are estimated according to Section 3 in [1]_.
gabelstein marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
Expand Down Expand Up @@ -357,12 +394,16 @@ def covariances(X, estimator='cov', **kwds):
.. [oas] https://scikit-learn.org/stable/modules/generated/oas-function.html
.. [sch] :func:`pyriemann.utils.covariance.covariance_sch`
.. [scm] :func:`pyriemann.utils.covariance.covariance_scm`
.. [1] `A shrinkage approach to large-scale covariance estimation and
implications for functional genomics
gabelstein marked this conversation as resolved.
Show resolved Hide resolved
<https://doi.org/10.1109/ICASSP.2007.366399>`_
R. Abrahamsson, Y. Selen and P. Stoica. 2007 IEEE International
Conference on Acoustics, Speech and Signal Processing, Volume 2, 2007.
""" # noqa
est = _check_cov_est_function(estimator)
n_matrices, n_channels, n_times = X.shape
covmats = np.empty((n_matrices, n_channels, n_channels), dtype=X.dtype)
for i in range(n_matrices):
covmats[i] = est(X[i], **kwds)
covmats = np.asarray([est(X[i], **kwds)
for i in range(n_matrices)])
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved
return covmats


Expand Down Expand Up @@ -393,10 +434,9 @@ def covariances_EP(X, P, estimator='cov', **kwds):
if n_times_p != n_times:
raise ValueError(
f"X and P do not have the same n_times: {n_times} and {n_times_p}")
covmats = np.empty((n_matrices, n_channels + n_channels_proto,
n_channels + n_channels_proto))
for i in range(n_matrices):
covmats[i] = est(np.concatenate((P, X[i]), axis=0), **kwds)
covmats = np.asarray([est(np.concatenate((P, X[i]), axis=0), **kwds)
for i in range(n_matrices)])
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved

return covmats


Expand Down Expand Up @@ -492,7 +532,6 @@ def block_covariances(X, blocks, estimator='cov', **kwds):
blockcov.append(est(X[i, idx_start:idx_start+j, :], **kwds))
idx_start += j
covmats[i] = block_diag(*tuple(blockcov))

return covmats


Expand Down Expand Up @@ -796,3 +835,41 @@ def get_nondiag_weight(X):
num = np.sum(X2, axis=(-2, -1)) - denom
weights = (1.0 / (X.shape[-1] - 1)) * (num / denom)
return weights


def _make_complex_covariance(covmats):
gabelstein marked this conversation as resolved.
Show resolved Hide resolved
"""Convert real-valued covariance matrices to complex-valued.

Converts the stacked real-valued covariance matrices to complex-valued
covariance matrices, following Section 3 in [1]_.

Parameters
----------
covmats : ndarray, shape (n_matrices, n_channels, n_channels)
Covariance matrices, real-valued.

Returns
-------
complex_covmats : ndarray, shape (n_matrices, n_channels/2, n_channels/2)
Covariance matrices, complex-valued.

Notes
-----
.. versionadded:: 0.6

References
----------
.. [1] `A shrinkage approach to large-scale covariance estimation and
implications for functional genomics
gabelstein marked this conversation as resolved.
Show resolved Hide resolved
<https://doi.org/10.1109/ICASSP.2007.366399>`_
R. Abrahamsson, Y. Selen and P. Stoica. 2007 IEEE International
Conference on Acoustics, Speech and Signal Processing, Volume 2, 2007.
"""

n_channels, n_channels = covmats.shape
complex_covmats = covmats[:n_channels // 2, :n_channels // 2] \
+ covmats[n_channels // 2:, n_channels // 2:] \
+ 1j * (covmats[n_channels // 2:, :n_channels // 2]
- covmats[:n_channels // 2, n_channels // 2:])
gabelstein marked this conversation as resolved.
Show resolved Hide resolved

return complex_covmats
29 changes: 20 additions & 9 deletions tests/test_utils_covariance.py
Expand Up @@ -71,14 +71,9 @@ def test_covariances_complex(estimator, rndstate):
x = rndstate.randn(n_matrices, n_channels, n_times) \
+ 1j * rndstate.randn(n_matrices, n_channels, n_times)

if estimator in ['lwf', 'mcd', 'oas', 'sch']:
with pytest.raises(ValueError):
covariances(x, estimator=estimator)

else:
cov = covariances(x, estimator=estimator)
assert cov.shape == (n_matrices, n_channels, n_channels)
assert is_herm_pos_def(cov)
cov = covariances(x, estimator=estimator)
assert cov.shape == (n_matrices, n_channels, n_channels)
assert is_herm_pos_def(cov)


@pytest.mark.parametrize(
Expand All @@ -89,7 +84,24 @@ def test_covariances_EP(estimator, rndstate):
n_matrices, n_channels_x, n_channels_p, n_times = 2, 3, 3, 100
x = rndstate.randn(n_matrices, n_channels_x, n_times)
p = rndstate.randn(n_channels_p, n_times)
if estimator is None:
cov = covariances_EP(x, p)
else:
cov = covariances_EP(x, p, estimator=estimator)
n_dim_cov = n_channels_x + n_channels_p
assert cov.shape == (n_matrices, n_dim_cov, n_dim_cov)


@pytest.mark.parametrize(
'estimator', estimators + [None]
)
def test_covariances_EP_complex(estimator, rndstate):
"""Test covariance_EP for complex input"""
n_matrices, n_channels_x, n_channels_p, n_times = 2, 3, 3, 100
x = rndstate.randn(n_matrices, n_channels_x, n_times) \
+ 1j * rndstate.randn(n_matrices, n_channels_x, n_times)
p = rndstate.randn(n_channels_p, n_times) \
+ 1j * rndstate.randn(n_channels_p, n_times)
if estimator is None:
cov = covariances_EP(x, p)
else:
Expand All @@ -105,7 +117,6 @@ def test_covariances_X(estimator, rndstate):
"""Test covariance_X for multiple estimators"""
n_matrices, n_channels, n_times = 3, 5, 15
x = rndstate.randn(n_matrices, n_channels, n_times)

if estimator is None:
cov = covariances_X(x, alpha=5.)
else:
Expand Down