From 56cf65333f67ba7dbb62c8250bb6611098d7bbda Mon Sep 17 00:00:00 2001 From: Maria Sayu Yamamoto <119514285+MSYamamoto@users.noreply.github.com> Date: Tue, 13 Dec 2022 23:04:25 +0100 Subject: [PATCH] Add class-distinctiveness function and its example code (#215) * Add class-distinctiveness function and its example code * Make class_dis function support more than two classes, apply flake8 on scripts, move the function to preprocessing.py and delete featureselection.py * Add the test of class_dis function, fix the class_dis formula, move the function to classification, and incorporate class_dis example into plot_toy_classification * update doc, change a bit figure example + factorize tests * flake8 * update whatsnew + add argument 'metric' * update test * improve doc * add private func + add parameter p * change the parameter name to exponent * improve doc * fix equation display and add last modifs Co-authored-by: Alexandre Gramfort Co-authored-by: qbarthelemy --- doc/api.rst | 6 + doc/whatsnew.rst | 3 + examples/simulated/plot_toy_classification.py | 77 ++++++---- pyriemann/classification.py | 137 +++++++++++++++++- tests/test_classification.py | 33 ++++- 5 files changed, 224 insertions(+), 32 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index f7bdde9c..f0b992ed 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -54,6 +54,12 @@ Classification SVC MeanField +.. autosummary:: + :toctree: generated/ + :template: function.rst + + class_distinctiveness + Regression -------------- .. _regression_api: diff --git a/doc/whatsnew.rst b/doc/whatsnew.rst index 20fa3ff1..7e3bd871 100644 --- a/doc/whatsnew.rst +++ b/doc/whatsnew.rst @@ -28,6 +28,9 @@ v0.3.1.dev - Add Transfer Learning module and examples, including RPA and MDWM. :pr:`189` by :user:`plcrodrigues`, :user:`qbarthelemy` and :user:`sylvchev` +- Add class distinctiveness function to measure the distinctiveness between classes on the manifold, + :func:`pyriemann.classification.class_distinctiveness`, and complete an example in gallery to show how it works on synthetic datasets. :pr:`215` by :user:`MSYamamoto` + v0.3 (July 2022) ---------------- diff --git a/examples/simulated/plot_toy_classification.py b/examples/simulated/plot_toy_classification.py index 4edafd8c..7cbeaae1 100644 --- a/examples/simulated/plot_toy_classification.py +++ b/examples/simulated/plot_toy_classification.py @@ -1,34 +1,32 @@ """ -===================================================================== -Illustrate classification accuracy versus class separability -===================================================================== +====================================================================== +Classification accuracy vs class distinctiveness vs class separability +====================================================================== Generate several datasets containing data points from two-classes. Each class is generated with a Riemannian Gaussian distribution centered at the class mean and with the same dispersion sigma. The distance between the class means is parametrized by Delta, which we make vary between zero and 5*sigma. We -illustrate how the accuracy of the MDM classifier varies when Delta increases. +illustrate how the accuracy of the MDM classifier and the value of the class +distinctiveness [1]_ vary when Delta increases. """ # Authors: Pedro Rodrigues +# Maria Sayu Yamamoto # # License: BSD (3-clause) import numpy as np import matplotlib.pyplot as plt -from sklearn.model_selection import cross_val_score +from sklearn.model_selection import cross_val_score, StratifiedKFold from pyriemann.classification import MDM from pyriemann.datasets import make_gaussian_blobs - - -print(__doc__) - +from pyriemann.classification import class_distinctiveness ############################################################################### # Set general parameters for the illustrations - n_matrices = 100 # how many matrices to sample on each class n_dim = 4 # dimensionality of the data points sigma = 1.0 # dispersion of the Gaussian distributions @@ -37,6 +35,7 @@ ############################################################################### # Loop over different levels of separability between the classes scores_array = [] +class_dis_array = [] deltas_array = np.linspace(0, 3*sigma, 10) for delta in deltas_array: @@ -48,28 +47,56 @@ random_state=random_state, n_jobs=4) - # which classifier to consider + # measure class distinctiveness of training data for each split + skf = StratifiedKFold(n_splits=5) + all_class_dis = [] + for train_ind, _ in skf.split(X, y): + class_dis = class_distinctiveness(X[train_ind], y[train_ind], + exponent=1, metric='riemann', + return_num_denom=False) + all_class_dis.append(class_dis) + + # average class distinctiveness across splits + mean_class_dis = np.mean(all_class_dis) + class_dis_array.append(mean_class_dis) + + # Now let's train a MDM classifier and measure its performance clf = MDM() # get the classification score for this setup scores_array.append( - cross_val_score(clf, X, y, cv=5, scoring='roc_auc').mean()) + cross_val_score(clf, X, y, cv=skf, scoring='roc_auc').mean()) scores_array = np.array(scores_array) +class_dis_array = np.array(class_dis_array) ############################################################################### # Plot the results -fig, ax = plt.subplots(figsize=(7.5, 5.9)) -ax.plot(deltas_array, scores_array, lw=3.0, label=sigma) -ax.set_xticks([0, 1, 2, 3]) -ax.set_xticklabels([0, 1, 2, 3], fontsize=12) -ax.set_yticks([0.6, 0.7, 0.8, 0.9, 1.0]) -ax.set_yticklabels([0.6, 0.7, 0.8, 0.9, 1.0], fontsize=12) -ax.set_xlabel(r'$\Delta/\sigma$', fontsize=14) -ax.set_ylabel(r'score', fontsize=12) -ax.set_title(r'Classification score Vs class separability ($n_{dim} = 4$)', - fontsize=12) -ax.grid(True) -ax.legend(loc='lower right', title=r'$\sigma$') - +fig, (ax1, ax2) = plt.subplots(sharex=True, nrows=2) + +ax1.plot(deltas_array, scores_array, lw=3.0, label=r'ROC AUC score') +ax2.plot(deltas_array, class_dis_array, lw=3.0, color='g', + label='Class Distinctiveness') + +ax2.set_xlabel(r'$\Delta/\sigma$', fontsize=14) +ax1.set_ylabel(r'ROC AUC score', fontsize=12) +ax2.set_ylabel(r'class distinctiveness', fontsize=12) +ax1.set_title('Classification score and class distinctiveness value\n' + r'vs. class separability ($n_{dim} = 4$)', + fontsize=12) + +ax1.grid(True) +ax2.grid(True) +fig.tight_layout() plt.show() + +############################################################################### +# References +# ---------- +# .. [1] `Class-distinctiveness-based frequency band selection on the +# Riemannian manifold for oscillatory activity-based BCIs: preliminary +# results +# `_ +# M. S. Yamamoto, F. Lotte, F. Yger, and S. Chevallier. +# 44th Annual International Conference of the IEEE Engineering +# in Medicine & Biology Society (EMBC2022), 2022. diff --git a/pyriemann/classification.py b/pyriemann/classification.py index 0e857257..ec3ac92c 100644 --- a/pyriemann/classification.py +++ b/pyriemann/classification.py @@ -17,7 +17,6 @@ def _check_metric(metric): - if isinstance(metric, str): metric_mean = metric metric_dist = metric @@ -195,7 +194,7 @@ def predict_proba(self, X): prob : ndarray, shape (n_matrices, n_classes) Probabilities for each class. """ - return softmax(-self._predict_distances(X)**2) + return softmax(-self._predict_distances(X) ** 2) class FgMDM(BaseEstimator, ClassifierMixin, TransformerMixin): @@ -547,7 +546,7 @@ def predict_proba(self, X): idx = np.argsort(dist) dist_sorted = np.take_along_axis(dist, idx, axis=1) neighbors_classes = self.classmeans_[idx] - probas = softmax(-dist_sorted[:, 0:self.n_neighbors]**2) + probas = softmax(-dist_sorted[:, 0:self.n_neighbors] ** 2) prob = np.zeros((n_matrices, len(self.classes_))) for m in range(n_matrices): @@ -819,7 +818,7 @@ def _get_label(self, x, labs_unique): for ip, p in enumerate(self.power_list): for ill, ll in enumerate(labs_unique): m[ip, ill] = distance( - x, self.covmeans_[p][ll], metric=self.metric)**2 + x, self.covmeans_[p][ll], metric=self.metric) ** 2 if self.method_label == 'sum_means': ipmin = np.argmin(np.sum(m, axis=1)) @@ -863,7 +862,7 @@ def _predict_distances(self, X): m[p].append( distance( x, self.covmeans_[p][ll], metric=self.metric - )**2 + ) ** 2 ) pmin = min(m.items(), key=lambda x: np.sum(x[1]))[0] dist.append(np.array(m[pmin])) @@ -903,4 +902,130 @@ def predict_proba(self, X): prob : ndarray, shape (n_matrices, n_classes) Probabilities for each class. """ - return softmax(-self._predict_distances(X)**2) + return softmax(-self._predict_distances(X) ** 2) + + +def class_distinctiveness(X, y, exponent=1, metric='riemann', + return_num_denom=False): + r"""Measure class distinctiveness between classes of SPD matrices. + + For two class problem, the class distinctiveness between class A + and B on the manifold of SPD matrices is quantified as [1]_: + + .. math:: + \mathrm{classDis}(A, B, p) = \frac{d \left(\bar{X}^{A}, + \bar{X}^{B}\right)^p} + {\frac{1}{2} \left( \sigma_{X^{A}}^p + \sigma_{X^{B}}^p \right)} + + where :math:`\bar{X}^{K}` is the center of class K, ie the mean of matrices + from class K (see :func:`pyriemann.utils.mean.mean_covariance`) and + :math:`\sigma_{X^{K}}` is the class dispersion, ie the mean of distances + between matrices from class K and their center of class + :math:`\bar{X}^{K}`: + + .. math:: + \sigma_{X^{K}}^p = \frac{1}{m} \sum_{i=1}^m d + \left(X_i, \bar{X}^{K}\right)^p + + For more than two classes, it is quantified as: + + .. math:: + \mathrm{classDis}\left(\left\{K_{j}\right\}, p\right) = + \frac{\sum_{j=1}^{c} d\left(\bar{X}^{K_{j}}, \tilde{X}\right)^p} + {\sum_{j=1}^{c} \sigma_{X^{K_{j}}}^p} + + where :math:`\tilde{X}` is the mean of centers of class of all :math:`c` + classes and :math:`p` is the exponentiation of the distance measure + named exponent at the input of this function. + + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + y : ndarray, shape (n_matrices,) + Labels for each matrix. + exponent : int, default=1 + Parameter for exponentiation of distances, corresponding to p in the + above equations: + + - exponent = 1 gives the formula originally defined in [1]_; + - exponent = 2 gives the Fisher criterion generalized on the manifold, + ie the ratio of the variance between the classes to the variance + within the classes. + metric : string | dict, default='riemann' + The type of metric used for centroid and distance estimation. + See `mean_covariance` for the list of supported metric. + The metric could be a dict with two keys, `mean` and `distance` in + order to pass different metrics for the centroid estimation and the + distance estimation. The original equation of class distinctiveness + in [1]_ uses 'riemann' for both the centroid estimation and the + distance estimation but you can customize other metrics with your + interests. + return_num_denom : bool, default=False + Whether to return numerator and denominator of class_dis. + + Returns + ------- + class_dis : float + Class distinctiveness value. + num : float + Numerator value of class_dis. Returned only if return_num_denom is + True. + denom : float + Denominator value of class_dis. Returned only if return_num_denom is + True. + + Notes + ----- + .. versionadded:: 0.3.1 + + References + ---------- + .. [1] `Defining and quantifying users’ mental imagery-based + BCI skills: a first step + `_ + F. Lotte, and C. Jeunet. Journal of neural engineering, + 15(4), 046030, 2018. + """ + + metric_mean, metric_dist = _check_metric(metric) + classes = np.unique(y) + if len(classes) <= 1: + raise ValueError('X must contain at least two classes') + + means = np.array([ + mean_covariance(X[y == ll], metric=metric_mean) for ll in classes + ]) + + if len(classes) == 2: + num = distance(means[0], means[1], metric=metric_dist) ** exponent + denom = 0.5 * _get_within(X, y, means, classes, exponent, metric_dist) + + else: + mean_all = mean_covariance(means, metric=metric_mean) + dists_between = [ + distance(m, mean_all, metric=metric_dist) ** exponent + for m in means + ] + num = np.sum(dists_between) + denom = _get_within(X, y, means, classes, exponent, metric_dist) + + class_dis = num / denom + + if return_num_denom: + return class_dis, num, denom + else: + return class_dis + + +def _get_within(X, y, means, classes, exponent, metric_dist): + """Private function to compute within dispersion.""" + sigmas = [] + for ii, ll in enumerate(classes): + dists_within = [ + distance(x, means[ii], metric=metric_dist) ** exponent + for x in X[y == ll] + ] + sigmas.append(np.mean(dists_within)) + sum_sigmas = np.sum(sigmas) + return sum_sigmas diff --git a/tests/test_classification.py b/tests/test_classification.py index c65439fd..67c796ce 100644 --- a/tests/test_classification.py +++ b/tests/test_classification.py @@ -16,7 +16,8 @@ KNearestNeighbor, TSclassifier, SVC, - MeanField + MeanField, + class_distinctiveness, ) rclf = [MDM, FgMDM, KNearestNeighbor, TSclassifier, SVC, MeanField] @@ -340,3 +341,33 @@ def test_meanfield(get_covmats, get_labels, method_label): assert proba.shape == (n_matrices, n_classes) transf = mf.transform(covmats) assert transf.shape == (n_matrices, n_classes) + + +@pytest.mark.parametrize("n_classes", [1, 2, 3]) +@pytest.mark.parametrize("metric_mean", get_means()) +@pytest.mark.parametrize("metric_dist", get_distances()) +@pytest.mark.parametrize("exponent", [1, 2]) +def test_class_distinctiveness(get_covmats, get_labels, + n_classes, metric_mean, metric_dist, exponent): + """Test function for class distinctiveness measure for two class problem""" + n_matrices, n_channels = 6, 3 + covmats = get_covmats(n_matrices, n_channels) + labels = get_labels(n_matrices, n_classes) + if n_classes == 1: + with pytest.raises(ValueError): + class_distinctiveness(covmats, labels) + return + + class_dis, num, denom = class_distinctiveness( + covmats, + labels, + exponent, + metric={"mean": metric_mean, "distance": metric_dist}, + return_num_denom=True + ) + assert class_dis >= 0 # negative class_dis value + assert num >= 0 # negative numerator value + assert denom >= 0 # negative denominator value + assert isinstance(class_dis, float), "Unexpected object of class_dis" + assert isinstance(num, float), "Unexpected object of num" + assert isinstance(denom, float), "Unexpected object of denum"