Skip to content

Commit

Permalink
Add example on ERP (#144)
Browse files Browse the repository at this point in the history
* add plot_erp and complete test

* add example

* rename example on embedding

* correct flake8

* remove erp from doc

* remove chax parameter

* remove numbers

* minor modif

* add module viz in api

* define mne version

* replace kwargs by parameters
  • Loading branch information
qbarthelemy committed Nov 5, 2021
1 parent e285170 commit f06a33e
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 12 deletions.
18 changes: 13 additions & 5 deletions doc/api.rst
Expand Up @@ -108,7 +108,7 @@ Channel selection
:template: class.rst

ElectrodeSelection
FlatChannelRemover
FlatChannelRemover

Stats
------------------
Expand Down Expand Up @@ -156,7 +156,6 @@ Covariance preprocessing
normalize
get_nondiag_weight


Distances
~~~~~~~~~~~~~~~~~~~~~~
.. _distance_api:
Expand All @@ -174,7 +173,6 @@ Distances
distance_kullback_sym
distance_wasserstein


Mean
~~~~~~~~~~~~~~~~~~~~~~
.. _mean_api:
Expand All @@ -194,7 +192,6 @@ Mean
mean_harmonic
mean_kullback_sym


Geodesic
~~~~~~~~~~~~~~~~~~~~~~
.. _geodesic_api:
Expand All @@ -208,7 +205,6 @@ Geodesic
geodesic_euclid
geodesic_logeuclid


Tangent Space
~~~~~~~~~~~~~~~~~~~~~~
.. _ts_base_api:
Expand Down Expand Up @@ -246,3 +242,15 @@ Aproximate Joint Diagonalization
ajd_pham
uwedge

Visualization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. _viz_api:
.. currentmodule:: pyriemann.utils.viz

.. autosummary::
:toctree: generated/

plot_confusion_matrix
plot_embedding
plot_cospectra
plot_waveforms
2 changes: 1 addition & 1 deletion doc/requirements.txt
@@ -1,7 +1,7 @@
sphinx-gallery
sphinx-bootstrap_theme
numpydoc
mne
mne==0.23.4
scikit-learn
seaborn
pandas
Expand Down
2 changes: 2 additions & 0 deletions doc/whatsnew.rst
Expand Up @@ -30,6 +30,8 @@ v0.2.8.dev

- Add new function :func:`pyriemann.datasets.make_gaussian_blobs` for generating random datasets with SPD matrices

- Add module ``pyriemann.utils.viz`` in API, add :func:`pyriemann.utils.viz.plot_waveforms`, and add an example on ERP visualization

v0.2.7 (June 2021)
------------------

Expand Down
102 changes: 102 additions & 0 deletions examples/ERP/plot_ERP.py
@@ -0,0 +1,102 @@
"""
===============================================================================
Display ERP
===============================================================================
Different ways to display a multichannel event-related potential (ERP).
"""
# Authors: Quentin Barthélemy
#
# License: BSD (3-clause)

import numpy as np
import mne
from matplotlib import pyplot as plt
from pyriemann.utils.viz import plot_waveforms


###############################################################################
# Load EEG data
# -------------

# Set filenames
data_path = mne.datasets.sample.data_path()
raw_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw.fif"
event_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif"

# Read raw data, select occipital channels and high-pass filter signal
raw = mne.io.Raw(raw_fname, preload=True, verbose=False)
raw.pick_channels(['EEG 057', 'EEG 058', 'EEG 059'], ordered=True)
raw.rename_channels({'EEG 057': 'O1', 'EEG 058': 'Oz', 'EEG 059': 'O2'})
n_channels = len(raw.ch_names)
raw.filter(1.0, None, method="iir")

# Read epochs and get responses to left visual field stimulus
tmin, tmax = -0.1, 0.8
epochs = mne.Epochs(
raw, mne.read_events(event_fname), {'vis_l': 3}, tmin, tmax, proj=False,
baseline=None, preload=True, verbose=False)
X = 5e5 * epochs.get_data()
print('Number of trials:', X.shape[0])
times = np.linspace(tmin, tmax, num=X.shape[2])

plt.rcParams["figure.figsize"] = (7, 12)
ylims = []


###############################################################################
# Plot all trials
# ---------------
#
# This kind of plot is a little bit messy.

fig = plot_waveforms(X, 'all', times=times, alpha=0.3)
fig.suptitle('Plot all trials', fontsize=16)
for i_channel in range(n_channels):
fig.axes[i_channel].set(ylabel=raw.ch_names[i_channel])
fig.axes[i_channel].set_xlim(tmin, tmax)
ylims.append(fig.axes[i_channel].get_ylim())
fig.axes[n_channels - 1].set(xlabel='Time')
plt.show()


###############################################################################
# Plot central tendency and dispersion of trials
# ----------------------------------------------
#
# This kind of plot is well-spread, but mean and standard deviation can be
# contaminated by artifacts, and they make a symmetric assumption on amplitude
# distribution.

fig = plot_waveforms(X, 'mean+/-std', times=times)
fig.suptitle('Plot mean+/-std of trials', fontsize=16)
for i_channel in range(n_channels):
fig.axes[i_channel].set(ylabel=raw.ch_names[i_channel])
fig.axes[i_channel].set_xlim(tmin, tmax)
fig.axes[i_channel].set_ylim(ylims[i_channel])
fig.axes[n_channels - 1].set(xlabel='Time')
plt.show()


###############################################################################
# Plot histogram of trials
# ------------------------
#
# This plot estimates a 2D histogram of trials [1]_.

fig = plot_waveforms(X, 'hist', times=times, n_bins=25, cmap=plt.cm.Greys)
fig.suptitle('Plot histogram of trials', fontsize=16)
for i_channel in range(n_channels):
fig.axes[i_channel].set(ylabel=raw.ch_names[i_channel])
fig.axes[i_channel].set_ylim(ylims[i_channel])
fig.axes[n_channels - 1].set(xlabel='Time')
plt.show()


###############################################################################
# References
# ----------
# .. [1] A. Souloumiac and B. Rivet, "Improved estimation of EEG evoked
# potentials by jitter compensation and enhancing spatial filters", ICASSP,
# 2013.
@@ -1,6 +1,6 @@
"""
=====================================================================
Embedding ERP EEG data in 2D Euclidean space with Laplacian Eigenmaps
Embedding ERP MEG data in 2D Euclidean space with Laplacian Eigenmaps
=====================================================================
Spectral embedding via Laplacian Eigenmaps of a set of ERP data.
Expand Down
112 changes: 107 additions & 5 deletions pyriemann/utils/viz.py
Expand Up @@ -2,8 +2,8 @@
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
from pyriemann.embedding import Embedding
from pyriemann.utils import deprecated
from ..embedding import Embedding
from . import deprecated


@deprecated(
Expand Down Expand Up @@ -63,8 +63,8 @@ def plot_embedding(
return fig


def plot_cospectra(cosp, freqs, ylabels=None, title="Cospectra"):
"""Plot cospectral matrices
def plot_cospectra(cosp, freqs, *, ylabels=None, title="Cospectra"):
"""Plot cospectral matrices.
Parameters
----------
Expand All @@ -77,14 +77,25 @@ def plot_cospectra(cosp, freqs, ylabels=None, title="Cospectra"):
-------
fig : matplotlib figure
Figure of cospectra.
Notes
-----
.. versionadded:: 0.2.7
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("Install matplotlib to plot cospectra")

if cosp.ndim != 3:
raise Exception('Input cosp has not 3 dimensions')
n_freqs, n_channels, _ = cosp.shape
if freqs.shape != (n_freqs,):
raise Exception(
'Input freqs has not the same number of frequencies as cosp')

fig = plt.figure(figsize=(12, 7))
fig.suptitle(title)
n_freqs = min(cosp.shape[0], freqs.shape[0])
for f in range(n_freqs):
ax = plt.subplot((n_freqs - 1) // 8 + 1, 8, f + 1)
plt.imshow(cosp[f], cmap=plt.get_cmap("Reds"))
Expand All @@ -100,3 +111,94 @@ def plot_cospectra(cosp, freqs, ylabels=None, title="Cospectra"):
plt.yticks([])

return fig


def plot_waveforms(X, display, *, times=None, color='gray', alpha=0.5,
linewidth=1.5, color_mean='k', color_std='gray', n_bins=50,
cmap=None):
''' Display repetitions of a multichannel waveform.
Parameters
----------
X : ndarray, shape (n_reps, n_channels, n_times)
Repetitions of the multichannel waveform.
display : {'all', 'mean', 'mean+/-std', 'hist'}
Type of display:
* 'all' for all the repetitions;
* 'mean' for the mean of the repetitions;
* 'mean+/-std' for the mean +/- standard deviation of the repetitions;
* 'hist' for the 2D histogram of the repetitions.
time : None | ndarray, shape (n_times,) (default None)
Values to display on x-axis.
color : matplotlib color, optional
Color of the lines, when ``display=all``.
alpha : float, optional
Alpha value used to cumulate repetitions, when ``display=all``.
linewidth : float, optional
Line width in points, when ``display=mean``.
color_mean : matplotlib color, optional
Color of the mean line, when ``display=mean``.
color_std : matplotlib color, optional
Color of the standard deviation area, when ``display=mean+/-std``.
n_bins : int, optional
Number of vertical bins for the 2D histogram, when ``display=hist``.
cmap : Colormap or str, optional
Color map for the histogram, when ``display=hist``.
Returns
-------
fig : matplotlib figure
Figure of waveform (one subplot by channel).
Notes
-----
.. versionadded:: 0.2.8
'''
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("Install matplotlib to plot waveforms")

if X.ndim != 3:
raise Exception('Input X has not 3 dimensions')
n_reps, n_channels, n_times = X.shape
if times is None:
times = np.arange(n_times)
elif times.shape != (n_times,):
raise Exception(
'Parameter times has not the same number of values as X')

fig, axes = plt.subplots(nrows=n_channels, ncols=1)
if n_channels == 1:
axes = [axes]
channels = np.arange(n_channels)

if display == 'all':
for (channel, ax) in zip(channels, axes):
for i_rep in range(n_reps):
ax.plot(times, X[i_rep, channel], c=color, alpha=alpha)

elif display in ['mean', 'mean+/-std']:
mean = np.mean(X, axis=0)
for (channel, ax) in zip(channels, axes):
ax.plot(times, mean[channel], c=color_mean, lw=linewidth)
if display == 'mean+/-std':
std = np.std(X, axis=0)
for (channel, ax) in zip(channels, axes):
ax.fill_between(times, mean[channel] - std[channel],
mean[channel] + std[channel], color=color_std)

elif display == 'hist':
times_rep = np.repeat(times[np.newaxis, :], n_reps, axis=0)
for (channel, ax) in zip(channels, axes):
ax.hist2d(times_rep.ravel(), X[:, channel, :].ravel(),
bins=(n_times, n_bins), cmap=cmap)

else:
raise Exception('Parameter display unknown %s' % display)

if n_channels > 1:
for ax in axes[:-1]:
ax.set_xticklabels([]) # remove xticklabels
return fig
14 changes: 14 additions & 0 deletions tests/test_viz.py
Expand Up @@ -6,6 +6,7 @@
plot_confusion_matrix,
plot_embedding,
plot_cospectra,
plot_waveforms
)


Expand Down Expand Up @@ -36,3 +37,16 @@ def test_cospectra():
cosp = np.random.randn(n_freqs, n_channels, n_channels)
freqs = np.random.randn(n_freqs)
plot_cospectra(cosp, freqs)


@requires_matplotlib
@pytest.mark.parametrize("display", ["all", "mean", "mean+/-std", "hist"])
def test_plot_waveforms(display):
"""Test plot_waveforms"""
n_matrices, n_channels, n_times = 16, 3, 50
X = np.random.randn(n_matrices, n_channels, n_times)
plot_waveforms(X, display)
plot_waveforms(X, display, times=np.arange(n_times))

X = np.random.randn(n_matrices, 1, n_times)
plot_waveforms(X, display)

0 comments on commit f06a33e

Please sign in to comment.