Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schaefer-Strimmer shrinkage covariance estimator #59

Merged
merged 55 commits into from Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
9bd4d93
rewriting base file to get ride of numpy matrices.
Apr 13, 2016
93f3d4b
last modification to rewrite into full numpy integration
Apr 13, 2016
678006e
rewriting mean with dotted notation and getting rid of numpy matrices
Apr 13, 2016
70142c7
continuing to rewrite the geodesic
Apr 13, 2016
f33acd0
rewriting of the spatial filters
Apr 13, 2016
c9c7ffd
small style correction in channelselection
Apr 13, 2016
c78b666
updating the use of numpy for all higher packages
Apr 14, 2016
94fe960
last rewriting of pyriemann file
Apr 15, 2016
5dbd0a3
adding slack notification
Apr 18, 2016
9fbc462
finishing to update numpy ref and removing numpy matrices
Apr 18, 2016
4ef2c81
Merge branch 'nonumpymat'
Apr 18, 2016
6764d76
covering all functions of classification: TSclassifier, kNN
Apr 20, 2016
e3c9d66
PEP8 conformation for test clustering
Apr 20, 2016
9a81789
PEP 8 normalisation
Apr 21, 2016
c1e97bc
fixing last bugs for testing
May 2, 2016
d707ba4
update test
May 14, 2016
387c6ae
adding test for Potato
May 15, 2016
e3497c2
end of modification of test units and coverage improvment.
May 17, 2016
2609082
modifying and adapting travis file.
May 17, 2016
979dc21
updating travis file and correcting error in testing wasserstein mean
May 17, 2016
efaec49
adding pandas in the requierement for Travis
May 17, 2016
cb5ac48
ajout de la generation de matrice SPD
May 20, 2016
0ef2d90
Correcting the Karcher mean implementation if the gradient descent is…
May 20, 2016
0a7d003
adding ALM to mean computation
Aug 22, 2016
01b06af
adding ref for the parametrized geodesic
Aug 22, 2016
32ee871
update avant merge
May 16, 2017
c1081cb
adding computation and test of Schaefer-Strimmer covariance estimator
Aug 18, 2017
b4de85b
replacing loops with algebra
Aug 24, 2017
55d91d4
merge with upstream
Jun 11, 2018
b58e8c4
adding an example
Jun 13, 2018
b6a73a3
correcting plots
Jun 13, 2018
6665145
uncomplete merge
Jun 29, 2018
0b061e1
Merge remote-tracking branch 'upstream/master'
Apr 10, 2019
cf15591
Merge remote-tracking branch 'upstream/master'
Dec 11, 2019
0e24641
using upstream version
Aug 11, 2021
e122092
merge with new repo
Aug 11, 2021
f330ec8
pep8
Aug 11, 2021
866f3aa
update test for tangent space
Aug 14, 2021
fe725fd
correct docstring and bug
Aug 14, 2021
7fab3ff
converting notebook into py file
Aug 14, 2021
9de9338
correct style
Aug 14, 2021
85c51b4
rewrite plots, reduce computation time
Aug 15, 2021
6ef7068
update link and rename variables
Aug 15, 2021
52df50c
correct style
Aug 15, 2021
5b248a5
correct plot layout
Aug 15, 2021
7919ca8
adding Schaefer-Strimmer in whatsnew
Aug 15, 2021
e0d54f6
update test, docstring and docs from suggestions
Aug 17, 2021
d53b052
improve tests for covariance
qbarthelemy Aug 18, 2021
f25aaad
keep the same color for each estimator across figures
qbarthelemy Aug 18, 2021
900b8bf
correct example text
Aug 19, 2021
f20ab23
raise error for covariance_EP and initialize RandomState in example
Aug 19, 2021
195cb8b
correct import order
Aug 19, 2021
852ae79
Apply suggestions from code review
Sep 2, 2021
28e2469
sort import, correct spd
Sep 2, 2021
0916352
switch zeros to np.empty
Sep 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whatsnew.rst
Expand Up @@ -14,6 +14,8 @@ v0.2.8.dev

- Correct spectral estimation in :func:`pyriemann.utils.covariance.cross_spectrum` to obtain equivalence with SciPy

- Add Schaefer-Strimmer covariance estimator in :func:`pyriemann.utils.covariance.covariances`, and an example to compare estimators

v0.2.7 (June 2021)
------------------

Expand Down
4 changes: 4 additions & 0 deletions examples/signal/README.txt
@@ -0,0 +1,4 @@
Preprocessing
-------------------

