Skip to content

Commit

Permalink
Adding Riemannian Gaussian to pyRiemann (#140)
Browse files Browse the repository at this point in the history
* adding the sampling.py feature and two examples related to it

* cleaning up

* correcting pep8 issues

* getting rid of the tqdm dependency

* Update examples/sampling/plot_riemannian_gaussian.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update examples/sampling/plot_toy_classification.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Remove ignored files

* Update examples/sampling/plot_toy_classification.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update examples/sampling/plot_toy_classification.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update examples/sampling/plot_toy_classification.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update pyriemann/datasets.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update examples/sampling/plot_riemannian_gaussian.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update pyriemann/sampling.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* taking into account agramfort and qbarthelemy comments

* more updates

* adding make_outliers function

* putting make_* functions into simulated.py for consistency

* including corrections and comments

* ...

* further corrections

* ...

* flake8

* adding a few tests

* flake8

* including reviews; adding a new test

* ...

* adding warnings

* trying to edit API.rst

* Add Riemannian Potato Field (#142)

* add PotatoField and improve Potato

* add example on PotatoField and clean example on Potato

* add tests and complete api

* minor corrections

* minor modifs

* commit suggestions

* complete doc

* complete doc again

* Apply suggestions from code review

Co-authored-by: Pedro L. C. Rodrigues <pedro.rodrigues01@gmail.com>

* complete doc

Co-authored-by: Pedro L. C. Rodrigues <pedro.rodrigues01@gmail.com>

* remove class templates in api.rst for new functionns

* adding things to whatsnew.rst

* ...

* adding the sampling.py feature and two examples related to it

* cleaning up

* correcting pep8 issues

* getting rid of the tqdm dependency

* Update examples/sampling/plot_riemannian_gaussian.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update examples/sampling/plot_toy_classification.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Remove ignored files

* Update examples/sampling/plot_toy_classification.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update examples/sampling/plot_toy_classification.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update examples/sampling/plot_toy_classification.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update pyriemann/datasets.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update examples/sampling/plot_riemannian_gaussian.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update pyriemann/sampling.py

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* taking into account agramfort and qbarthelemy comments

* more updates

* adding make_outliers function

* putting make_* functions into simulated.py for consistency

* including corrections and comments

* ...

* further corrections

* ...

* flake8

* adding a few tests

* flake8

* including reviews; adding a new test

* ...

* adding warnings

* trying to edit API.rst

* remove class templates in api.rst for new functionns

* adding things to whatsnew.rst

* ...

* adding ValueError tests

* solving problem coming from conftest.py

* ...

* ...

* trying to make tests pass?

* flake8

* re-add init file

* move is_* functions into pyriemann to be reused by pyriemann methods

* remove pytest import from utils.test

* re-add tests for raise errors

* refactor test

* revert changes on doc/conf.py

* complete tests

* updating api.rst

* fix whatsnew.rst

* Update doc/whatsnew.rst

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>

Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
  • Loading branch information
3 people committed Oct 19, 2021
1 parent 36b1ad0 commit 608fdfe
Show file tree
Hide file tree
Showing 15 changed files with 960 additions and 130 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -7,3 +7,6 @@
/doc/auto_examples
/dist
*-e
*.egg-info/
.vscode/
sandbox/
17 changes: 15 additions & 2 deletions doc/api.rst
Expand Up @@ -60,7 +60,6 @@ Clustering
Potato
PotatoField


Tangent Space
------------------
.. _tangentspace_api:
Expand Down Expand Up @@ -109,7 +108,7 @@ Channel selection
:template: class.rst

ElectrodeSelection
FlatChannelRemover
FlatChannelRemover

Stats
------------------
Expand All @@ -123,6 +122,19 @@ Stats
PermutationDistance
PermutationModel

Datasets
------------------
.. _datasets_api:
.. currentmodule:: pyriemann.datasets

.. autosummary::
:toctree: generated/

make_gaussian_blobs
make_outliers
make_covariances
sample_gaussian_spd
generate_random_spd_matrix

Utils function
--------------
Expand Down Expand Up @@ -233,3 +245,4 @@ Aproximate Joint Diagonalization
rjd
ajd_pham
uwedge

4 changes: 3 additions & 1 deletion doc/whatsnew.rst
Expand Up @@ -24,7 +24,9 @@ v0.2.8.dev

- Refactor tests + fix refit of :class:`pyriemann.tangentspace.TangentSpace`

- Add :class:`pyriemann.clustering.PotatoField`, and an example on artifact detection
- Add sampling SPD matrices from a Riemannian Gaussian distribution in :func:`pyriemann.datasets.sample_gaussian_spd`

- Add new function :func:`pyriemann.datasets.make_gaussian_blobs` for generating random datasets with SPD matrices

v0.2.7 (June 2021)
------------------
Expand Down
4 changes: 4 additions & 0 deletions examples/simulated/README.txt
@@ -0,0 +1,4 @@
Simulated data
---------------

Examples using datasets sampled from known probability distributions
77 changes: 77 additions & 0 deletions examples/simulated/plot_riemannian_gaussian.py
@@ -0,0 +1,77 @@
"""
=====================================================================
Sample from the Riemannian Gaussian distribution in the SPD manifold
=====================================================================
Spectral embedding of samples from the Riemannian Gaussian distribution
with different centerings and dispersions.
"""
# Authors: Pedro Rodrigues <pedro.rodrigues@melix.org>
#
# License: BSD (3-clause)

import numpy as np
import matplotlib.pyplot as plt

from pyriemann.embedding import Embedding
from pyriemann.datasets import sample_gaussian_spd, generate_random_spd_matrix


print(__doc__)

###############################################################################
# Set parameters for sampling from the Riemannian Gaussian distribution
n_matrices = 100 # how many SPD matrices to generate
n_dim = 4 # number of dimensions of the SPD matrices
sigma = 1.0 # dispersion of the Gaussian distribution
epsilon = 4.0 # parameter for controlling the distance between centers
random_state = 42 # ensure reproducibility

# Generate the samples on three different conditions
mean = generate_random_spd_matrix(n_dim) # random reference point

samples_1 = sample_gaussian_spd(n_matrices=n_matrices,
mean=mean,
sigma=sigma,
random_state=random_state)
samples_2 = sample_gaussian_spd(n_matrices=n_matrices,
mean=mean,
sigma=sigma/2,
random_state=random_state)
samples_3 = sample_gaussian_spd(n_matrices=n_matrices,
mean=epsilon*mean,
sigma=sigma,
random_state=random_state)

# Stack all of the samples into one data array for the embedding
samples = np.concatenate([samples_1, samples_2, samples_3])
labels = np.array(n_matrices*[1] + n_matrices*[2] + n_matrices*[3])

###############################################################################
# Apply the spectral embedding over the SPD matrices
lapl = Embedding(metric='riemann', n_components=2)
embd = lapl.fit_transform(X=samples)

###############################################################################
# Plot the results

fig, ax = plt.subplots(figsize=(8, 6))

colors = {1: 'C0', 2: 'C1', 3: 'C2'}
for i in range(len(samples)):
ax.scatter(embd[i, 0], embd[i, 1], c=colors[labels[i]], s=50)
ax.scatter([], [], c='C0', s=50, label=r'$\varepsilon = 1.00, \sigma = 1.00$')
ax.scatter([], [], c='C1', s=50, label=r'$\varepsilon = 1.00, \sigma = 0.50$')
ax.scatter([], [], c='C2', s=50, label=r'$\varepsilon = 4.00, \sigma = 1.00$')
ax.set_xticks([-1, -0.5, 0, 0.5, 1.0])
ax.set_xticklabels([-1, -0.5, 0, 0.5, 1.0], fontsize=12)
ax.set_yticks([-1, -0.5, 0, 0.5, 1.0])
ax.set_yticklabels([-1, -0.5, 0, 0.5, 1.0], fontsize=12)
ax.set_title(r'Spectral embedding of data points (fixed $n_{dim} = 4$)',
fontsize=14)
ax.set_xlabel(r'$\phi_1$', fontsize=14)
ax.set_ylabel(r'$\phi_2$', fontsize=14)
ax.legend()

plt.show()
74 changes: 74 additions & 0 deletions examples/simulated/plot_toy_classification.py
@@ -0,0 +1,74 @@
"""
=====================================================================
Illustrate classification accuracy versus class separability
=====================================================================
Generate several datasets containing data points from two-classes. Each class
is generated with a Riemannian Gaussian distribution centered at the class mean
and with the same dispersion sigma. The distance between the class means is
parametrized by Delta, which we make vary between zero and 5*sigma. We
illustrate how the accuracy of the MDM classifier varies when Delta increases.
"""
# Authors: Pedro Rodrigues <pedro.rodrigues@melix.org>
#
# License: BSD (3-clause)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score

from pyriemann.classification import MDM
from pyriemann.datasets import make_gaussian_blobs


print(__doc__)


###############################################################################
# Set general parameters for the illustrations


n_matrices = 100 # how many matrices to sample on each class
n_dim = 4 # dimensionality of the data points
sigma = 1.0 # dispersion of the Gaussian distributions
random_state = 42 # ensure reproducibility

###############################################################################
# Loop over different levels of separability between the classes
scores_array = []
deltas_array = np.linspace(0, 5*sigma, 5)

for delta in deltas_array:
# generate data points for a classification problem
X, y = make_gaussian_blobs(n_matrices=n_matrices,
n_dim=n_dim,
class_sep=delta,
class_disp=sigma,
random_state=random_state)

# which classifier to consider
clf = MDM()

# get the classification score for this setup
scores_array.append(
cross_val_score(clf, X, y, cv=5, scoring='roc_auc').mean())

scores_array = np.array(scores_array)

###############################################################################
# Plot the results
fig, ax = plt.subplots(figsize=(7.5, 5.9))
ax.plot(deltas_array, scores_array, lw=3.0, label=sigma)
ax.set_xticks([0, 1, 2, 3, 4, 5])
ax.set_xticklabels([0, 1, 2, 3, 4, 5], fontsize=12)
ax.set_yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
ax.set_yticklabels([0.5, 0.6, 0.7, 0.8, 0.9, 1.0], fontsize=12)
ax.set_xlabel(r'$\Delta/\sigma$', fontsize=14)
ax.set_ylabel(r'score', fontsize=12)
ax.set_title(r'Classification score Vs class separability ($n_{dim} = 4$)',
fontsize=12)
ax.grid(True)
ax.legend(loc='lower right', title=r'$\sigma$')

plt.show()
9 changes: 9 additions & 0 deletions pyriemann/datasets/__init__.py
@@ -0,0 +1,9 @@
from .sampling import sample_gaussian_spd, generate_random_spd_matrix
from .simulated import make_covariances, make_gaussian_blobs, make_outliers


__all__ = ["sample_gaussian_spd",
"generate_random_spd_matrix",
"make_covariances",
"make_gaussian_blobs",
"make_outliers"]

0 comments on commit 608fdfe

Please sign in to comment.