In [6]:
%conda env export > environment.yaml
# %conda env create -f environment.yaml


Note: you may need to restart the kernel to use updated packages.


In [29]:
%pip install mne scikit-learn mne-icalabel torch

Collecting torch
  Downloading torch-2.2.1-cp312-cp312-manylinux1_x86_64.whl.metadata (26 kB)
Collecting filelock (from torch)
  Downloading filelock-3.13.1-py3-none-any.whl.metadata (2.8 kB)
Collecting sympy (from torch)
  Downloading sympy-1.12-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.2.1-py3-none-any.whl.metadata (5.2 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Downloading 

# TMS pipeline
This notebook contains the full pipeline of the individual scripts. The pipeline goes as follows:
- EDA
- Prepocessing
- Model training
- Validation

In [1]:
# Imports
import os
import pandas as pd
from ipywidgets import *
import numpy as np
import mne
from mne.preprocessing import ICA
import torch
from mne_icalabel import label_components

import matplotlib.pyplot as plt
from mne.preprocessing import ICA
import scipy

import preprocessing
import utils

# Specify graph rendering method
# %matplotlib widget
plt.switch_backend('TkAgg')

## Data loading
Currently, only one session from one patient gets used for testing purposes. File names are as follows: "**TMS-EEG-H_02_S1b_X_Y.Z**", where:
- X = **rsEEG** (resting state EEG) or **spTEP** (single pulse TMS Evoked Potential)
- Y = **pre** or **post** (before or after the rTMS procedure)
- Z = **vhdr**, **vmrk**, **eeg** or **mat** (files for the BrainVision format, and a MATLAB file)

In [2]:
# Currently for 1 patient, will be generalized into a pipeline for all patients

DATASET_PATH = './dataset'
FILENAME_TEMPLATE = "TMS-EEG-H_02_S1b_{}_{}.vhdr"

# rsEEG_pre_raw = mne.io.read_raw_brainvision(os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("rsEEG", "pre")), preload=True)
spTEP_pre_raw = mne.io.read_raw_brainvision(os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("spTEP", "pre")), preload=True)
spTEP_pre_raw.drop_channels(['HEOG', 'VEOG'])

#rsEEG_post_raw = mne.io.read_raw_brainvision(os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("rsEEG", "post")), preload=True)
# spTEP_post_raw = mne.io.read_raw_brainvision(os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("spTEP", "post")), preload=True)
# spTEP_post_raw.drop_channels(['HEOG', 'VEOG'])


Extracting parameters from ./dataset/TMS-EEG-H_02_S1b_spTEP_pre.vhdr...
Setting channel info structure...
Reading 0 ... 2696199  =      0.000 ...   539.240 secs...


0,1
Measurement date,"August 23, 2017 15:29:56 GMT"
Experimenter,Unknown
Participant,Unknown

0,1
Digitized points,67 points
Good channels,62 EEG
Bad channels,
EOG channels,Not available
ECG channels,Not available

0,1
Sampling frequency,5000.00 Hz
Highpass,0.00 Hz
Lowpass,1000.00 Hz
Filenames,TMS-EEG-H_02_S1b_spTEP_pre.eeg
Duration,00:08:60 (HH:MM:SS)


## EDA
### EEG Visualization
Visualizing the EEG graphs and electrodes for all the files gives a first impression of the data, and possible immediate preprocessing changes that can take place.

spTEP_pre: TP9 is bad channel

Perfect overlap on the Fp2 and VEOG channels, likely because the Fp2 was also used as a reference for VEOG. This overlap does result in errors, so it's better to remove the VEOG channel and keep this reference in mind.

In [4]:
spTEP_pre_raw.plot(start=60, duration=10, n_channels=20, scalings={'eeg': 50e-6})

Using matplotlib as 2D backend.


<MNEBrowseFigure size 1920x1274 with 4 Axes>

Channels marked as bad:
none


In [5]:
utils.plot_single_response(spTEP_pre_raw, channel="Pz")

Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
150 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 150 events and 1251 original time points ...
0 bad epochs dropped


