Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add class-distinctiveness function and its example code #215

Merged
merged 12 commits into from Dec 13, 2022
6 changes: 6 additions & 0 deletions doc/api.rst
Expand Up @@ -54,6 +54,12 @@ Classification
SVC
MeanField

.. autosummary::
:toctree: generated/
:template: function.rst

class_distinctiveness

Regression
--------------
.. _regression_api:
Expand Down
3 changes: 3 additions & 0 deletions doc/whatsnew.rst
Expand Up @@ -28,6 +28,9 @@ v0.3.1.dev

- Add Transfer Learning module and examples, including RPA and MDWM. :pr:`189` by :user:`plcrodrigues`, :user:`qbarthelemy` and :user:`sylvchev`

- Add class distinctiveness function to measure the distinctiveness between classes on the manifold,
:func:`pyriemann.classification.class_distinctiveness`, and complete an example in gallery to show how it works on synthetic datasets. :pr:`215` by :user:`MSYamamoto`

v0.3 (July 2022)
----------------

Expand Down
77 changes: 52 additions & 25 deletions examples/simulated/plot_toy_classification.py
@@ -1,34 +1,32 @@
"""
=====================================================================
Illustrate classification accuracy versus class separability
=====================================================================
======================================================================
Classification accuracy vs class distinctiveness vs class separability
======================================================================

Generate several datasets containing data points from two-classes. Each class
is generated with a Riemannian Gaussian distribution centered at the class mean
and with the same dispersion sigma. The distance between the class means is
parametrized by Delta, which we make vary between zero and 5*sigma. We
illustrate how the accuracy of the MDM classifier varies when Delta increases.
illustrate how the accuracy of the MDM classifier and the value of the class
distinctiveness [1]_ vary when Delta increases.

"""
# Authors: Pedro Rodrigues <pedro.rodrigues@melix.org>
# Maria Sayu Yamamoto <maria-sayu.yamamoto@universite-paris-saclay.fr>
#
# License: BSD (3-clause)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_score, StratifiedKFold

from pyriemann.classification import MDM
from pyriemann.datasets import make_gaussian_blobs


print(__doc__)

from pyriemann.classification import class_distinctiveness

###############################################################################
# Set general parameters for the illustrations


n_matrices = 100 # how many matrices to sample on each class
n_dim = 4 # dimensionality of the data points
sigma = 1.0 # dispersion of the Gaussian distributions
Expand All @@ -37,6 +35,7 @@
###############################################################################
# Loop over different levels of separability between the classes
scores_array = []
class_dis_array = []
deltas_array = np.linspace(0, 3*sigma, 10)

for delta in deltas_array:
Expand All @@ -48,28 +47,56 @@
random_state=random_state,
n_jobs=4)

# which classifier to consider
# measure class distinctiveness of training data for each split
skf = StratifiedKFold(n_splits=5)
all_class_dis = []
for train_ind, _ in skf.split(X, y):
class_dis = class_distinctiveness(X[train_ind], y[train_ind],
exponent=1, metric='riemann',
return_num_denom=False)
all_class_dis.append(class_dis)

# average class distinctiveness across splits
mean_class_dis = np.mean(all_class_dis)
class_dis_array.append(mean_class_dis)

# Now let's train a MDM classifier and measure its performance
clf = MDM()

# get the classification score for this setup
scores_array.append(
cross_val_score(clf, X, y, cv=5, scoring='roc_auc').mean())
cross_val_score(clf, X, y, cv=skf, scoring='roc_auc').mean())

scores_array = np.array(scores_array)
class_dis_array = np.array(class_dis_array)

###############################################################################
# Plot the results
fig, ax = plt.subplots(figsize=(7.5, 5.9))
ax.plot(deltas_array, scores_array, lw=3.0, label=sigma)
ax.set_xticks([0, 1, 2, 3])
ax.set_xticklabels([0, 1, 2, 3], fontsize=12)
ax.set_yticks([0.6, 0.7, 0.8, 0.9, 1.0])
ax.set_yticklabels([0.6, 0.7, 0.8, 0.9, 1.0], fontsize=12)
ax.set_xlabel(r'$\Delta/\sigma$', fontsize=14)
ax.set_ylabel(r'score', fontsize=12)
ax.set_title(r'Classification score Vs class separability ($n_{dim} = 4$)',
fontsize=12)
ax.grid(True)
ax.legend(loc='lower right', title=r'$\sigma$')

fig, (ax1, ax2) = plt.subplots(sharex=True, nrows=2)

