Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial_fit in Potato #133

Merged
merged 5 commits into from Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/whatsnew.rst
Expand Up @@ -10,9 +10,11 @@ A catalog of new features, improvements, and bug-fixes in each release.
v0.2.8.dev
----------

- Correct spectral estimation in :func:`pyriemann.utils.covariance.cross_spectrum` to obtain equivalence with SciPy

- Add instantaneous, lagged and imaginary coherences in :func:`pyriemann.utils.covariance.coherence` and :class:`pyriemann.estimation.Coherences`

- Correct spectral estimation in :func:`pyriemann.utils.covariance.cross_spectrum` to obtain equivalence with SciPy
- Add `partial_fit` in :class:`pyriemann.clustering.Potato`, useful for an online update; and update example on artifact detection
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved

v0.2.7 (June 2021)
------------------
Expand Down
8 changes: 4 additions & 4 deletions examples/artifacts/plot_correct_ajdc_EEG.py
Expand Up @@ -4,7 +4,7 @@
===============================================================================

Blind source separation (BSS) based on approximate joint diagonalization of
Fourier cospectra (AJDC), applied to artifact correction of EEG [1].
Fourier cospectra (AJDC), applied to artifact correction of EEG [1]_.
"""
# Authors: Quentin Barthélemy & David Ojeda.
# EEG signal kindly shared by Marco Congedo.
Expand Down Expand Up @@ -175,6 +175,6 @@ def read_header(fname):
###############################################################################
# References
# ----------
# [1] Q. Barthélemy, L. Mayaud, Y. Renard, D. Kim, S.-W. Kang, J. Gunkelman and
# M. Congedo, "Online denoising of eye-blinks in electroencephalography",
# Neurophysiol Clin, 2017
# .. [1] Q. Barthélemy, L. Mayaud, Y. Renard, D. Kim, S.-W. Kang, J. Gunkelman
# and M. Congedo, "Online denoising of eye-blinks in electroencephalography"
# , Neurophysiol Clin, 2017.
37 changes: 28 additions & 9 deletions examples/artifacts/plot_detect_riemannian_potato_EEG.py
Expand Up @@ -104,17 +104,20 @@ def add_alpha(p_cols, alphas):


###############################################################################
# Calibration of Potato
# ---------------------
# Offline Calibration of Potato
# -----------------------------
#
# 2D projection of the z-score map of the Riemannian potato, for 2x2 covariance
# matrices (in blue if clean, in red if artifacted) and their reference matrix
# (in black). The colormap defines the z-score and a chosen isocontour defines
# the potato. It reproduces Fig 1 of reference [2]_.

z_th = 2.5 # z-score threshold
z_th = 2.0 # z-score threshold
t = 40 # nb of matrices to train the potato


###############################################################################

# Calibrate potato by unsupervised training on first matrices: compute a
# reference matrix, mean and standard deviation of distances to this reference.
train_set = range(t)
Expand Down Expand Up @@ -159,9 +162,17 @@ def add_alpha(p_cols, alphas):
#
# Detect artifacts/outliers on test set, with an animation to imitate an online
# acquisition, processing and artifact detection of EEG time-series.
# The potato is static: it is not updated when EEG is not artifacted, damaging
# its efficiency over time.
# Initialized with an offline calibration, the online potato can be [2]_:
#
# * static: it is never updated, damaging its efficiency over time,
# * semi-dynamic: it is updated when EEG is not artifacted.

is_static = False # static or semi-dynamic mode


###############################################################################

# Prepare data for online detection
test_covs_max = 300 # nb of matrices to visualize in this example
test_covs_visu = 30 # nb of matrices to display simultaneously
test_time_start = -2 # start time to display signal
Expand All @@ -177,7 +188,6 @@ def add_alpha(p_cols, alphas):
rp_colors, ep_colors = [], []
alphas = np.linspace(0, 1, test_covs_visu)

# Prepare animation for online detection
fig = plt.figure(figsize=(12, 10), constrained_layout=False)
fig.suptitle('Online artifact detection by potatoes', fontsize=16)
gs = fig.add_gridspec(nrows=4, ncols=40, top=0.90, hspace=0.3, wspace=1.0)
Expand All @@ -197,15 +207,21 @@ def add_alpha(p_cols, alphas):
p_ep = plot_potato_2D(ax_ep, cax_ep, X, Y, ep_zscores, ep_center, covs_visu,
ep_colors, 'Z-score of Euclidean distance to reference')

# Plot online detection (an interactive display is required)

###############################################################################

# Prepare animation for online detection
def online_update(self):
global t, time, sig, covs_visu

# Online artifact detection
rp_label = rpotato.predict(covs[t][np.newaxis, ...])
ep_label = epotato.predict(covs[t][np.newaxis, ...])
if not is_static:
if rp_label[0] == 1:
rpotato.partial_fit(covs[t][np.newaxis, ...], alpha=1 / t)
if ep_label[0] == 1:
epotato.partial_fit(covs[t][np.newaxis, ...], alpha=1 / t)

# Update data
time_start = t * interval + test_time_end
Expand Down Expand Up @@ -238,8 +254,11 @@ def online_update(self):
return pl_sig0, pl_sig1, p_rp, p_ep


# For a correct display, change the parameter 'interval_display'
interval_display = 1.0
###############################################################################

# Plot online detection (an interactive display is required).
interval_display = 1.0 # can be changed for a slower display

potato = FuncAnimation(fig, online_update, frames=test_covs_max,
interval=interval_display, blit=False, repeat=False)
plt.show()
Expand Down
89 changes: 73 additions & 16 deletions pyriemann/clustering.py
Expand Up @@ -25,6 +25,8 @@ def _init_centroids(X, n_clusters, init, random_state, x_squared_norms):
from joblib import Parallel, delayed

from .classification import MDM
from .utils.mean import mean_covariance
from .utils.geodesic import geodesic

#######################################################################

Expand Down Expand Up @@ -312,7 +314,7 @@ def __init__(self, metric='riemann', threshold=3, n_iter_max=100,
self.threshold = threshold
self.n_iter_max = n_iter_max
if pos_label == neg_label:
raise(ValueError("Positive and Negative labels must be different"))
raise(ValueError("Positive and negative labels must be different"))
self.pos_label = pos_label
self.neg_label = neg_label

Expand All @@ -333,22 +335,8 @@ def fit(self, X, y=None):
"""
self._mdm = MDM(metric=self.metric)