Dropped 0 epochs: 
The following epochs were marked as bad and are dropped:
[]
Channels marked as bad:
none


In [27]:
print(spTEP_pre_raw.info)

<Info | 8 non-empty values
 bads: []
 ch_names: Iz, O2, Oz, O1, PO8, PO4, POz, PO3, PO7, P8, P6, P4, P2, Pz, P1, ...
 chs: 62 EEG
 custom_ref_applied: False
 dig: 67 items (3 Cardinal, 64 EEG)
 highpass: 0.0 Hz
 lowpass: 1000.0 Hz
 meas_date: 2017-08-23 15:29:56 UTC
 nchan: 62
 projs: []
 sfreq: 5000.0 Hz
>


<Figure size 640x640 with 1 Axes>

In [6]:
utils.plot_average_response(spTEP_pre_raw)

Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
151 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 151 events and 1251 original time points ...


1 bad epochs dropped


In [76]:
# Clearly indicates where the coils were placed
spTEP_pre_raw.compute_psd().plot_topomap()

Effective window size : 0.410 (s)


<Figure size 2249x427 with 10 Axes>

## Preprocessing

In [3]:
# Preprocessing with interpolation
spTEP_pre_raw.info['bads'] = ['TP9']
spTEP_pre_raw.interpolate_bads(reset_bads=True)

spTEP_pre = spTEP_pre_raw.copy()
spTEP_pre = preprocessing.preprocess(spTEP_pre)

# spTEP_post = spTEP_post_raw.copy()
# spTEP_post = preprocessing.preprocess(spTEP_post)

