In [None]:
%pip install mne-icalabel

In [1]:
# Imports
import os
from ipywidgets import *
import numpy as np
import mne
from mne.preprocessing import ICA
from mne_icalabel import label_components
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

import utils

# Specify graph rendering method
# %matplotlib widget
plt.switch_backend("TkAgg")

In [2]:
DATASET_PATH = "./dataset"
FILENAME_TEMPLATE = "TMS-EEG-H_02_S1b_{}_{}.vhdr"

spTEP_pre_raw = mne.io.read_raw_brainvision(
    os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("spTEP", "pre")), preload=True
)
sampling_rate = spTEP_pre_raw.info["sfreq"]
events, event_dict = mne.events_from_annotations(spTEP_pre_raw)
tms_indices = [event[0] for event in events if event[2] == 1]

rsEEG_pre_raw = mne.io.read_raw_brainvision(
    os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("rsEEG", "pre")), preload=True
)

Extracting parameters from ./dataset/TMS-EEG-H_02_S1b_spTEP_pre.vhdr...
Setting channel info structure...
Reading 0 ... 2696199  =      0.000 ...   539.240 secs...
Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Extracting parameters from ./dataset/TMS-EEG-H_02_S1b_rsEEG_pre.vhdr...
Setting channel info structure...
Reading 0 ... 3984899  =      0.000 ...   796.980 secs...


In [3]:
# Plotting utilities
def plot_single_response(eeg_data, channel="Pz", tmin=-0.005, tmax=0.01):
    events, event_dict = mne.events_from_annotations(eeg_data)
    event_id = event_dict["Stimulus/S  1"]
    epochs = mne.Epochs(
        eeg_data,
        events,
        event_id=event_id,
        tmin=tmin,
        tmax=tmax,
        baseline=None,
        preload=True,
        picks=channel,
    )

    epochs.plot(picks=channel, n_epochs=1, show=True, scalings={"eeg": 50e-4})


def plot_average_epoch(epochs, start=-0.05, end=0.25):
    data = epochs.get_data()
    mean_responses = np.mean(data, axis=0)
    time_points = np.linspace(-1, 1, data.shape[2])
    selected_indices = np.where((time_points >= start) & (time_points <= end))
    for i, mean_response in enumerate(mean_responses):
        selected_data = mean_response[selected_indices]
        selected_time_points = time_points[selected_indices]
        plt.plot(selected_time_points, selected_data, label=f"Channel {i+1}")
    plt.xlabel("Time points")
    plt.ylabel("Mean response")
    plt.show()


def plot_response(eeg):
    utils.plot_average_response(eeg, tmin=-0.05, tmax=0.25)  # Check full response
    utils.plot_single_response(
        eeg, channel="Pz", tmin=-0.05, tmax=0.05
    )  # Check for TMS pulse

# Cleaning - spTEP

The paper of Bertazzoli et al. (2021) compares 4 pipelines: ARTIST, TMSEEG, TESA and SOUND-SSP-SIR, all of which work decently well in varying degrees. There are common steps, but TESA will be the one that will be most closely followed. The current steps are as follows:

1. Remove EOG
2. Remove TMS pulse
3. Downsample
4. **ICA - 1**
5. Bandpass - Notch filters
6. **ICA - 2**
7. Rereference

Currently, there is no demeaning or bad channel rejection present as in TESA. Demeaning is done before the TMS-pulse interpolation, and baseline correction should be done as last step after rereferencing.


In [4]:
plot_response(spTEP_pre_raw)

Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
151 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 151 events and 1501 original time points ...
1 bad epochs dropped
Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
150 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 150 events and 501 original time points ...
0 bad epochs dropped
Using matplotlib as 2D backend.


In [5]:
spTEP_copy = spTEP_pre_raw.copy()

## EOG removal


In [6]:
def remove_EOG(eeg_data):
    eeg_data.drop_channels(["HEOG", "VEOG"])

## TMS pulse removal


In [7]:
def calculate_range_indices(tms_index, start, end, sampling_rate):
    """
    start and end are positive in seconds
    sampling rate in Hz
    """
    samples_before = int(start * sampling_rate)
    samples_after = int(end * sampling_rate)

    start_index = max(0, tms_index - samples_before)
    end_index = tms_index + samples_after

    return start_index, end_index

