In [None]:
# %conda env export > environment.yaml
%conda env create -f environment.yaml

In [None]:
# %pip install mne pandas numpy scikit-learn mne-icalabel
%pip install ipywidgets

# TMS pipeline

This notebook contains the full pipeline of the individual scripts. The pipeline goes as follows:

- EDA
- Prepocessing
- Model training
- Validation


In [None]:
# Imports
import os
import pandas as pd
from ipywidgets import *
import numpy as np
import mne
from mne.preprocessing import ICA
# import torch
from mne_icalabel import label_components

import matplotlib.pyplot as plt
from mne.preprocessing import ICA
import scipy

import preprocessing
import utils

# Specify graph rendering method
# %matplotlib widget
plt.switch_backend('TkAgg')

## Data loading

Currently, only one session from one patient gets used for testing purposes. File names are as follows: "**TMS-EEG-H_02_S1b_X_Y.Z**", where:

- X = **rsEEG** (resting state EEG) or **spTEP** (single pulse TMS Evoked Potential)
- Y = **pre** or **post** (before or after the rTMS procedure)
- Z = **vhdr**, **vmrk**, **eeg** or **mat** (files for the BrainVision format, and a MATLAB file)


In [None]:
# Currently for 1 patient, will be generalized into a pipeline for all patients

DATASET_PATH = './dataset'
FILENAME_TEMPLATE = "TMS-EEG-H_02_S1b_{}_{}.vhdr"

# rsEEG_pre_raw = mne.io.read_raw_brainvision(os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("rsEEG", "pre")), preload=True)
spTEP_pre_raw = mne.io.read_raw_brainvision(os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("spTEP", "pre")), preload=True)
spTEP_pre_raw.drop_channels(['HEOG', 'VEOG'])

#rsEEG_post_raw = mne.io.read_raw_brainvision(os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("rsEEG", "post")), preload=True)
# spTEP_post_raw = mne.io.read_raw_brainvision(os.path.join(DATASET_PATH, FILENAME_TEMPLATE.format("spTEP", "post")), preload=True)
# spTEP_post_raw.drop_channels(['HEOG', 'VEOG'])


## EDA

### EEG Visualization

Visualizing the EEG graphs and electrodes for all the files gives a first impression of the data, and possible immediate preprocessing changes that can take place.

spTEP_pre: TP9 is bad channel

Perfect overlap on the Fp2 and VEOG channels, likely because the Fp2 was also used as a reference for VEOG. This overlap does result in errors, so it's better to remove the VEOG channel and keep this reference in mind.


In [None]:
clone = spTEP_pre_raw.copy()

In [None]:
clone = preprocessing.preprocess(clone)

In [None]:
utils.plot_average_response(clone)

In [None]:
spTEP_pre_raw.plot(start=60, duration=10, n_channels=20, scalings={'eeg': 50e-6})

In [None]:
utils.plot_single_response(spTEP_pre_raw, channel="Pz", tmin=-0.05, tmax=0.05)

In [None]:
print(spTEP_pre_raw.info)

In [None]:
utils.plot_average_response(spTEP_pre_raw, tmin=-0.01, tmax=0.05)

In [None]:
# Clearly indicates where the coils were placed
spTEP_pre_raw.compute_psd().plot_topomap()

## Preprocessing


In [None]:
# Preprocessing with interpolation
spTEP_pre_raw.info['bads'] = ['TP9']
spTEP_pre_raw.interpolate_bads(reset_bads=True)

spTEP_pre = spTEP_pre_raw.copy()
spTEP_pre = preprocessing.preprocess(spTEP_pre)

# spTEP_post = spTEP_post_raw.copy()
# spTEP_post = preprocessing.preprocess(spTEP_post)

In [None]:
utils.plot_average_response(spTEP_pre_raw, tmin=-0.005, tmax=0.2)

In [None]:
# Full pipeline
utils.plot_average_response(spTEP_pre, tmin=-0.05, tmax=0.2)
# utils.plot_average_response(spTEP_post, tmin=-0.05, tmax=0.2)

In [None]:
FILENAME_TEMPLATE = "TMS-EEG-H_{:02d}_S{}_{}_{}"
filename = FILENAME_TEMPLATE.format(2, "1b", "spTEP", "pre")
spTEP_pre.save(os.path.join(".", "cleaned", filename + ".fif"), overwrite=True)

In [None]:
# Apply preprocessing to all available data and save for later use

DATASET_PATH_FULL = '../neuroaa/datasets/raw/uz_gent'
FILENAME_TEMPLATE = "TMS-EEG-H_{:02d}_S{}_{}_{}"

for patient in range(19):
    for session in range(4):
        for type in ["spTEP", "rsEEG"]:
            for trial in ["pre", "post"]:
                filename = FILENAME_TEMPLATE.format(patient, session, type, trial)
                eeg_raw = mne.io.read_raw_brainvision(os.path.join(DATASET_PATH, filename + ".vhdr"), preload=True)
                eeg_clean = eeg_raw.copy()
                eeg_clean = preprocessing.preprocess(eeg_clean)
                eeg_clean.save(os.path.join(".", "cleaned_data", filename + ".fif"), overwrite=True)

## Feature extraction

Features are per pulse (spTEP) or epoch (rsEEG)


### Time domain


In [None]:
# Peak amplitude, peak latency, and area under curve
events, event_dict = mne.events_from_annotations(spTEP_pre)
epochs = mne.Epochs(spTEP_pre, events, event_id=event_dict, tmin=-0.05, tmax=0.2, baseline=None, preload=True)

features = []

for epoch in epochs.get_data():
    peak_amplitude = np.max(epoch)
    peak_latency = np.argmax(epoch) / epochs.info['sfreq']
    area_under_curve = np.mean(np.trapz(epoch))

    features.append([peak_amplitude, peak_latency, area_under_curve])