Setting channel interpolation method to {'eeg': 'spline'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 95.0 mm
Computing interpolation matrix from 61 sensor positions
Interpolating 1 sensors
Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
151 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 151 events and 251 original time points ...
139 bad epochs dropped
Fitting ICA to data using 62 channels (please be patient, this may take a while)
Selecting by number: 20 components


  ica.fit(epochs)


Fitting ICA took 0.2s.
Effective window size : 2.048 (s)
Applying ICA to Raw instance
    Transforming to ICA space (20 components)
    Zeroing out 14 ICA components
    Projecting back using 62 PCA components
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 90 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 90.00 Hz
- Upper transition bandwidth: 22.50 Hz (-6 dB cutoff frequency: 101.25 Hz)
- Filter length: 3301 samples (3.301 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter from 49 - 51 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 49.38
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 49.12 Hz)
- Upper passband edge: 50.62 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 50.88 Hz)
- Filter length: 6601 samples (6.601 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s


Not setting metadata
151 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 151 events and 251 original time points ...
139 bad epochs dropped
Fitting ICA to data using 62 channels (please be patient, this may take a while)
Selecting by number: 20 components
Fitting ICA took 1.4s.
Effective window size : 2.048 (s)




Applying ICA to Raw instance
    Transforming to ICA space (20 components)
    Zeroing out 17 ICA components
    Projecting back using 62 PCA components
EEG channel type selected for re-referencing
Applying a custom ('EEG',) reference.


In [44]:
utils.plot_average_response(spTEP_pre_raw, tmin=-0.005, tmax=0.2)

Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
151 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 151 events and 1026 original time points ...
1 bad epochs dropped


In [4]:
# Full pipeline
utils.plot_average_response(spTEP_pre, tmin=-0.05, tmax=0.2)
# utils.plot_average_response(spTEP_post, tmin=-0.05, tmax=0.2)

Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
151 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 151 events and 251 original time points ...
1 bad epochs dropped


In [7]:
FILENAME_TEMPLATE = "TMS-EEG-H_{:02d}_S{}_{}_{}"
filename = FILENAME_TEMPLATE.format(2, "1b", "spTEP", "pre")
spTEP_pre.save(os.path.join(".", "cleaned", filename + ".fif"), overwrite=True)

Writing /home/tomasgalle/UGent/thesis/tms-research/cleaned/TMS-EEG-H_02_S1b_spTEP_pre.fif


  spTEP_pre.save(os.path.join(".", "cleaned", filename + ".fif"), overwrite=True)


Closing /home/tomasgalle/UGent/thesis/tms-research/cleaned/TMS-EEG-H_02_S1b_spTEP_pre.fif
[done]


In [None]:
# Apply preprocessing to all available data and save for later use

DATASET_PATH_FULL = '../neuroaa/datasets/raw/uz_gent'
FILENAME_TEMPLATE = "TMS-EEG-H_{:02d}_S{}_{}_{}"

for patient in range(19):
    for session in range(4):
        for type in ["spTEP", "rsEEG"]:
            for trial in ["pre", "post"]:
                filename = FILENAME_TEMPLATE.format(patient, session, type, trial)
                eeg_raw = mne.io.read_raw_brainvision(os.path.join(DATASET_PATH, filename + ".vhdr"), preload=True)
                eeg_clean = eeg_raw.copy()
                eeg_clean = preprocessing.preprocess(eeg_clean)
                eeg_clean.save(os.path.join(".", "cleaned_data", filename + ".fif"), overwrite=True)

## Feature extraction
Features are per pulse (spTEP) or epoch (rsEEG)

### Time domain

In [42]:
# Peak amplitude, peak latency, and area under curve
events, event_dict = mne.events_from_annotations(spTEP_pre)
epochs = mne.Epochs(spTEP_pre, events, event_id=event_dict, tmin=-0.05, tmax=0.2, baseline=None, preload=True)

features = []

for epoch in epochs.get_data():
    peak_amplitude = np.max(epoch)
    peak_latency = np.argmax(epoch) / epochs.info['sfreq']
    area_under_curve = np.mean(np.trapz(epoch))

    features.append([peak_amplitude, peak_latency, area_under_curve])

len(features)

Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
151 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 151 events and 251 original time points ...
1 bad epochs dropped


  for epoch in epochs.get_data():


150

In [4]:
# Plot difference of features between pre and post

events, event_dict = mne.events_from_annotations(spTEP_pre)
epochs = mne.Epochs(spTEP_pre, events, event_id=event_dict, tmin=-0.05, tmax=0.2, baseline=None, preload=True)
data = epochs.get_data()
average_response = epochs.average().data

events_post, event_dict_post = mne.events_from_annotations(spTEP_post)
epochs_post = mne.Epochs(spTEP_post, events_post, event_id=event_dict_post, tmin=-0.05, tmax=0.2, baseline=None, preload=True)
data_post = epochs_post.get_data()
average_response_post = epochs_post.average().data

average_max = average_response.max(axis=1)
mean = data.mean(axis=(0, 2))
std_dev = data.std(axis=(0, 2))
skewness = scipy.stats.skew(data, axis=(0, 2))
kurtosis = scipy.stats.kurtosis(data, axis=(0, 2))

average_max_post = average_response_post.max(axis=1)
mean_post = data_post.mean(axis=(0, 2))
std_dev_post = data_post.std(axis=(0, 2))
skewness_post = scipy.stats.skew(data_post, axis=(0, 2))
kurtosis_post = scipy.stats.kurtosis(data_post, axis=(0, 2))

fig, axs = plt.subplots(3, 2)
axs[0, 0].scatter(mean, mean_post)
axs[0, 0].set_xlabel('Mean (pre)')
axs[0, 0].set_ylabel('Mean (post)')

axs[0, 1].scatter(std_dev, std_dev_post)
axs[0, 1].set_xlabel('Standard Deviation (pre)')
axs[0, 1].set_ylabel('Standard Deviation (post)')

axs[1, 0].scatter(skewness, skewness_post)
axs[1, 0].set_xlabel('Skewness (pre)')
axs[1, 0].set_ylabel('Skewness (post)')

axs[1, 1].scatter(kurtosis, kurtosis_post)
axs[1, 1].set_xlabel('Kurtosis (pre)')
axs[1, 1].set_ylabel('Kurtosis (post)')

axs[2, 0].scatter(average_max, average_max_post)
axs[2, 0].set_xlabel('Max Average Response (pre)')
axs[2, 0].set_ylabel('Max Average Response (post)')

# Remove the unused subplot
fig.delaxes(axs[2,1])

plt.tight_layout()
plt.show()

Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
151 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 151 events and 251 original time points ...
1 bad epochs dropped
Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
150 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 150 events and 251 original time points ...
1 bad epochs dropped


  data = epochs.get_data()
  data_post = epochs_post.get_data()


### Frequency domain

In [5]:
bands = {'delta': (0.5, 4),
         'theta': (4, 8),
         'alpha': (8, 12),
         'beta': (12, 30),
         'gamma': (30, 50)}

In [None]:
# Average power per frequency band
psd = spTEP_pre.compute_psd(fmin=1.0, fmax=100.0)
psd.plot(average=True)

freqs = psd.freqs

# Initialize a dictionary to hold the average power for each band
avg_power = {}

for band, (fmin, fmax) in bands.items():
    band_indices = np.where((freqs >= fmin) & (freqs <= fmax))[0]
    band_psd = psd.get_data()[:, band_indices]
    avg_power[band] = np.mean(band_psd)

print(avg_power)

In [40]:
# Difference in average power
psd_pre = epochs.compute_psd(fmin=1.0, fmax=100.0)
psd_post = epochs_post.compute_psd(fmin=1.0, fmax=100.0)

psd_pre_avg = psd_pre.average()
psd_post_avg = psd_post.average()

psd_diff = psd_post_avg.get_data() - psd_pre_avg.get_data()

fig, ax = plt.subplots()

# Plot the difference PSD
ax.plot(psd_pre.freqs, psd_diff[0], label='Post - Pre')

# Add a legend
ax.legend()

# Show the plot
plt.show()

    Using multitaper spectrum estimation with 7 DPSS windows


NameError: name 'epochs_post' is not defined

### Time-frequency domain

In [6]:
# Compute the wavelet transform of the data
events, event_dict = mne.events_from_annotations(spTEP_pre)
epochs = mne.Epochs(spTEP_pre, events, event_id=event_dict, tmin=0, tmax=4, baseline=None, preload=True)

frequencies = np.arange(1, 100)

wavelets = mne.time_frequency.tfr_morlet(epochs, freqs=frequencies, n_cycles=2)

Used Annotations descriptions: ['New Segment/', 'Stimulus/S  1']
Not setting metadata
151 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 151 events and 4001 original time points ...
0 bad epochs dropped


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:  1.6min


In [7]:
# INDIVIDUAL PLOTS

power = wavelets[0].data
avg_power = np.mean(power, axis=0)

fig, ax = plt.subplots()

# Create the heatmap
cax = ax.imshow(avg_power, aspect='auto', cmap='hot', origin='lower')

# Add a colorbar
fig.colorbar(cax)

# Set the labels for the x and y axes and the title
ax.set_xlabel('Time')
ax.set_ylabel('Frequency')
ax.set_title('Average Power')

# Show the plot
plt.show()

# GROUPED PLOTS
freqs = wavelets[0].freqs

n_rows = len(bands)
n_cols = 1

# Create a figure with multiple subplots
fig, axs = plt.subplots(n_rows, n_cols, figsize=(10, 20))

for ax, (band, (fmin, fmax)) in zip(axs, bands.items()):
    # Find the indices that correspond to this frequency band
    band_indices = np.where((freqs >= fmin) & (freqs <= fmax))[0]

    # Slice the power data to include only these frequencies
    band_power = power[:, band_indices, :]

    # Compute the average power across electrodes
    avg_power = np.mean(band_power, axis=0)

    # Create the heatmap
    cax = ax.imshow(avg_power, aspect='auto', cmap='hot', origin='lower')

    ax.set_xlim([None, 200])

    # Add a colorbar
    fig.colorbar(cax, ax=ax)

    # Set the labels for the x and y axes and the title
    ax.set_xlabel('Time')
    ax.set_ylabel('Frequency')
    ax.set_title(f'Average Power ({band} band)')

# Show the plot
plt.tight_layout()
plt.show()

## Clustering