# EEG Decoding/MVPA Tutorial

The goal of this notebook is to demonstrate EEG decoding/MVPA techniques using some pilot data collected in our EEG Lab. Some of the contents are based on [this notebook](https://github.com/mne-tools/mne-workshops/blob/master/2018_06_Amsterdam/mne_notebook_3_mvpa.ipynb).

In [1]:
%matplotlib inline

import numpy as np                # data array manipulation 
import pandas as pd               # dataframe 
import matplotlib.pyplot as plt   # data visulization 

import mne                        # eeg toolbox

## Experimental Designs

First, Let's take a look at the stimuli and their properties.

In [2]:
df_stim = pd.read_csv('stimuli.csv', squeeze=True).set_index('id')
df_stim

Unnamed: 0_level_0,size,animacy,label
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,big,animate,big/animate/object001
2,big,animate,big/animate/object002
3,big,animate,big/animate/object003
4,big,animate,big/animate/object004
5,big,animate,big/animate/object005
6,big,animate,big/animate/object006
7,big,animate,big/animate/object007
8,big,animate,big/animate/object008
9,big,animate,big/animate/object009
10,big,animate,big/animate/object010


In [None]:
# count the number of stimuli
num_stim = df_stim.shape[0]
num_stim

In [None]:
# whether a stimulus is animate or not?
animate_stim_ids = df_stim[df_stim['animacy'] == 'animate'].index
animate_stim_ids

## From Raw to Epochs

In this section, we will epoch the raw EEG data so that the epoched data can be used for decoding.


### load raw EEG data
In this step, we will load the raw EEG data.

In [None]:
# raw eeg file
fname = './original-run1-raw.fif.gz'

# load eeg data
raw = mne.io.read_raw_fif(fname, preload=True)
raw

 ### extract event information 
 In this step, we'll make `events` and `event_id`, which are required inputs for epoching, by extracting annotations from the raw EEG data.

In [None]:
# extract annotations from raw
triggers, _ = mne.events_from_annotations(raw)

# the first 10 triggers
triggers[:10] 

Because not all triggers are related to the experimental design, we then filter only the experiment-related events.

In [None]:
events = triggers[triggers[:, -1] <= num_stim]

# the first 10 events
events[:10]

In [None]:
event_id = dict(zip(df_stim['label'], df_stim.index))
event_id

### make epochs

In [None]:
# Start time before event, End time after event 
tmin, tmax = -0.1, 0.9

# Channels to include (only eeg channels)
picks = mne.pick_types(raw.info, meg=False, eeg=True, eog=False)

epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks, decim=4)
epochs

***

## Animacy Decoding

Let's start predicting trial types (animate vs. inanimate) from EEG activity.

In [None]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # classification function
from sklearn.pipeline import make_pipeline                           # preprocessing pipeline constructor
from sklearn.preprocessing import StandardScaler                     # Standardize features

from mne.decoding import SlidingEstimator, cross_val_multiscore      # predict and score a series of models 

### prepare data

We will load `original-epo.fif.gz`, which contains epoched data from all runs. I've imported the recorded BrainVision file to MNE with `mne.io.read_raw_brainvision()`.

In [None]:
epochs = mne.read_epochs('original-epo.fif.gz')
times = epochs.times

X = epochs.get_data()
y = np.isin(epochs.events[:, -1], animate_stim_ids)

(X.shape, y.shape)

dimensions of X: `samples`, `channels`, `time`

### cateogory decoding
We want to investigate at which time points there is information about the stimulus category.

In [None]:
from mne.decoding import Vectorizer

clf = make_pipeline(StandardScaler(),
                    LinearDiscriminantAnalysis())

sl = SlidingEstimator(clf)

In [None]:
# independent measurement cross validation
scores = cross_val_multiscore(sl, X, y, cv=5)

In [None]:
scores.shape

In [None]:
# set plot parameters
plot_params = dict(
    ylabel='AUC', 
    title='Animacy Decoding', 
    xlim=(tmin, tmax)
)