Using pyRiemann for signal processing and covariance estimation
208 changes: 208 additions & 0 deletions examples/signal/plot_covariance_estimation.py
@@ -0,0 +1,208 @@
"""
===============================================================================
Estimate covariance with different time windows
===============================================================================

Covariance estimators comparison for different EEG signal lengths and their
impact on classification [1]_.
"""
# Author: Sylvain Chevallier
#
# License: BSD (3-clause)

from matplotlib import pyplot as plt
from mne import Epochs, pick_types, events_from_annotations
from mne.io import concatenate_raws
from mne.io.edf import read_raw_edf
from mne.datasets import eegbci
import numpy as np
import pandas as pd
from pyriemann.estimation import Covariances
from pyriemann.utils.distance import distance
from pyriemann.classification import MDM
import seaborn as sns
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.pipeline import make_pipeline

sylvchev marked this conversation as resolved.
Show resolved Hide resolved

rs = np.random.RandomState(42)

###############################################################################
# Estimating covariance on synthetic data
# ----------------------------------------
#
# Generate synthetic data, sampled from a distribution considered as the
# groundtruth.

n_trials, n_channels, n_times = 10, 5, 1000
var = 2.0 + 0.1 * rs.randn(n_trials, n_channels)
A = 2 * rs.rand(n_channels, n_channels) - 1
A /= np.linalg.norm(A, axis=1)[:, np.newaxis]
true_cov = np.empty(shape=(n_trials, n_channels, n_channels))
X = np.empty(shape=(n_trials, n_channels, n_times))
for i in range(n_trials):
true_cov[i] = A @ np.diag(var[i]) @ A.T
X[i] = rs.multivariate_normal(
np.array([0.0] * n_channels), true_cov[i], size=n_times
).T

###############################################################################
# Covariances() object offers several estimators: SCM, Ledoit-Wolf (LWF),
# Schaefer-Strimmer (SCH), oracle approximating shrunk covariance (OAS),
# minimum covariance determinant (MCD) and others. We will compare the
# distance of LWF, OAS and SCH estimators with the groundtruth, while
# increasing epoch length.

estimators = ["lwf", "oas", "sch"]
w_len = np.linspace(10, n_times, 20, dtype=int)
dfd = list()
for est in estimators:
for wl in w_len:
cov_est = Covariances(estimator=est).transform(X[:, :, :wl])
for k in range(n_trials):
dist = distance(cov_est[k], true_cov[k], metric="riemann")
dfd.append(dict(estimator=est, wlen=wl, dist=dist))
dfd = pd.DataFrame(dfd)

###############################################################################

fig, ax = plt.subplots(figsize=(6, 4))
ax.set(xscale="log")
sns.lineplot(data=dfd, x="wlen", y="dist", hue="estimator", ax=ax)
ax.set_title("Distance to groundtruth covariance matrix")
ax.set_xlabel("Number of samples")
ax.set_ylabel(r"$\delta(\Sigma, \hat{\Sigma})$")
plt.tight_layout()

###############################################################################
# Choice of estimator for motor imagery data
# -------------------------------------------
# Loading data from PhysioNet MI dataset, for subject 1.

event_id = dict(hands=2, feet=3)
subject = 1
runs = [6, 10, 14] # motor imagery: hands vs feet
raw_files = [
read_raw_edf(f, preload=True, stim_channel="auto")
for f in eegbci.load_data(subject, runs)
]
raw = concatenate_raws(raw_files)
picks = pick_types(raw.info, eeg=True, exclude="bads")

# subsample elecs
picks = picks[::2]
# Apply band-pass filter
raw.filter(7.0, 35.0, method="iir", picks=picks, skip_by_annotation="edge")
events, _ = events_from_annotations(raw, event_id=dict(T1=2, T2=3))
event_ids = dict(hands=2, feet=3)

###############################################################################
# Influence of shrinkage to estimate covariance
# -----------------------------------------------
# Sample covariance matrix (SCM) estimation could lead to ill-conditionned
# matrices depending on the quality and quantity of EEG data available.
# Matrix condition number is the ratio between the highest and lowest
# eigenvalues: high values indicates ill-conditionned matrices that are not
# suitable for classification.
# A common approach to mitigate this issue is to regularize covariance matrices
# by shrinkage, like in Ledoit-Wolf, Schaefer-Strimmer or oracle estimators.

estimators = ["lwf", "oas", "sch", "scm"]
tmin = -0.2
w_len = np.linspace(0.2, 2, 10)
n_trials = 45
dfc = list()

for wl in w_len:
epochs = Epochs(
raw,
events,
event_ids,
tmin,
tmin + wl,
picks=picks,
preload=True,
verbose=False,
)
for est in estimators:
cov = Covariances(estimator=est).transform(epochs.get_data())
for k in range(len(cov)):
ev, _ = np.linalg.eigh(cov[k, :, :])
dfc.append(dict(estimator=est, wlen=wl, cond=ev[-1] / ev[0]))
dfc = pd.DataFrame(dfc)

