# SEED Dataset Preprocessing

This Notebook can be used in order to preprocess the SEED Dataset in the following way: <br/>
1. Collect all the Data form the mat files
2. Remove the baseline from the SEEG Signals
3. Filter the EEG Signals, either with a bandpass filter, a lowpass filter or a highpass filter
4. Downsample the EEG Signals to a certain sampling rate
5. Select only certain EEG Channels and rearange them in a given order
6. Cut the EEG Signals into Windows of given length with given overlap
7. Safe the generated Dataset
8. Generate some decent plots to compare the old signal with the new signal

## Description, Dataset Access and Citation

<b>For a full description and access to the Dataset please see:</b> <br/>
url: http://bcmi.sjtu.edu.cn/home/seed/seed.html <br/>
[1] Duan, R.-N.; Zhu, J.-Y.; Lu, B.-L.: Differential Entropy Feature for EEG-based Emo-tion Classification. In:6th International IEEE/EMBS Conference on Neural Engineering(NER), p. 81–84, IEEE, 2013 <br/>
[2] Zheng, W.-L.; Lu, B.-L.: Investigating Critical Frequency Bands and Channels forEEG-based Emotion Recognition with Deep Neural Networks.IEEE Transactions on Autonomous Mental Development 7 (2015) 3, p. 162–175. <br/><br/>
<b>If you find this code helpfull please cite:</b><br/>
tbd

## Hyperparameters

In [None]:
# data_dir ='../Datasets/raw/SEED_copy/raw/Preprocessed_EEG'
# output_dir = '../Datasets/full_preprocessing/'
data_dir ='E:/Databases/SEED/SEED/Preprocessed_EEG/'
output_dir = 'E:/Databases/DataPre/SEED/full_preprocessing/'
baseline_removal_window = 3
cutoff_frequencies = [4,40]
seconds_to_use = 185
downsampling_rate = 128
channels_to_use = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
window_size = 2
window_overlap = 0
save_plots_to_file = False

In [None]:
import os
os.makedirs(os.path.join(output_dir, "figures"), exist_ok=True)

<b>data dir (String)</b>               - the path to the raw SEED Dataset, e.g. "SEED/raw/Proproccessed_EEG/" <br/>
<b>output dir (String)</b>             - the path, where you want to save the preprocessed Dataset, e.g.                                                   "myDatasets/SEED/ <br/>
<b>baseline_removal_window (float)</b> - the timewindow, to calculate a baseline (average) from. This value gets                                             subtracted from the whole timeseries. If you don't want to use baseline_removal,                                   set it to 0 <br/>
<b>cutoff_frequencies (touple)</b>     - the cutoff frequencies for the filter, if the first value ist set to None,                                         allowpass-filtering is used, if the second value is set to None, a highpass filt                                   ist used, if both values are set, a bandpass filter is used <br/>
<b>seconds_to_use (int)</b> - The window in seconds to use from the timeseries, if for example seconds_to_use ist set to 45, only the last 45 seconds will be used. If you want to use the whole timeseries, set it to None <br/>
<b>downsampling_rate (int)</b>         - the frequency, to which the eeg signals should be downsampled. If you don't want to downsample the signal, set it to 0 <br/>
<b>channels_to_use (list)</b> a list of eeg channels you want to use, if you want to use all channels from the dataset, set it to None<br/>
<b>window_size (int)</b>              - the lenght of the timewindow in seconds, the dataset should be cut into <br/>
<b>window_overlap (int)</b>            - overlap of the windows in seconds, if set to 0, the windows don't overlap <br/>
<b>save_plots_to_file (boolean)</b> - wheter or not you want to save the generated plots, if you choose False, the plots will be only shown in this notebook

## necessary imports

In [None]:
import numpy as np
import scipy.io as sio
from tqdm.notebook import tqdm, trange
import os
import warnings

## Data Collection

