Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example on ERP #144

Merged
merged 11 commits into from Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 13 additions & 5 deletions doc/api.rst
Expand Up @@ -108,7 +108,7 @@ Channel selection
:template: class.rst

ElectrodeSelection
FlatChannelRemover
FlatChannelRemover

Stats
------------------
Expand Down Expand Up @@ -156,7 +156,6 @@ Covariance preprocessing
normalize
get_nondiag_weight


Distances
~~~~~~~~~~~~~~~~~~~~~~
.. _distance_api:
Expand All @@ -174,7 +173,6 @@ Distances
distance_kullback_sym
distance_wasserstein


Mean
~~~~~~~~~~~~~~~~~~~~~~
.. _mean_api:
Expand All @@ -194,7 +192,6 @@ Mean
mean_harmonic
mean_kullback_sym


Geodesic
~~~~~~~~~~~~~~~~~~~~~~
.. _geodesic_api:
Expand All @@ -208,7 +205,6 @@ Geodesic
geodesic_euclid
geodesic_logeuclid


Tangent Space
~~~~~~~~~~~~~~~~~~~~~~
.. _ts_base_api:
Expand Down Expand Up @@ -246,3 +242,15 @@ Aproximate Joint Diagonalization
ajd_pham
uwedge

Visualization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. _viz_api:
.. currentmodule:: pyriemann.utils.viz

.. autosummary::
:toctree: generated/

plot_confusion_matrix
plot_embedding
plot_cospectra
plot_waveforms
2 changes: 1 addition & 1 deletion doc/requirements.txt
@@ -1,7 +1,7 @@
sphinx-gallery
sphinx-bootstrap_theme
numpydoc
mne
mne==0.23.4
scikit-learn
seaborn
pandas
Expand Down
2 changes: 2 additions & 0 deletions doc/whatsnew.rst
Expand Up @@ -30,6 +30,8 @@ v0.2.8.dev

- Add new function :func:`pyriemann.datasets.make_gaussian_blobs` for generating random datasets with SPD matrices

- Add module ``pyriemann.utils.viz`` in API, add :func:`pyriemann.utils.viz.plot_waveforms`, and add an example on ERP visualization

v0.2.7 (June 2021)
------------------

Expand Down
102 changes: 102 additions & 0 deletions examples/ERP/plot_ERP.py
@@ -0,0 +1,102 @@
"""
===============================================================================
Display ERP
===============================================================================

Different ways to display a multichannel event-related potential (ERP).

"""
# Authors: Quentin Barthélemy
#
# License: BSD (3-clause)

import numpy as np
import mne
from matplotlib import pyplot as plt
from pyriemann.utils.viz import plot_waveforms


###############################################################################
# Load EEG data
# -------------

# Set filenames
data_path = mne.datasets.sample.data_path()
raw_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw.fif"
event_fname = data_path + "/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif"

# Read raw data, select occipital channels and high-pass filter signal
raw = mne.io.Raw(raw_fname, preload=True, verbose=False)
raw.pick_channels(['EEG 057', 'EEG 058', 'EEG 059'], ordered=True)
raw.rename_channels({'EEG 057': 'O1', 'EEG 058': 'Oz', 'EEG 059': 'O2'})
n_channels = len(raw.ch_names)
raw.filter(1.0, None, method="iir")

# Read epochs and get responses to left visual field stimulus
tmin, tmax = -0.1, 0.8
epochs = mne.Epochs(
raw, mne.read_events(event_fname), {'vis_l': 3}, tmin, tmax, proj=False,
baseline=None, preload=True, verbose=False)
X = 5e5 * epochs.get_data()
print('Number of trials:', X.shape[0])
times = np.linspace(tmin, tmax, num=X.shape[2])

plt.rcParams["figure.figsize"] = (7, 12)
ylims = []


###############################################################################
# Plot all trials
# ---------------
#
# This kind of plot is a little bit messy.