###############################################################################

fig, ax = plt.subplots(figsize=(6, 4))
ax.set(yscale="log")
sns.lineplot(data=dfc, x="wlen", y="cond", hue="estimator", ax=ax)
ax.set_title("Condition number of estimated covariance matrices")
ax.set_xlabel("Epoch length (s)")
ax.set_ylabel(r"$\lambda_{\max}$/$\lambda_{\min}$")
plt.tight_layout()

###############################################################################
# Picking a good estimator for classification
# -----------------------------------------------
# The choice of covariance estimator have an impact on classification,
# especially when the covariances are estimated on short time windows.

estimators = ["lwf", "oas", "sch", "scm"]
tmin = 0.0
w_len = np.linspace(0.2, 2.0, 5)
n_trials, n_splits = 45, 3
dfa = list()
sc = "balanced_accuracy"

cv = StratifiedKFold(n_splits=n_splits, shuffle=True)
for wl in w_len:
epochs = Epochs(
raw,
events,
event_ids,
tmin,
tmin + wl,
proj=True,
picks=picks,
preload=True,
baseline=None,
verbose=False,
)
X = epochs.get_data()
y = np.array([0 if ev == 2 else 1 for ev in epochs.events[:, -1]])
for est in estimators:
clf = make_pipeline(Covariances(estimator=est), MDM())
try:
score = cross_val_score(clf, X, y, cv=cv, scoring=sc)
dfa += [dict(estimator=est, wlen=wl, accuracy=sc) for sc in score]
except ValueError:
print(f"{est}: {wl} is not sufficent to estimate a PSD matrix")
sylvchev marked this conversation as resolved.
Show resolved Hide resolved
dfa += [dict(estimator=est, wlen=wl, accuracy=np.nan)] * n_splits
dfa = pd.DataFrame(dfa)

###############################################################################

fig, ax = plt.subplots(figsize=(6, 4))
sns.lineplot(
data=dfa,
x="wlen",
y="accuracy",
hue="estimator",
style="estimator",
ax=ax,
ci=None,
sylvchev marked this conversation as resolved.
Show resolved Hide resolved
markers=True,
dashes=False,
)
ax.set_title("Accuracy for different estimators and epoch lengths")
ax.set_xlabel("Epoch length (s)")
ax.set_ylabel("Classification accuracy")
plt.tight_layout()

###############################################################################
# References
# ----------
# .. [1] S. Chevallier, E. Kalunga, Q. Barthélemy, F. Yger. "Riemannian
# classification for SSVEP based BCI: offline versus online implementations"
# Brain–Computer Interfaces Handbook: Technological and Theoretical Advances
# , 2018.
73 changes: 63 additions & 10 deletions pyriemann/utils/covariance.py
Expand Up @@ -30,6 +30,52 @@ def _mcd(X):
return C


def _sch(X):
"""Schaefer-Strimmer covariance estimator

Shrinkage estimator using method from [1]_:
.. math::
\hat{\Sigma} = (1 - \gamma)\Sigma_{scm} + \gamma T

where :math:`T` is the diagonal target matrix:
.. math::
T_{i,j} = \{ \Sigma_{scm}^{ii} \text{if} i = j, 0 \text{otherwise} \}
Note that the optimal :math:`\gamma` is estimated by the authors' method.

:param X: Signal matrix, (n_channels, n_times)

:returns: Schaefer-Strimmer shrinkage covariance matrix, (n_channels, n_channels)

sylvchev marked this conversation as resolved.
Show resolved Hide resolved
Notes
-----
.. versionadded:: 0.2.8.dev

References
----------
.. [1] Schafer, J., and K. Strimmer. 2005. A shrinkage approach to
large-scale covariance estimation and implications for functional
genomics. Statist. Appl. Genet. Mol. Biol. 4:32.
""" # noqa
n_times = X.shape[1]
X_c = (X.T - X.T.mean(axis=0)).T
C_scm = 1. / n_times * X_c @ X_c.T

# Compute optimal gamma, the weigthing between SCM and srinkage estimator
R = (n_times / ((n_times - 1.) * np.outer(X.std(axis=1), X.std(axis=1))))
R *= C_scm
var_R = (X_c ** 2) @ (X_c ** 2).T - 2 * C_scm * (X_c @ X_c.T)
var_R += n_times * C_scm ** 2
Xvar = np.outer(X.var(axis=1), X.var(axis=1))
var_R *= n_times / ((n_times - 1) ** 3 * Xvar)
R -= np.diag(np.diag(R))
var_R -= np.diag(np.diag(var_R))
gamma = max(0, min(1, var_R.sum() / (R ** 2).sum()))