In [8]:
def interpolate_TMS_pulse(eeg_data_raw, tms_indices, start, end, sampling_rate):
    eeg_data = eeg_data_raw.get_data()
    num_electrodes = eeg_data.shape[0]
    for tms_index in tms_indices:
        start_index, end_index = calculate_range_indices(
            tms_index, start, end, sampling_rate
        )
        for i in range(num_electrodes):
            x = [start_index - 2, start_index - 1, end_index + 1, end_index + 2]
            y = [
                eeg_data[i, start_index - 2],
                eeg_data[i, start_index - 1],
                eeg_data[i, end_index + 1],
                eeg_data[i, end_index + 2],
            ]
            x_new = np.arange(start_index, end_index + 1)

            interp_func = interp1d(x, y, kind="cubic")
            eeg_data[i, start_index : end_index + 1] = interp_func(x_new)

    eeg_data_raw._data = eeg_data

## Downsampling

The original data was captured with a sampling frequency of 5000 Hz. 1000 Hz is chosen as the frequency to be downsampled to, as this means that, following Nyquists theorem, the highest frequency that will be accurately recorded is 500 Hz, which should be more than enough for further analysis, as the gamma band is often referred to as 30-100 Hz.


In [9]:
def downsample(eeg_data, sample_rate=1000):
    eeg_data.resample(sample_rate, npad="auto")

## Epoching


In [10]:
def epoching(eeg_data):
    events, event_dict = mne.events_from_annotations(eeg_data)
    event_id = event_dict["Stimulus/S  1"]
    epochs = mne.Epochs(
        eeg_data,
        events,
        event_id=event_id,
        tmin=-1,
        tmax=1,
        baseline=None,
        preload=True,
    )
    return epochs

## Demeaning/detrending

Demeaning is achieved by subtracting each value from each electrode with the average value of the corresponding electrode, essentially bringing the means from all electrodes to 0.

> TODO: check if other way of demeaning on complete electrode is possible to move value near 0 or better yet on 0


In [11]:
def demean_epochs(epochs):
    data = epochs.get_data()
    demeaned_data = data - np.mean(data, axis=2, keepdims=True)
    epochs = mne.EpochsArray(
        demeaned_data, epochs.info, events=epochs.events, event_id=epochs.event_id
    )
    return epochs

## ICA - 1

The first ICA filter is mainly to remove the primary large artifacts such as muscle and electrical charge. If demeaning were applied now, a graph as below is the result.

This is implemented by first fitting ICA to the signal, and then applying the threshold formula used by the TESA software to each component to either keep or remove each ICA component.


In [12]:
def ICA_1(epoch_data, T=3.5, b1=0.011, b2=0.030, n_components=20):
    ica = ICA(n_components=n_components, random_state=97)
    ica.fit(epoch_data)

    # Credits to Arne Callaert for the following code
    sources = ica.get_sources(epoch_data)
    averaged_sources = sources.get_data().mean(axis=0)
    times = sources.times
    sfreq = sources.info["sfreq"]
    indices = np.where((times >= (b1 / 1000)) & (times <= (b2 / 1000)))
    print("indices:", indices)
    components_to_remove = []

    for i, component in enumerate(averaged_sources):
        base = len(times) / 2
        b1_index = int(base + (b1 * sfreq))
        b2_index = int(base + (b2 * sfreq))
        x = np.mean(np.abs(component[b1_index:b2_index]))
        y = np.mean(np.abs(component))
        if x / y > T:
            print("FOUND:", x / y)
            components_to_remove.append(i)

    ica.exclude = components_to_remove

    epoch_data = ica.apply(epoch_data)

## Bandpass - Notch


In [13]:
def bandpass_notch(epoch_data, low_freq=1, high_freq=100, notch_freqs=[50]):
    # Bandpass
    epoch_data.filter(low_freq, high_freq)

    # Notch (only directly available on raw object, not on epochs)
    data = epoch_data.get_data()
    notch_filtered = mne.filter.notch_filter(data, epochs.info["sfreq"], notch_freqs)
    filtered_epochs = mne.epochs.EpochsArray(
        notch_filtered, epochs.info, events=epochs.events, tmin=epochs.tmin
    )

    return filtered_epochs

## Rereference


In [14]:
def rereference(epochs):
    mne.set_eeg_reference(epochs, ref_channels="average")

## ICA - 2

Li et al., (2022). MNE-ICALabel: Automatically annotating ICA components with ICLabel in Python. Journal of Open Source Software, 7(76), 4484, https://doi.org/10.21105/joss.04484