if y is not None:
if len(y) != len(X):
raise ValueError('y must be the same length of X')

classes = np.int32(np.unique(y))

if len(classes) > 2:
raise ValueError('number of classes must be maximum 2')

if self.pos_label not in classes:
raise ValueError('y must contain a positive class')
y_old = self._check_labels(X, y)

y_old = np.int32(np.array(y) == self.pos_label)
else:
y_old = np.ones(len(X))
# start loop
for n_iter in range(self.n_iter_max):
ix = (y_old == 1)
self._mdm.fit(X[ix], y_old[ix])
Expand All @@ -364,6 +352,54 @@ def fit(self, X, y=None):
y_old = y
return self

def partial_fit(self, X, y=None, alpha=0.1):
"""Update the potato from covariance matrices. Useful for dynamic or
semi-dymanic online potatoes.

Parameters
----------
X : ndarray, shape (n_trials, n_channels, n_channels)
ndarray of SPD matrices.
y : ndarray | None (default None)
Not used, here for compatibility with sklearn API.
alpha : float (default 0.1)
Update rate in [0, 1] for the centroid, and mean and standard
deviation of log-distances.

Returns
-------
self : Potato instance
The Potato instance.
"""
if not hasattr(self, '_mdm'):
raise ValueError('Partial fit can be called only on an already '
'fitted potato.')

n_trials, n_channels, _ = X.shape
if n_channels != self._mdm.covmeans_[0].shape[0]:
raise ValueError(
'X does not have the good number of channels. Should be %d but'
' got %d.' % (self._mdm.covmeans_[0].shape[0], n_channels))

y = self._check_labels(X, y)

if not 0 <= alpha <= 1:
raise ValueError('Parameter alpha must be in [0, 1]')