In [None]:
fig, ax = plt.subplots()
ax.plot(epochs.times, scores.T)
ax.hlines(0.5, tmin, tmax, linestyle=':')
ax.set(**plot_params)
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.plot(epochs.times, scores.mean(0))
ax.hlines(0.5, tmin, tmax, linestyle=':')
ax.set(**plot_params)
plt.show()

### category decoding with generalization to novel exemplars

For this experiment, we want to make sure that the category decoding can be generalized to novel exemplars. Therefore, we will use independent exemplar [cross-validation](https://scikit-learn.org/stable/auto_examples/model_selection/plot_cv_indices.html#sphx-glr-auto-examples-model-selection-plot-cv-indices-py).

In [None]:
from sklearn.model_selection import GroupKFold


# X = epochs.get_data() 
# y = np.isin(epochs.events[:, -1], animate_event_ids)
groups = epochs.events[:, -1] # The same exemplar will not appear in different folds

gkf = GroupKFold(n_splits=5)
sl = SlidingEstimator(clf)

# independent exemplar cross validation
scores_with_generalization = cross_val_multiscore(sl, X, y, groups=groups, cv=gkf)

In [None]:
fig, ax = plt.subplots()
ax.plot(epochs.times, scores.mean(0), label='without generalization')
ax.plot(epochs.times, scores_with_generalization.mean(0), label='with generalization')
ax.hlines(0.5, tmin, tmax, linestyle=':')       # add chance level 
ax.set(**plot_params)
ax.legend()
plt.show()

## Time Generalizing Decoding
We can also investigate decoding with generalization across time.

In [None]:
from mne.decoding import GeneralizingEstimator

gkf = GroupKFold(n_splits=2)
gen = GeneralizingEstimator(clf)
scores_gen = cross_val_multiscore(gen, X, y, groups=groups, cv=gkf)

In [None]:
data = scores_gen.mean(0)
vmax = np.abs(data).max()

fig, ax = plt.subplots()
im = ax.imshow(
    data,
    origin="lower", cmap="RdBu_r",
    extent=(tmin, tmax, tmin, tmax),
    vmax=vmax, vmin=1-vmax);

plt.colorbar(im)
plt.show()

## Representational Similarity Analysis

### make RDM

In [None]:
epochs['big/animate/object001']

In [None]:
# get the pattern for each stimuli
patterns = np.array([epochs[event].get_data().mean(0) for event in event_id])
patterns.shape

In [None]:
from scipy.spatial.distance import pdist, squareform

rdms = [squareform(pdist(patterns[:,:,i], metric='correlation')) 
        for i in range(len(epochs.times))]
rdms = np.array(rdms)
rdms.shape

### visulize RDMs at different time points

In [None]:
tid1, tid2, tid3 = 10, 28, 70
time1, time2, time3 = epochs.times[tid1], epochs.times[tid2], epochs.times[tid3], 

(time1, time2, time3)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

im1 = axes[0].imshow(rdms[tid1], vmax=1.5, vmin=0)
im2 = axes[1].imshow(rdms[tid2], vmax=1.5, vmin=0)
im3 = axes[2].imshow(rdms[tid3], vmax=1.5, vmin=0)

axes[0].set_title(f'RDM at {time1} s')
axes[1].set_title(f'RDM at {time2} s')
axes[2].set_title(f'RDM at {time3} s')
plt.show()


### visualize MDS plots at different time points

In [None]:
from sklearn.manifold import MDS

def plot_mds(rdm, ax=None, colors=None, time=None):
    model = MDS(n_components=2, dissimilarity='precomputed', random_state=0)
    out = model.fit_transform(rdm)
    if ax is None:
        fig, ax = plt.subplots()
    ax.scatter(out[:,0], out[:,1], color=colors)
    if time is not None:
        ax.set_title(f'MDS at {time} s')
    

In [None]:
# set colors for animate and inanimate objects
color_mapping = dict(animate='purple', inanimate='pink')
colors = df_stim['animacy'].map(color_mapping).values

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
plot_mds(rdms[tid1], colors=colors, time=time1, ax=axes[0])
plot_mds(rdms[tid2], colors=colors, time=time2, ax=axes[1])
plot_mds(rdms[tid3], colors=colors, time=time3, ax=axes[2])
plt.show()