ax1.plot(deltas_array, scores_array, lw=3.0, label=r'ROC AUC score')
ax2.plot(deltas_array, class_dis_array, lw=3.0, color='g',
label='Class Distinctiveness')

ax2.set_xlabel(r'$\Delta/\sigma$', fontsize=14)
ax1.set_ylabel(r'ROC AUC score', fontsize=12)
ax2.set_ylabel(r'class distinctiveness', fontsize=12)
ax1.set_title('Classification score and class distinctiveness value\n'
r'vs. class separability ($n_{dim} = 4$)',
fontsize=12)

ax1.grid(True)
ax2.grid(True)
fig.tight_layout()
plt.show()

###############################################################################
# References
# ----------
# .. [1] `Class-distinctiveness-based frequency band selection on the
# Riemannian manifold for oscillatory activity-based BCIs: preliminary
# results
# <https://hal.archives-ouvertes.fr/hal-03641137/>`_
# M. S. Yamamoto, F. Lotte, F. Yger, and S. Chevallier.
# 44th Annual International Conference of the IEEE Engineering
# in Medicine & Biology Society (EMBC2022), 2022.
137 changes: 131 additions & 6 deletions pyriemann/classification.py
Expand Up @@ -17,7 +17,6 @@


def _check_metric(metric):

if isinstance(metric, str):
metric_mean = metric
metric_dist = metric
Expand Down Expand Up @@ -195,7 +194,7 @@ def predict_proba(self, X):
prob : ndarray, shape (n_matrices, n_classes)
Probabilities for each class.
"""
return softmax(-self._predict_distances(X)**2)
return softmax(-self._predict_distances(X) ** 2)


class FgMDM(BaseEstimator, ClassifierMixin, TransformerMixin):
Expand Down Expand Up @@ -547,7 +546,7 @@ def predict_proba(self, X):
idx = np.argsort(dist)
dist_sorted = np.take_along_axis(dist, idx, axis=1)
neighbors_classes = self.classmeans_[idx]
probas = softmax(-dist_sorted[:, 0:self.n_neighbors]**2)
probas = softmax(-dist_sorted[:, 0:self.n_neighbors] ** 2)

prob = np.zeros((n_matrices, len(self.classes_)))
for m in range(n_matrices):
Expand Down Expand Up @@ -819,7 +818,7 @@ def _get_label(self, x, labs_unique):
for ip, p in enumerate(self.power_list):
for ill, ll in enumerate(labs_unique):
m[ip, ill] = distance(
x, self.covmeans_[p][ll], metric=self.metric)**2
x, self.covmeans_[p][ll], metric=self.metric) ** 2

if self.method_label == 'sum_means':
ipmin = np.argmin(np.sum(m, axis=1))
Expand Down Expand Up @@ -863,7 +862,7 @@ def _predict_distances(self, X):
m[p].append(
distance(
x, self.covmeans_[p][ll], metric=self.metric
)**2
) ** 2
)
pmin = min(m.items(), key=lambda x: np.sum(x[1]))[0]
dist.append(np.array(m[pmin]))
Expand Down Expand Up @@ -903,4 +902,130 @@ def predict_proba(self, X):
prob : ndarray, shape (n_matrices, n_classes)
Probabilities for each class.
"""
return softmax(-self._predict_distances(X)**2)
return softmax(-self._predict_distances(X) ** 2)


