Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fixed point covariance estimator and add **kwds arguments in Covariances #220

Merged
merged 8 commits into from Jan 28, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whatsnew.rst
Expand Up @@ -35,6 +35,8 @@ v0.3.1.dev

- Add kernel matrices representation :class:`pyriemann.estimation.Kernels` and complete example comparing estimators. :pr:`217` by :user:`qbarthelemy`

- Add a new covariance estimator, robust fixed point covariance, and add **kwds arguments for all covariance based functions and classes. :pr:`220` by :user:`qbarthelemy`
sylvchev marked this conversation as resolved.
Show resolved Hide resolved

v0.3 (July 2022)
----------------

Expand Down
48 changes: 37 additions & 11 deletions pyriemann/estimation.py
Expand Up @@ -24,9 +24,11 @@ class Covariances(BaseEstimator, TransformerMixin):

Parameters
----------
estimator : string, default=scm'
estimator : string, default='scm'
Covariance matrix estimator, see
:func:`pyriemann.utils.covariance.covariances`.
**kwds : optional keyword parameters
Any further parameters are passed directly to the covariance estimator.

See Also
--------
Expand All @@ -36,9 +38,10 @@ class Covariances(BaseEstimator, TransformerMixin):
HankelCovariances
"""

def __init__(self, estimator='scm'):
def __init__(self, estimator='scm', **kwds):
"""Init."""
self.estimator = estimator
self.kwds = kwds

def fit(self, X, y=None):
"""Fit.
Expand Down Expand Up @@ -72,7 +75,7 @@ def transform(self, X):
covmats : ndarray, shape (n_matrices, n_channels, n_channels)
Covariance matrices.
"""
covmats = covariances(X, estimator=self.estimator)
covmats = covariances(X, estimator=self.estimator, **self.kwds)
return covmats


Expand Down Expand Up @@ -108,6 +111,8 @@ class ERPCovariances(BaseEstimator, TransformerMixin):
svd : int | None, default=None
If not None, number of components of SVD used to reduce prototype
responses.
**kwds : optional keyword parameters
Any further parameters are passed directly to the covariance estimator.

Attributes
----------
Expand Down Expand Up @@ -139,11 +144,12 @@ class ERPCovariances(BaseEstimator, TransformerMixin):
GRETSI, 2013.
"""

def __init__(self, classes=None, estimator='scm', svd=None):
def __init__(self, classes=None, estimator='scm', svd=None, **kwds):
"""Init."""
self.classes = classes
self.estimator = estimator
self.svd = svd
self.kwds = kwds

def fit(self, X, y):
"""Fit.
Expand Down Expand Up @@ -201,7 +207,12 @@ def transform(self, X):
is None, and to `n_channels + n_classes x min(svd, n_channels)`
otherwise.
"""
covmats = covariances_EP(X, self.P_, estimator=self.estimator)
covmats = covariances_EP(
X,
self.P_,
estimator=self.estimator,
**self.kwds
)
return covmats


Expand Down Expand Up @@ -241,6 +252,8 @@ class XdawnCovariances(BaseEstimator, TransformerMixin):
baseline_cov : array, shape (n_channels, n_channels) | None, default=None
Baseline covariance for `Xdawn` spatial filtering,
see :class:`pyriemann.spatialfilters.Xdawn`.
**kwds : optional keyword parameters
Any further parameters are passed directly to the covariance estimator.

Attributes
----------
Expand All @@ -267,14 +280,16 @@ def __init__(self,
classes=None,
estimator='scm',
xdawn_estimator='scm',
baseline_cov=None):
baseline_cov=None,
**kwds):
"""Init."""
self.applyfilters = applyfilters
self.estimator = estimator
self.xdawn_estimator = xdawn_estimator
self.classes = classes
self.nfilter = nfilter
self.baseline_cov = baseline_cov
self.kwds = kwds

def fit(self, X, y):
"""Fit.
Expand Down Expand Up @@ -322,7 +337,12 @@ def transform(self, X):
if self.applyfilters:
X = self.Xd_.transform(X)

covmats = covariances_EP(X, self.P_, estimator=self.estimator)
covmats = covariances_EP(
X,
self.P_,
estimator=self.estimator,
**self.kwds
)
return covmats


Expand All @@ -346,6 +366,8 @@ class BlockCovariances(BaseEstimator, TransformerMixin):
estimator : string, default='scm'
Covariance matrix estimator, see
:func:`pyriemann.utils.covariance.covariances`.
**kwds : optional keyword parameters
Any further parameters are passed directly to the covariance estimator.

Notes
-----
Expand All @@ -356,10 +378,11 @@ class BlockCovariances(BaseEstimator, TransformerMixin):
Covariances
"""

def __init__(self, block_size, estimator='scm'):
def __init__(self, block_size, estimator='scm', **kwds):
"""Init."""
self.estimator = estimator
self.block_size = block_size
self.kwds = kwds

def fit(self, X, y=None):
"""Fit.
Expand Down Expand Up @@ -405,7 +428,7 @@ def transform(self, X):
else:
raise ValueError("Parameter block_size must be int or list.")

return block_covariances(X, blocks, self.estimator)
return block_covariances(X, blocks, self.estimator, **self.kwds)


###############################################################################
Expand Down Expand Up @@ -630,6 +653,8 @@ class HankelCovariances(BaseEstimator, TransformerMixin):
estimator : string, default='scm'
Covariance matrix estimator, see
:func:`pyriemann.utils.covariance.covariances`.
**kwds : optional keyword parameters
Any further parameters are passed directly to the covariance estimator.

See Also
--------
Expand All @@ -647,10 +672,11 @@ class HankelCovariances(BaseEstimator, TransformerMixin):
Biomedical Engineering 52(9), 1541-1548, 2005.
"""

def __init__(self, delays=4, estimator='scm'):
def __init__(self, delays=4, estimator='scm', **kwds):
"""Init."""
self.delays = delays
self.estimator = estimator
self.kwds = kwds

def fit(self, X, y=None):
"""Fit.
Expand Down Expand Up @@ -701,7 +727,7 @@ def transform(self, X):
tmp = np.r_[tmp, np.roll(x, d, axis=-1)]
X2.append(tmp)
X2 = np.array(X2)
covmats = covariances(X2, estimator=self.estimator)
covmats = covariances(X2, estimator=self.estimator, **self.kwds)
return covmats


Expand Down