if alpha > 0:
if n_trials > 1: # mini-batch update
Xm = mean_covariance(X[(y == 1)], metric=self.metric)
else: # pure online update
Xm = X[0]
self._mdm.covmeans_[0] = geodesic(
self._mdm.covmeans_[0], Xm, alpha, metric=self.metric)
d = np.squeeze(np.log(self._mdm.transform(Xm[np.newaxis, ...])))
self._mean = (1 - alpha) * self._mean + alpha * d
self._std = np.sqrt(
(1 - alpha) * self._std**2 + alpha * (d - self._mean)**2)

return self

def transform(self, X):
"""return the normalized log-distance to the centroid (z-score).

Expand Down Expand Up @@ -422,6 +458,27 @@ def predict_proba(self, X):
proba = self._get_proba(z)
return proba

def _check_labels(self, X, y):
"""check validity of labels."""
if y is not None:
if len(y) != len(X):
raise ValueError('y must be the same length of X')

classes = np.int32(np.unique(y))

if len(classes) > 2:
raise ValueError('number of classes must be maximum 2')

if self.pos_label not in classes:
raise ValueError('y must contain a positive class')

y = np.int32(np.array(y) == self.pos_label)

else:
y = np.ones(len(X))

return y

def _get_z_score(self, d):
"""get z-score from distance."""
z = (d - self._mean) / self._std
Expand Down
46 changes: 31 additions & 15 deletions tests/test_clustering.py
Expand Up @@ -62,10 +62,14 @@ def test_KmeansPCT_init():

def test_Potato_init():
"""Test Potato"""
covset = generate_cov(20, 3)
labels = np.array([0, 1]).repeat(10)
n_trials, n_channels = 20, 3
covset = generate_cov(n_trials, n_channels)
cov = covset[0][np.newaxis, ...] # to test potato with a single trial
labels = np.array([0, 1]).repeat(n_trials // 2)

# init
with pytest.raises(ValueError): # positive and neg labels equal
Potato(pos_label=0)
pt = Potato()

# fit no labels
Expand All @@ -74,26 +78,43 @@ def test_Potato_init():
# fit with labels
with pytest.raises(ValueError):
pt.fit(covset, y=[1])

with pytest.raises(ValueError):
pt.fit(covset, y=[0] * 20)

with pytest.raises(ValueError):
pt.fit(covset, y=[0, 2, 3] + [1] * 17)

pt.fit(covset, labels)

# partial_fit
with pytest.raises(ValueError): # potato not fitted
Potato().partial_fit(covset)
with pytest.raises(ValueError): # unequal # of chans
pt.partial_fit(generate_cov(2, n_channels + 1))
with pytest.raises(ValueError): # alpha < 0
pt.partial_fit(covset, labels, alpha=-0.1)
with pytest.raises(ValueError): # alpha > 1
pt.partial_fit(covset, labels, alpha=1.1)
with pytest.raises(ValueError): # no positive labels
pt.partial_fit(covset, [0] * n_trials)
pt.partial_fit(covset, labels, alpha=0.6)
pt.partial_fit(cov, alpha=0.1)

# transform
pt.transform(covset)
pt.transform(covset[0][np.newaxis, ...]) # transform a single trial
pt.transform(cov)

# predict
pt.predict(covset)
pt.predict(covset[0][np.newaxis, ...]) # predict a single trial
pt.predict(cov)

# predict_proba
pt.predict_proba(covset)
pt.predict_proba(covset[0][np.newaxis, ...])
pt.predict_proba(cov)

# potato with a single channel
covset_1chan = generate_cov(n_trials, 1)
pt.fit_transform(covset_1chan)
pt.predict(covset_1chan)
pt.predict_proba(covset_1chan)

# lower threshold
pt = Potato(threshold=1)
Expand All @@ -103,10 +124,5 @@ def test_Potato_init():
pt = Potato(threshold=1, pos_label=2, neg_label=7)
pt.fit(covset)
assert_array_equal(np.unique(pt.predict(covset)), [2, 7])

# test with custom positive label
pt.fit(covset, y=[2]*20)

# different positive and neg label
with pytest.raises(ValueError):
Potato(pos_label=0)
# fit with custom positive label
pt.fit(covset, y=[2]*n_trials)