def class_distinctiveness(X, y, exponent=1, metric='riemann',
return_num_denom=False):
r"""Measure class distinctiveness between classes of SPD matrices.

For two class problem, the class distinctiveness between class A
and B on the manifold of SPD matrices is quantified as [1]_:

.. math::
\mathrm{classDis}(A, B, p) = \frac{d \left(\bar{X}^{A},
\bar{X}^{B}\right)^p}
{\frac{1}{2} \left( \sigma_{X^{A}}^p + \sigma_{X^{B}}^p \right)}

where :math:`\bar{X}^{K}` is the center of class K, ie the mean of matrices
from class K (see :func:`pyriemann.utils.mean.mean_covariance`) and
:math:`\sigma_{X^{K}}` is the class dispersion, ie the mean of distances
between matrices from class K and their center of class
:math:`\bar{X}^{K}`:

.. math::
\sigma_{X^{K}}^p = \frac{1}{m} \sum_{i=1}^m d
\left(X_i, \bar{X}^{K}\right)^p

For more than two classes, it is quantified as:

.. math::
\mathrm{classDis}\left(\left\{K_{j}\right\}, p\right) =
\frac{\sum_{j=1}^{c} d\left(\bar{X}^{K_{j}}, \tilde{X}\right)^p}
{\sum_{j=1}^{c} \sigma_{X^{K_{j}}}^p}

where :math:`\tilde{X}` is the mean of centers of class of all :math:`c`
classes and :math:`p` is the exponentiation of the distance measure
named exponent at the input of this function.

Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y : ndarray, shape (n_matrices,)
Labels for each matrix.
exponent : int, default=1
Parameter for exponentiation of distances, corresponding to p in the
above equations:

- exponent = 1 gives the formula originally defined in [1]_;
- exponent = 2 gives the Fisher criterion generalized on the manifold,
ie the ratio of the variance between the classes to the variance
within the classes.
metric : string | dict, default='riemann'
The type of metric used for centroid and distance estimation.
See `mean_covariance` for the list of supported metric.
The metric could be a dict with two keys, `mean` and `distance` in
order to pass different metrics for the centroid estimation and the
distance estimation. The original equation of class distinctiveness
in [1]_ uses 'riemann' for both the centroid estimation and the
distance estimation but you can customize other metrics with your
interests.
return_num_denom : bool, default=False
Whether to return numerator and denominator of class_dis.

Returns
-------
class_dis : float
Class distinctiveness value.
num : float
Numerator value of class_dis. Returned only if return_num_denom is
True.
denom : float
Denominator value of class_dis. Returned only if return_num_denom is
True.

Notes
-----
.. versionadded:: 0.3.1

References
----------
.. [1] `Defining and quantifying users’ mental imagery-based
BCI skills: a first step
<https://hal.archives-ouvertes.fr/hal-01846434/>`_
F. Lotte, and C. Jeunet. Journal of neural engineering,
15(4), 046030, 2018.
"""

metric_mean, metric_dist = _check_metric(metric)
classes = np.unique(y)
if len(classes) <= 1:
raise ValueError('X must contain at least two classes')

means = np.array([
mean_covariance(X[y == ll], metric=metric_mean) for ll in classes
])

if len(classes) == 2:
num = distance(means[0], means[1], metric=metric_dist) ** exponent
denom = 0.5 * _get_within(X, y, means, classes, exponent, metric_dist)

else:
mean_all = mean_covariance(means, metric=metric_mean)
dists_between = [
distance(m, mean_all, metric=metric_dist) ** exponent
for m in means
]
num = np.sum(dists_between)
denom = _get_within(X, y, means, classes, exponent, metric_dist)

class_dis = num / denom

if return_num_denom:
return class_dis, num, denom
else:
return class_dis


def _get_within(X, y, means, classes, exponent, metric_dist):
"""Private function to compute within dispersion."""
sigmas = []
for ii, ll in enumerate(classes):
dists_within = [
distance(x, means[ii], metric=metric_dist) ** exponent
for x in X[y == ll]
]
sigmas.append(np.mean(dists_within))
sum_sigmas = np.sum(sigmas)
return sum_sigmas
33 changes: 32 additions & 1 deletion tests/test_classification.py
Expand Up @@ -16,7 +16,8 @@
KNearestNeighbor,
TSclassifier,
SVC,
MeanField
MeanField,
class_distinctiveness,
)

rclf = [MDM, FgMDM, KNearestNeighbor, TSclassifier, SVC, MeanField]
Expand Down Expand Up @@ -340,3 +341,33 @@ def test_meanfield(get_covmats, get_labels, method_label):
assert proba.shape == (n_matrices, n_classes)
transf = mf.transform(covmats)
assert transf.shape == (n_matrices, n_classes)


@pytest.mark.parametrize("n_classes", [1, 2, 3])
@pytest.mark.parametrize("metric_mean", get_means())
@pytest.mark.parametrize("metric_dist", get_distances())
@pytest.mark.parametrize("exponent", [1, 2])
def test_class_distinctiveness(get_covmats, get_labels,
n_classes, metric_mean, metric_dist, exponent):
"""Test function for class distinctiveness measure for two class problem"""
n_matrices, n_channels = 6, 3
covmats = get_covmats(n_matrices, n_channels)
labels = get_labels(n_matrices, n_classes)
if n_classes == 1:
with pytest.raises(ValueError):
class_distinctiveness(covmats, labels)
return

class_dis, num, denom = class_distinctiveness(
covmats,
labels,
exponent,
metric={"mean": metric_mean, "distance": metric_dist},
return_num_denom=True
)
assert class_dis >= 0 # negative class_dis value
assert num >= 0 # negative numerator value
assert denom >= 0 # negative denominator value
assert isinstance(class_dis, float), "Unexpected object of class_dis"
assert isinstance(num, float), "Unexpected object of num"
assert isinstance(denom, float), "Unexpected object of denum"