len(features)

In [None]:
# Plot difference of features between pre and post

events, event_dict = mne.events_from_annotations(spTEP_pre)
epochs = mne.Epochs(spTEP_pre, events, event_id=event_dict, tmin=-0.05, tmax=0.2, baseline=None, preload=True)
data = epochs.get_data()
average_response = epochs.average().data

events_post, event_dict_post = mne.events_from_annotations(spTEP_post)
epochs_post = mne.Epochs(spTEP_post, events_post, event_id=event_dict_post, tmin=-0.05, tmax=0.2, baseline=None, preload=True)
data_post = epochs_post.get_data()
average_response_post = epochs_post.average().data

average_max = average_response.max(axis=1)
mean = data.mean(axis=(0, 2))
std_dev = data.std(axis=(0, 2))
skewness = scipy.stats.skew(data, axis=(0, 2))
kurtosis = scipy.stats.kurtosis(data, axis=(0, 2))

average_max_post = average_response_post.max(axis=1)
mean_post = data_post.mean(axis=(0, 2))
std_dev_post = data_post.std(axis=(0, 2))
skewness_post = scipy.stats.skew(data_post, axis=(0, 2))
kurtosis_post = scipy.stats.kurtosis(data_post, axis=(0, 2))

fig, axs = plt.subplots(3, 2)
axs[0, 0].scatter(mean, mean_post)
axs[0, 0].set_xlabel('Mean (pre)')
axs[0, 0].set_ylabel('Mean (post)')

axs[0, 1].scatter(std_dev, std_dev_post)
axs[0, 1].set_xlabel('Standard Deviation (pre)')
axs[0, 1].set_ylabel('Standard Deviation (post)')

axs[1, 0].scatter(skewness, skewness_post)
axs[1, 0].set_xlabel('Skewness (pre)')
axs[1, 0].set_ylabel('Skewness (post)')

axs[1, 1].scatter(kurtosis, kurtosis_post)
axs[1, 1].set_xlabel('Kurtosis (pre)')
axs[1, 1].set_ylabel('Kurtosis (post)')

axs[2, 0].scatter(average_max, average_max_post)
axs[2, 0].set_xlabel('Max Average Response (pre)')
axs[2, 0].set_ylabel('Max Average Response (post)')

# Remove the unused subplot
fig.delaxes(axs[2,1])

plt.tight_layout()
plt.show()

### Frequency domain


In [None]:
bands = {'delta': (0.5, 4),
         'theta': (4, 8),
         'alpha': (8, 12),
         'beta': (12, 30),
         'gamma': (30, 50)}

In [None]:
# Average power per frequency band
psd = spTEP_pre.compute_psd(fmin=1.0, fmax=100.0)
psd.plot(average=True)

freqs = psd.freqs

# Initialize a dictionary to hold the average power for each band
avg_power = {}

for band, (fmin, fmax) in bands.items():
    band_indices = np.where((freqs >= fmin) & (freqs <= fmax))[0]
    band_psd = psd.get_data()[:, band_indices]
    avg_power[band] = np.mean(band_psd)

print(avg_power)

In [None]:
# Difference in average power
psd_pre = epochs.compute_psd(fmin=1.0, fmax=100.0)
psd_post = epochs_post.compute_psd(fmin=1.0, fmax=100.0)

psd_pre_avg = psd_pre.average()
psd_post_avg = psd_post.average()

psd_diff = psd_post_avg.get_data() - psd_pre_avg.get_data()

fig, ax = plt.subplots()

# Plot the difference PSD
ax.plot(psd_pre.freqs, psd_diff[0], label='Post - Pre')

# Add a legend
ax.legend()

# Show the plot
plt.show()

### Time-frequency domain


In [None]:
# Compute the wavelet transform of the data
events, event_dict = mne.events_from_annotations(spTEP_pre)
epochs = mne.Epochs(spTEP_pre, events, event_id=event_dict, tmin=0, tmax=4, baseline=None, preload=True)

frequencies = np.arange(1, 100)

wavelets = mne.time_frequency.tfr_morlet(epochs, freqs=frequencies, n_cycles=2)

In [None]:
# INDIVIDUAL PLOTS

power = wavelets[0].data
avg_power = np.mean(power, axis=0)

fig, ax = plt.subplots()

# Create the heatmap
cax = ax.imshow(avg_power, aspect='auto', cmap='hot', origin='lower')

# Add a colorbar
fig.colorbar(cax)

# Set the labels for the x and y axes and the title
ax.set_xlabel('Time')
ax.set_ylabel('Frequency')
ax.set_title('Average Power')

# Show the plot
plt.show()

# GROUPED PLOTS
freqs = wavelets[0].freqs

n_rows = len(bands)
n_cols = 1

# Create a figure with multiple subplots
fig, axs = plt.subplots(n_rows, n_cols, figsize=(10, 20))

for ax, (band, (fmin, fmax)) in zip(axs, bands.items()):
    # Find the indices that correspond to this frequency band
    band_indices = np.where((freqs >= fmin) & (freqs <= fmax))[0]

    # Slice the power data to include only these frequencies
    band_power = power[:, band_indices, :]

    # Compute the average power across electrodes
    avg_power = np.mean(band_power, axis=0)

    # Create the heatmap
    cax = ax.imshow(avg_power, aspect='auto', cmap='hot', origin='lower')

    ax.set_xlim([None, 200])

    # Add a colorbar
    fig.colorbar(cax, ax=ax)

    # Set the labels for the x and y axes and the title
    ax.set_xlabel('Time')
    ax.set_ylabel('Frequency')
    ax.set_title(f'Average Power ({band} band)')

# Show the plot
plt.tight_layout()
plt.show()

## Clustering