sigma = (1. - gamma) * (n_times / (n_times - 1.)) * C_scm
shrinkage = gamma * (n_times / (n_times - 1.)) * np.diag(np.diag(C_scm))
return sigma + shrinkage


def _check_est(est):
"""Check if a given estimator is valid"""

Expand All @@ -40,6 +86,7 @@ def _check_est(est):
'lwf': _lwf,
'oas': _oas,
'mcd': _mcd,
'sch': _sch,
'corr': np.corrcoef
}

Expand All @@ -65,7 +112,7 @@ def covariances(X, estimator='cov'):
X : ndarray, shape (n_trials, n_channels, n_times)
ndarray of trials.

estimator : {'cov', 'scm', 'lwf', 'oas', 'mcd', 'corr'} (default: 'scm')
estimator : {'cov', 'scm', 'lwf', 'oas', 'mcd', 'sch', 'corr'} (default: 'scm')
covariance matrix estimator:

* 'cov' for numpy based covariance matrix,
Expand All @@ -78,6 +125,8 @@ def covariances(X, estimator='cov'):
https://scikit-learn.org/stable/modules/generated/sklearn.covariance.OAS.html
* 'mcd' for minimum covariance determinant matrix,
https://scikit-learn.org/stable/modules/generated/sklearn.covariance.MinCovDet.html
* 'sch' for Schaefer-Strimmer covariance,
http://doi.org/10.2202/1544-6115.1175,
* 'corr' for correlation coefficient matrix,
https://numpy.org/doc/stable/reference/generated/numpy.corrcoef.html

Expand All @@ -91,20 +140,24 @@ def covariances(X, estimator='cov'):
.. [1] https://scikit-learn.org/stable/modules/covariance.html
""" # noqa
est = _check_est(estimator)
Nt, Ne, Ns = X.shape
covmats = np.zeros((Nt, Ne, Ne))
for i in range(Nt):
n_trials, n_channels, n_times = X.shape
covmats = np.zeros((n_trials, n_channels, n_channels))
sylvchev marked this conversation as resolved.
Show resolved Hide resolved
for i in range(n_trials):
covmats[i, :, :] = est(X[i, :, :])
return covmats


def covariances_EP(X, P, estimator='cov'):
"""Special form covariance matrix."""
est = _check_est(estimator)
Nt, Ne, Ns = X.shape
Np, Ns = P.shape
covmats = np.zeros((Nt, Ne + Np, Ne + Np))
for i in range(Nt):
n_trials, n_channels, n_times = X.shape
n_proto, n_times_P = P.shape
if n_times_P != n_times:
raise ValueError(
f"X and P do not have the same n_times: {n_times} and {n_times_P}"
)
covmats = np.zeros((n_trials, n_channels + n_proto, n_channels + n_proto))
sylvchev marked this conversation as resolved.
Show resolved Hide resolved
for i in range(n_trials):
covmats[i, :, :] = est(np.concatenate((P, X[i, :, :]), axis=0))
return covmats

Expand All @@ -117,10 +170,10 @@ def eegtocov(sig, window=128, overlapp=0.5, padding=True, estimator='cov'):
padd = np.zeros((int(window / 2), sig.shape[1]))
sig = np.concatenate((padd, sig, padd), axis=0)

Ns, Ne = sig.shape
n_times, n_channels = sig.shape
jump = int(window * overlapp)
ix = 0
while (ix + window < Ns):
while (ix + window < n_times):
X.append(est(sig[ix:ix + window, :].T))
ix = ix + jump

Expand Down
14 changes: 14 additions & 0 deletions tests/test_tangentspace.py
@@ -1,6 +1,7 @@
"""Test tangent space functions."""
import numpy as np
from numpy.testing import assert_array_almost_equal
import pytest

from pyriemann.tangentspace import TangentSpace, FGDA

Expand Down Expand Up @@ -37,6 +38,19 @@ def test_TangentSpace_transform():
ts.transform(covset)


@pytest.mark.parametrize('shape', [(10, 9), (10, 9, 8), (10), (12, 8, 8)])
def test_TangentSpace_transform_dim(shape):
"""Test transform input shape, could be TS vector or covmat"""
n_trials, n_channels = 10, 3
covset = generate_cov(n_trials, n_channels)
ts = TangentSpace(metric='riemann')
ts.fit(covset)

X = np.zeros(shape=shape)
with pytest.raises(ValueError):
ts.transform(X)


def test_TangentSpace_transform_without_fit():
"""Test transform of Tangent Space without fit."""
covset = generate_cov(10, 3)
Expand Down