fig = plot_waveforms(X, 'all', times=times, alpha=0.3)
fig.suptitle('Plot all trials', fontsize=16)
for i_channel in range(n_channels):
fig.axes[i_channel].set(ylabel=raw.ch_names[i_channel])
fig.axes[i_channel].set_xlim(tmin, tmax)
ylims.append(fig.axes[i_channel].get_ylim())
fig.axes[n_channels - 1].set(xlabel='Time')
plt.show()


###############################################################################
# Plot central tendency and dispersion of trials
# ----------------------------------------------
#
# This kind of plot is well-spread, but mean and standard deviation can be
# contaminated by artifacts, and they make a symmetric assumption on amplitude
# distribution.

fig = plot_waveforms(X, 'mean+/-std', times=times)
fig.suptitle('Plot mean+/-std of trials', fontsize=16)
for i_channel in range(n_channels):
fig.axes[i_channel].set(ylabel=raw.ch_names[i_channel])
fig.axes[i_channel].set_xlim(tmin, tmax)
fig.axes[i_channel].set_ylim(ylims[i_channel])
fig.axes[n_channels - 1].set(xlabel='Time')
plt.show()


###############################################################################
# Plot histogram of trials
# ------------------------
#
# This plot estimates a 2D histogram of trials [1]_.

fig = plot_waveforms(X, 'hist', times=times, n_bins=25, cmap=plt.cm.Greys)
fig.suptitle('Plot histogram of trials', fontsize=16)
for i_channel in range(n_channels):
fig.axes[i_channel].set(ylabel=raw.ch_names[i_channel])
fig.axes[i_channel].set_ylim(ylims[i_channel])
fig.axes[n_channels - 1].set(xlabel='Time')
plt.show()


###############################################################################
# References
# ----------
# .. [1] A. Souloumiac and B. Rivet, "Improved estimation of EEG evoked
# potentials by jitter compensation and enhancing spatial filters", ICASSP,
# 2013.
@@ -1,6 +1,6 @@
"""
=====================================================================
Embedding ERP EEG data in 2D Euclidean space with Laplacian Eigenmaps
Embedding ERP MEG data in 2D Euclidean space with Laplacian Eigenmaps
=====================================================================

Spectral embedding via Laplacian Eigenmaps of a set of ERP data.
Expand Down
112 changes: 107 additions & 5 deletions pyriemann/utils/viz.py
Expand Up @@ -2,8 +2,8 @@
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
from pyriemann.embedding import Embedding
from pyriemann.utils import deprecated
from ..embedding import Embedding
from . import deprecated


@deprecated(
Expand Down Expand Up @@ -63,8 +63,8 @@ def plot_embedding(
return fig


def plot_cospectra(cosp, freqs, ylabels=None, title="Cospectra"):
"""Plot cospectral matrices
def plot_cospectra(cosp, freqs, *, ylabels=None, title="Cospectra"):
"""Plot cospectral matrices.

Parameters
----------
Expand All @@ -77,14 +77,25 @@ def plot_cospectra(cosp, freqs, ylabels=None, title="Cospectra"):
-------
fig : matplotlib figure
Figure of cospectra.

Notes
-----
.. versionadded:: 0.2.7
"""
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("Install matplotlib to plot cospectra")

if cosp.ndim != 3:
raise Exception('Input cosp has not 3 dimensions')
n_freqs, n_channels, _ = cosp.shape
if freqs.shape != (n_freqs,):
raise Exception(
'Input freqs has not the same number of frequencies as cosp')

