## Data gathering and separation

In [72]:
from glob import glob
import os
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [73]:
# Get the data set path file
all_file_path = glob('../dataverse_files/*.edf')
print(len(all_file_path))
print(all_file_path[0])

28
../dataverse_files\h01.edf


In [74]:
# Separate healthy and schizofrenic patients eeg data
healthy_file_path = [i for i in all_file_path if 'h' in i.split('\\')[1]]
patient_file_path = [i for i in all_file_path if 's' in i.split('\\')[1]]

print(healthy_file_path[0])
print(patient_file_path[0])

../dataverse_files\h01.edf
../dataverse_files\s01.edf


In [75]:
# Function to read the data from .edf files using mne

def read_data(file_path, low_freq=0.5, high_freq=45, duration=5, overlap=1):
    data = mne.io.read_raw_edf(file_path, preload=True)
    data.set_eeg_reference()
    data.filter(l_freq=low_freq, h_freq=high_freq)
    epochs = mne.make_fixed_length_epochs(data, duration=duration, overlap=overlap)
    array = epochs.get_data()
    return array

In [76]:
sample_data = read_data(healthy_file_path[0])
sample_data.shape   # Nb of epochs (time windows speficied by duration in the function above), channels (nb of electrodes in the eeg), lenght of the signal (nb of total points)

Extracting EDF parameters from c:\Projects\EEG Classification\dataverse_files\h01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 231249  =      0.000 ...   924.996 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (6.604 s)

Not setting metadata
231 matching events found
No baseline correction applied
0 projection items

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


(231, 19, 1250)

In [77]:
%%capture

control_epochs_array = [read_data(i) for i in healthy_file_path]
patient_epochs_array = [read_data(i) for i in patient_file_path]

# each array will have the data of 14 subjects

In [78]:
# labeling healthy data as "0" and schizofrenic patient data as "1" for each epoch on each subject

control_epochs_labels = [len(i)*[0] for i in control_epochs_array]
patient_epochs_labels = [len(i)*[1] for i in patient_epochs_array]

len(control_epochs_labels), len(patient_epochs_labels)

(14, 14)

In [79]:
data_list = control_epochs_array + patient_epochs_array    # all data
label_list = control_epochs_labels + patient_epochs_labels  # all labels

In [80]:
# Identifying data per subject
group_list = [[i]*len(j) for i,j in enumerate(data_list)]
print(len(group_list))

28


In [81]:
# Transforms into array of eeg signals stacked

data_array = np.vstack(data_list)
label_array = np.hstack(label_list)
group_array = np.hstack(group_list)
print(data_array.shape, label_array.shape, group_array.shape)   # (epochs, electrodes, total points), (labels - 0 for healthy, 1 for patient), (subjects)

(7201, 19, 1250) (7201,) (7201,)


## Features extraction 

In [82]:
def extract_features_time_domain(signal):
    """
    This function extracts features in the time domain from a raw EEG signal

    ## Returns:

    This function returns a **ndarray** which contain, respectively, the following features:\n

    - **Mean**: The average of the signal.

    - **Variance**: Indicates how much the signal amplitude fluctuates from the mean.

    - **Standard Deviation**: The square root of the variance, showing the average amount of deviation from the mean.

    - **Root-Mean-Square (RMS)**: Measures the magnitude of the EEG signal by taking the square root of the average of the squared values.

    - **Absolute Mean Difference (AMD)**: This feature calculates the average absolute difference between consecutive values.

    - **Skewness**: Measures the asymmetry of the signal around the mean. Positive skew indicates a tail on the right side,
    while negative skew shows a tail on the left. It helps detect irregular patterns.

    - **Kurtosis**: Measures the "tailedness" of the signal distribution, indicating whether data points are close to the mean or more spread out.
    High kurtosis may indicate spikes or irregularities in the signal.

    ### Hjorth Parameters: A set of metrics specifically for EEG data:\n

    - **Mobility**: Indicates the frequency, defined as the square root of the variance of the first derivative divided by the variance of the signal.

    - **Complexity**: The ratio of the mobility of the first derivative to the mobility of the signal itself.

    """

    mean = np.mean(signal, axis=-1)
    variance = np.var(signal, axis=-1)
    std_dev = np.std(signal, axis=-1)
    rms_value = np.sqrt(np.mean(signal**2, axis=-1))
    amd_value = np.sum(np.abs(np.diff(signal, axis=-1)), axis=-1)

    from scipy.stats import skew, kurtosis
    skewness = skew(signal, axis=-1)
    kurtosis_value = kurtosis(signal, axis=-1)

    first_derivative = np.diff(signal, axis=-1)
    second_derivative = np.diff(first_derivative, axis=-1)
    mobility = np.sqrt(np.var(first_derivative)/variance)
    complexity = np.sqrt(np.var(second_derivative) / np.var(first_derivative)) / mobility

    return np.concatenate((mean, variance, std_dev, rms_value, amd_value, skewness, kurtosis_value, mobility, complexity), axis=-1)

In [83]:
all_features = []
for data in data_array:
    all_features.append(extract_features_time_domain(data))

In [84]:
all_features_array = np.array(all_features)
all_features_array.shape          # (Nb of epochs, all features extracted from each electrode: 9x19) 

(7201, 171)

## Linear Regression Classification

In [85]:
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold, GridSearchCV

**Logistic Regression**: Logistic regression is a supervised learning algorithm used for binary or multiclass classification. This code creates a machine learning pipeline to train a logistic regression classifier with hyperparameter tuning using grouped cross-validation.

*GroupKFold Cross-Validation*:
- GroupKFold ensures that samples from the same group (e.g., signals from the same subject) are not mixed between training and test sets. This prevents data leakage and ensures that the model generalizes better to unseen data, improving its predictive performance.
- This approach helps when the data has inherent group structure (such as multiple signals from the same individual or session).

*GridSearchCV with Hyperparameter Tuning*:
- GridSearchCV systematically evaluates different values of the hyperparameters (in this case, the regularization strength **C** in logistic regression) and identifies the best configuration for the model.
- The C parameter controls the tradeoff between fitting the model too closely to the training data (overfitting) and underfitting the data.

*Standardization*:
- The pipeline includes a StandardScaler, which standardizes the features (i.e., transforms them to have zero mean and unit variance). This is particularly helpful for models like logistic regression, which can be sensitive to the scale of features.

In [93]:
# The model itself
clf = LogisticRegression()

# Ensures that the samples within the same group remain in the same fold during training and testing
gkf = GroupKFold(5)

# Standardizes features and puts the classifier after the scaler so the data is standardized before training.
pipeline = Pipeline([('scalar', StandardScaler()), ('clf', clf)])

# Specifies the values to test for the logistic regression C parameter. Each value in this list corresponds to a different regularization strength.
param_grid = {'clf__C':np.linspace(0.05, 1, 30)}

# GridSearchCV performs a grid search over the param_grid to find the best C parameter for the logistic regression model, 
# using cross-validation with GroupKFold. The n_jobs parameter allows parallel processing, speeding up the computation by using X CPU cores
gscv = GridSearchCV(pipeline, param_grid, cv=gkf, n_jobs=12)

# trains the GridSearchCV object using the feature array (all_features_array), labels (label_array), and groups (group_array). 
# After running this, gscv will contain the best logistic regression model based on the cross-validated performance.
gscv.fit(all_features_array, label_array, groups=group_array)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [94]:
print("Best Score Prediction: {:.2f} %".format(gscv.best_score_*100))

print("Best parameter (C):", gscv.best_params_)

Best Score Prediction: 70.52 %
Best parameter (C): {'clf__C': 0.1482758620689655}