In [None]:
sampling_rate = 200
channel_names = ['FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'OZ', 'O2', 'CB2']

In [None]:
from glob import glob

# def get_session_id(subject_id,filename):
#     #print(os.path.join(data_dir,subject_id,'_********.mat'))
#     files = glob(os.path.join(data_dir,subject_id+'_********.mat'))
#     return sorted(files).index(os.path.join(data_dir,filename))

def get_session_id(subject_id,filename):
    pattern = os.path.join(data_dir, subject_id + '_********.mat').replace('\\', '/')
    files = [f.replace('\\', '/') for f in glob(pattern)]
    session_id = sorted(files).index(os.path.join(data_dir, filename).replace('\\', '/'))
    return session_id

In [None]:
def get_keynames(dict_keys):
    dict_keys = list(dict_keys)
    trial_names = list()
    for trial_name in dict_keys:
        for i in range(1,16):
            if '_eeg'+str(i) in trial_name:
                trial_names.append(trial_name)
    return trial_names

In [None]:
label = sio.loadmat(data_dir + 'label.mat')['label'][0]
print(label)
print(len(label))

In [None]:
X = np.zeros((675,62,53001))
X_len = np.zeros(675,dtype=int)
Y = np.zeros((675,))
session = np.zeros((675,))
subject = np.zeros((675,))
trial = np.zeros((675,))
for file in tqdm(os.listdir(data_dir)):
    if file.endswith(".mat") and not file.endswith("label.mat"):
        subject_id = file.split('_')[0]
        session_id = get_session_id(subject_id,file)
        X_temp = sio.loadmat(os.path.join(data_dir,file))
        X_temp_keys = get_keynames(X_temp.keys())
        for trial_id in range(1,16):
            experiment_id = session_id*15*15 + (int(subject_id)-1)*15 + trial_id-1
            subject[experiment_id] = subject_id
            session[experiment_id] = session_id
            trial[experiment_id] = trial_id
            Y[experiment_id] = label[trial_id-1]
            X_temp_exp = X_temp[X_temp_keys[trial_id]]
            X_len[experiment_id] = X_temp_exp.shape[1]
            X[experiment_id,:,:X_temp_exp.shape[1]] = X_temp_exp

In [None]:
print("Shape of the Time Series Array X: " + str(X.shape))
print("Unique Label Y Indices: " + str(np.unique(Y)))
print("Unique Session Indices: " + str(np.unique(session)))
print("Unique Subject Indices: " + str(np.unique(subject)))
print("Unique Trial Indices: " + str(np.unique(trial)))
print("Minimum length of Timeseries: " + str(min(X_len)))
print("Maximum length of Timeseries: " + str(max(X_len)))

In [None]:
# X_raw is later used for plotting, if you don't want to see the plots, you can uncomment this line
X_raw = X.copy()

## Baseline-Removal

In [None]:
if not(baseline_removal_window==0):
    baseline_datapoints = baseline_removal_window * sampling_rate
    baseline = X[:,:,:baseline_datapoints].sum(2) / baseline_datapoints
    for timestep in trange(X.shape[2]):
        X[:,:,timestep] = X[:,:,timestep] - baseline

## Filtering (Bandpass or Highpass or Lowpass)

In [None]:
from scipy.signal import butter, sosfilt, sosfreqz

def butter_bandpass(lowcut, highcut, fs, btype='band', order=5):
        nyq = 0.5 * fs
        if btype == 'bandpass':
            low = lowcut / nyq
            high = highcut / nyq
            sos = butter(order, [low, high], analog=False, btype='bandpass', output='sos')
        elif btype == 'highpass':
            low = lowcut / nyq
            sos = butter(order, low, analog=False, btype='highpass', output='sos')
        elif btype == 'lowpass':
            high = highcut / nyq
            sos = butter(order, high, analog=False, btype='lowpass', output='sos')
        return sos

def butter_bandpass_filter(X, lowcut, highcut, fs, btype='bandpass', order=5):
        sos = butter_bandpass(lowcut, highcut, fs, btype=btype, order=order)
        X = sosfilt(sos, X)
        return X

In [None]:
if not(cutoff_frequencies[0] == None):
    if not(cutoff_frequencies[1] == None):
        btype='bandpass'
    else:
        btype='highpass'
elif not (cutoff_frequencies[1] == None):
        btype='lowpass'

for experiment_id in trange(X.shape[0]):
    for channel_id in range(X.shape[1]):
        X[experiment_id, channel_id, :] = butter_bandpass_filter(
                                                        X[experiment_id, channel_id, :],
                                                        cutoff_frequencies[0],
                                                        cutoff_frequencies[1],
                                                        sampling_rate,
                                                        btype=btype,
                                                        order=5)

## Use only the last n seconds of the timeseries
(As determined with the hyperparameter seconds_to_use)

In [None]:
if not(seconds_to_use == None):
    num_sample_points_to_use = seconds_to_use * sampling_rate
    X_selected = np.zeros((X.shape[0], X.shape[1], num_sample_points_to_use))
    for exp_id in trange(len(X_len)):
        X_selected[exp_id,:,:] = X[exp_id,:,X_len[exp_id]-num_sample_points_to_use:X_len[exp_id]]
    X = X_selected

## Downsampling

In [None]:
from scipy.signal import resample

In [None]:
if not(downsampling_rate == 0) and not(downsampling_rate == sampling_rate):
    new_length = int(X.shape[2] / sampling_rate * downsampling_rate)
    X_downsampled = np.zeros((X.shape[0], X.shape[1], new_length))
    for experiment_id in trange(X.shape[0]):
        for channel_id in range(X.shape[1]):
            X_downsampled[experiment_id, channel_id, :] = resample(X[experiment_id, channel_id, :], new_length)
    X = X_downsampled

## Select certain channels

In [None]:
if channels_to_use == None:
    channels_to_use = channel_names

In [None]:
channel_index_list = list()
for i in range(len(channels_to_use)):
    if channels_to_use[i] in channel_names:
        channel_index_list.append(channel_names.index(channels_to_use[i]))
    else:
        warnings.warn(' Channel ' + channels_to_use[i] +' could not be found in the list of actual channels')

In [None]:
X_selected_channels = np.zeros((X.shape[0], len(channels_to_use), X.shape[2]))
for channel in trange(len(channel_index_list)):
    X_selected_channels[:,channel,:] = X[:,channel_index_list[channel],:]
X = X_selected_channels

## Cut into windows

In [None]:
window_size = 2
window_overlap = 0

In [None]:
num_points_per_window = window_size * downsampling_rate
num_points_overlap = window_overlap * downsampling_rate
stride = num_points_per_window - num_points_overlap
start_index = [0]
end_index = [num_points_per_window]
num_windows_per_exp = 1
while(end_index[-1]+stride < X.shape[2]):
    num_windows_per_exp = num_windows_per_exp + 1
    start_index.append(start_index[-1] + stride)
    end_index.append(end_index[-1] + stride)
X_cut = np.zeros((num_windows_per_exp*X.shape[0],X.shape[1], num_points_per_window))
Y_cut = np.zeros(num_windows_per_exp*X.shape[0],)
session_cut = np.zeros(num_windows_per_exp*X.shape[0],)
subject_cut = np.zeros(num_windows_per_exp*X.shape[0],)
trial_cut = np.zeros(num_windows_per_exp*X.shape[0],)
for exp_id in trange(X.shape[0]):
    for window_id in range(len(start_index)):
        X_cut[exp_id*num_windows_per_exp+window_id,:,:] = X[exp_id,:,start_index[window_id]:end_index[window_id]]
        Y_cut[exp_id*num_windows_per_exp+window_id] = Y[exp_id]
        session_cut[exp_id*num_windows_per_exp+window_id] = session[exp_id]
        subject_cut[exp_id*num_windows_per_exp+window_id] = subject[exp_id]
        trial_cut[exp_id*num_windows_per_exp+window_id] = trial[exp_id]
X = X_cut
Y = Y_cut
session = session_cut
subject = subject_cut
trial = trial_cut

## Safe Dataset

Please note: This step can take up to several minutes

In [None]:
np.savez_compressed(
                    output_dir +'/SEED.npz',
                    X=X,
                    Y=Y,
                    session = session,
                    subject = subject,
                    trial = trial,
                    downsampling_rate = downsampling_rate,
                    channel_names = channels_to_use,
                    window_size=window_size,
                    window_overlap = window_overlap,
                    cutoff_frequencies = cutoff_frequencies,
                    baseline_removal_window = baseline_removal_window,
                    seconds_to_use = seconds_to_use
                    )
print('Saved File')

## Plots

In [None]:
def perform_FFT(X, sampling_rate):
    n = len(X) # length of the signal
    k = np.arange(n)
    T = n/sampling_rate
    frq = k/T # two sides frequency range
    frq = frq[:len(frq)//2] # one side frequency range

    X_FT = np.fft.fft(X)/n # dft and normalization
    X_FT = X_FT[:n//2]
    return X_FT, frq

In [None]:
channel_id_to_plot = 0

In [None]:
raw_ts = X_raw[0,channel_names.index(channels_to_use[channel_id_to_plot]),X_raw.shape[2]-(seconds_to_use*sampling_rate):X_raw.shape[2]-(seconds_to_use*sampling_rate) + window_size*sampling_rate]
raw_ts = X_raw[0,channel_names.index(channels_to_use[channel_id_to_plot]),0:sampling_rate*window_size]
preprocessed_ts = X[0,channel_id_to_plot,:]

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(np.arange(0,sampling_rate*window_size),raw_ts, label="Raw Signal")
plt.plot(np.arange(0,downsampling_rate*window_size)/downsampling_rate*sampling_rate,preprocessed_ts, label="Preprocessed Signal")
plt.title("Comparison of an Exemplary Time Series")
plt.xticks([0,sampling_rate,2*sampling_rate],[0,1,2])
plt.xlabel("Time t [s]")
plt.ylabel("Voltage U [mV]")
plt.legend()
if save_plots_to_file:
    plt.savefig(output_dir+'figures/SEED_Timespace.png', facecolor="white")
else:
    plt.show()

In [None]:
raw_fs, raw_frq = perform_FFT(raw_ts, sampling_rate)
preprocessed_fs, preprocessed_frq = perform_FFT(preprocessed_ts, downsampling_rate)

In [None]:
plt.plot(raw_frq, abs(raw_fs), label="Raw Signal")
plt.plot(preprocessed_frq, abs(preprocessed_fs), label="Preprocessed Signal")
plt.title("Comparison of an Exemplary Spectrum")
plt.xlabel("Frequency f [Hz]")
plt.ylabel("Amplitude |X(f)|")
plt.legend()
if save_plots_to_file:
    plt.savefig(output_dir+'figures/SEED_Frequencyspace.png', facecolor="white")
else:
    plt.show()