fig = plt.figure(figsize=(12, 7))
fig.suptitle(title)
n_freqs = min(cosp.shape[0], freqs.shape[0])
for f in range(n_freqs):
ax = plt.subplot((n_freqs - 1) // 8 + 1, 8, f + 1)
plt.imshow(cosp[f], cmap=plt.get_cmap("Reds"))
Expand All @@ -100,3 +111,94 @@ def plot_cospectra(cosp, freqs, ylabels=None, title="Cospectra"):
plt.yticks([])

return fig


def plot_waveforms(X, display, *, times=None, color='gray', alpha=0.5,
linewidth=1.5, color_mean='k', color_std='gray', n_bins=50,
cmap=None):
''' Display repetitions of a multichannel waveform.

Parameters
----------
X : ndarray, shape (n_reps, n_channels, n_times)
Repetitions of the multichannel waveform.
display : {'all', 'mean', 'mean+/-std', 'hist'}
Type of display:

* 'all' for all the repetitions;
* 'mean' for the mean of the repetitions;
* 'mean+/-std' for the mean +/- standard deviation of the repetitions;
* 'hist' for the 2D histogram of the repetitions.
time : None | ndarray, shape (n_times,) (default None)
Values to display on x-axis.
color : matplotlib color, optional
Color of the lines, when ``display=all``.
alpha : float, optional
Alpha value used to cumulate repetitions, when ``display=all``.
linewidth : float, optional
Line width in points, when ``display=mean``.
color_mean : matplotlib color, optional
Color of the mean line, when ``display=mean``.
color_std : matplotlib color, optional
Color of the standard deviation area, when ``display=mean+/-std``.
n_bins : int, optional
Number of vertical bins for the 2D histogram, when ``display=hist``.
cmap : Colormap or str, optional
Color map for the histogram, when ``display=hist``.

Returns
-------
fig : matplotlib figure
Figure of waveform (one subplot by channel).

Notes
-----
.. versionadded:: 0.2.8
'''
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError("Install matplotlib to plot waveforms")

if X.ndim != 3:
raise Exception('Input X has not 3 dimensions')
n_reps, n_channels, n_times = X.shape
if times is None:
times = np.arange(n_times)
elif times.shape != (n_times,):
raise Exception(
'Parameter times has not the same number of values as X')

fig, axes = plt.subplots(nrows=n_channels, ncols=1)
if n_channels == 1:
axes = [axes]
channels = np.arange(n_channels)

if display == 'all':
for (channel, ax) in zip(channels, axes):
for i_rep in range(n_reps):
ax.plot(times, X[i_rep, channel], c=color, alpha=alpha)

elif display in ['mean', 'mean+/-std']:
mean = np.mean(X, axis=0)
for (channel, ax) in zip(channels, axes):
ax.plot(times, mean[channel], c=color_mean, lw=linewidth)
if display == 'mean+/-std':
std = np.std(X, axis=0)
for (channel, ax) in zip(channels, axes):
ax.fill_between(times, mean[channel] - std[channel],
mean[channel] + std[channel], color=color_std)

elif display == 'hist':
times_rep = np.repeat(times[np.newaxis, :], n_reps, axis=0)
for (channel, ax) in zip(channels, axes):
ax.hist2d(times_rep.ravel(), X[:, channel, :].ravel(),
bins=(n_times, n_bins), cmap=cmap)

else:
raise Exception('Parameter display unknown %s' % display)

if n_channels > 1:
for ax in axes[:-1]:
ax.set_xticklabels([]) # remove xticklabels
return fig
14 changes: 14 additions & 0 deletions tests/test_viz.py
Expand Up @@ -6,6 +6,7 @@
plot_confusion_matrix,
plot_embedding,
plot_cospectra,
plot_waveforms
)


Expand Down Expand Up @@ -36,3 +37,16 @@ def test_cospectra():
cosp = np.random.randn(n_freqs, n_channels, n_channels)
freqs = np.random.randn(n_freqs)
plot_cospectra(cosp, freqs)


@requires_matplotlib
@pytest.mark.parametrize("display", ["all", "mean", "mean+/-std", "hist"])
def test_plot_waveforms(display):
"""Test plot_waveforms"""
n_matrices, n_channels, n_times = 16, 3, 50
X = np.random.randn(n_matrices, n_channels, n_times)
plot_waveforms(X, display)
plot_waveforms(X, display, times=np.arange(n_times))

X = np.random.randn(n_matrices, 1, n_times)
plot_waveforms(X, display)
qbarthelemy marked this conversation as resolved.
Show resolved Hide resolved