In [15]:
def ICA_2(epoch_data):
    ica = mne.preprocessing.ICA(n_components=20, random_state=42)
    ica.fit(epoch_data)
    ic_labels = label_components(epoch_data, ica, method="iclabel")

    print(ic_labels["labels"])

    labels = ic_labels["labels"]
    exclude_idx = [
        idx for idx, label in enumerate(labels) if label not in ["brain", "other"]
    ]
    print(f"Excluding these {len(exclude_idx)} ICA components: {exclude_idx}")

    ica.apply(epoch_data, exclude=exclude_idx)

## Baseline correction


In [17]:
def baseline(epoch_data, baseline=(-0.5, -0.005)):
    epoch_data.apply_baseline((1 - baseline[0], 1 - baseline[1]))

## Final result


In [18]:
# function to apply all steps
def preprocess(eeg_data):
    remove_EOG(eeg_data)
    interpolate_TMS_pulse(eeg_data, tms_indices, 0.005, 0.015, sampling_rate)
    downsample(eeg_data)
    epochs = epoching(eeg_data)
    epochs = demean_epochs(epochs)
    ICA_1(epochs)
    bandpass_notch(epochs)
    ICA_2(epochs)
    rereference(epochs)
    baseline(epochs)

    return epochs

In [19]:
def plot_gmfp(epochs):
    psds, freqs = mne.time_frequency.psd_multitaper(epochs, fmin=0.1, fmax=100.0, n_jobs=-1)

    # Compute mean power across all epochs
    mean_power = np.mean(psds, axis=0)

    # Compute Global Mean Field Power (GMFP)
    gmfp = np.sqrt(np.mean(mean_power, axis=0))

    # Plot GMFP
    plt.figure()
    plt.plot(freqs, gmfp)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Global Mean Field Power (GMFP)')
    plt.title('Global Mean Field Power (GMFP)')
    plt.show()


In [20]:
epochs_cleaned = preprocess(spTEP_copy)

plot_gmfp(epochs_cleaned)

Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
150 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 150 events and 2001 original time points ...
0 bad epochs dropped
Not setting metadata
150 matching events found
No baseline correction applied
0 projection items activated
Fitting ICA to data using 62 channels (please be patient, this may take a while)


  data = epochs.get_data()
  ica.fit(epoch_data)


Selecting by number: 20 components
Fitting ICA took 9.6s.
indices: (array([], dtype=int64),)
FOUND: 5.256511308711403
FOUND: 3.9395458463802986
Applying ICA to Epochs instance
    Transforming to ICA space (20 components)
    Zeroing out 2 ICA components
    Projecting back using 62 PCA components


  averaged_sources = sources.get_data().mean(axis=0)


Setting up band-pass filter from 1 - 1e+02 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 100.00 Hz
- Upper transition bandwidth: 25.00 Hz (-6 dB cutoff frequency: 112.50 Hz)
- Filter length: 3301 samples (3.301 s)



  epoch_data.filter(low_freq, high_freq)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done  71 tasks      | elapsed:    0.0s
[Parallel(n_jobs=1)]: Done 161 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 287 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 449 tasks      | elapsed:    0.1s
[Parallel(n_jobs=1)]: Done 647 tasks      | elapsed:    0.2s
[Parallel(n_jobs=1)]: Done 881 tasks      | elapsed:    0.2s
[Parallel(n_jobs=1)]: Done 1151 tasks      | elapsed:    0.3s
[Parallel(n_jobs=1)]: Done 1457 tasks      | elapsed:    0.4s
[Parallel(n_jobs=1)]: Done 1799 tasks      | elapsed:    0.4s
[Parallel(n_jobs=1)]: Done 2177 tasks      | elapsed:    0.5s
[Parallel(n_jobs=1)]: Done 2591 tasks      | elapsed:    0.6s
[Parallel(n_jobs=1)]: Done 3041 tasks      | elapsed:    0.7s
[Parallel(n_jobs=1)]: Done 3527 tasks      | elapsed:    0.8s
[Parallel(n_jobs=1)]: Done 4049 tasks      | elapsed:    1.0s
[Parallel(n_jobs=1)]: Done 4607 task

NameError: name 'epochs' is not defined

## TODO

Current biggest things to find out:

- is there a way to further improve the filtering that ICA 1 is supposed to achieve? (Filtering out the residue of the TMS pulse)
- how can time ranges be plot on the scalp topography? like in the comparative paper