In [1]:
from glob import glob
import mne
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from scipy.signal import welch
from scipy.stats import skew, kurtosis
from mne.io import read_raw_edf
from mne import make_fixed_length_epochs
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold, GridSearchCV
from sklearn.metrics import ConfusionMatrixDisplay
import pywt
from scipy.stats import entropy  
from scipy.signal import coherence 

In [2]:
all_file_path=glob('dataverse_files/*.edf')
all_file_path[0]

'dataverse_files\\h01.edf'

In [3]:
    healthy_file_path=[i for i in all_file_path if 'h' in i.split('\\')[1]] # split healthy patients files
    patient_file_path=[i for i in all_file_path if 's' in i.split('\\')[1]]

In [4]:
def read_data(file_path):
    datax=mne.io.read_raw_edf(file_path,preload=True)
    datax.set_eeg_reference()# by default reference is average of all channels 
    datax.filter(l_freq=1,h_freq=45)
    epochs=mne.make_fixed_length_epochs(datax,duration=25,overlap=0) # break the contious signal into smaller signals called epochs
    epochs=epochs.get_data()
    return epochs #no_of_trials,channels,length_of_signal

data=read_data(healthy_file_path[0])
data.shape

Extracting EDF parameters from C:\Users\test\PR_EEG\dataverse_files\h01.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 231249  =      0.000 ...   924.996 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 825 samples (3.300 s)

Not setting metadata
37 matching events found
No baseline correction applied
0 projection items activated
Usi

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s


(37, 19, 6250)

In [5]:
%%capture
control_epochs_array=[read_data(subject) for subject in healthy_file_path]
patients_epochs_array=[read_data(subject) for subject in patient_file_path]

In [6]:
print(control_epochs_array[0].shape)
control_epochs_labels=[len(i)*[0] for i in control_epochs_array]
patients_epochs_labels=[len(i)*[1] for i in patients_epochs_array]
print(len(control_epochs_labels),len(patients_epochs_labels))

(37, 19, 6250)
14 14


In [7]:
data_list=control_epochs_array+patients_epochs_array
label_list=control_epochs_labels+patients_epochs_labels
print(len(data_list),len(label_list))

28 28


In [8]:
groups_list=[[i]*len(j) for i, j in enumerate(data_list)]

In [9]:
data_array=np.vstack(data_list)
label_array=np.hstack(label_list)
group_array=np.hstack(groups_list)
print(data_array.shape,label_array.shape,group_array.shape)

(1142, 19, 6250) (1142,) (1142,)


In [10]:
from hurst import compute_Hc

def mean(data):
    return np.mean(data, axis=-1)

def std(data):
    return np.std(data, axis=-1)

def ptp(data):
    return np.ptp(data, axis=-1)

def var(data):
    return np.var(data, axis=-1)

def minim(data):
    return np.min(data, axis=-1)

def maxim(data):
    return np.max(data, axis=-1)

def argminim(data):
    return np.argmin(data, axis=-1)

def argmaxim(data):
    return np.argmax(data, axis=-1)

def mean_square(data):
    return np.mean(data**2, axis=-1)

def rms(data):
    return np.sqrt(np.mean(data**2, axis=-1))

def abs_diffs_signal(data):
    return np.sum(np.abs(np.diff(data, axis=-1)), axis=-1)

def skewness(data):
    return skew(data, axis=-1)

def kurtosis_custom(data):
    """Compute kurtosis along the last axis manually."""
    if data.ndim == 2:  # Handles 2D data
        return np.array([kurtosis(data[i, :]) for i in range(data.shape[0])])
    elif data.ndim == 3:  # Handles 3D data
        return np.array([[kurtosis(data[i, j, :]) for j in range(data.shape[1])] for i in range(data.shape[0])])
    else:  # Handles 1D data
        return kurtosis(data)


def concatenate_features(data):
    return np.concatenate((
        mean(data), std(data), ptp(data), var(data), minim(data),
        maxim(data), argminim(data), argmaxim(data),
        mean_square(data), rms(data), abs_diffs_signal(data),
        skewness(data), kurtosis_custom(data)
    ), axis=-1)


# Frequency band power
def band_power(data, sf, band):
    """
    Calculate the power of a specific frequency band for multi-channel EEG data.
    Parameters:
    - data: np.ndarray, shape (n_channels, n_samples)
    - sf: Sampling frequency
    - band: tuple, (fmin, fmax)
    
    Returns:
    - np.ndarray, shape (n_channels,) Band power for each channel
    """
    fmin, fmax = band
    psd_all = []
    freqs_all = None

    for channel in data:  # Loop through each channel
        psd, freqs = welch(channel, sf, nperseg=256)
        if freqs_all is None:  # Capture frequencies only once
            freqs_all = freqs
        psd_all.append(psd)

    psd_all = np.array(psd_all)  # Shape: (n_channels, n_freqs)
    idx_band = np.logical_and(freqs_all >= fmin, freqs_all <= fmax)
    return np.sum(psd_all[:, idx_band], axis=-1)




def extract_more_features(data, sf):
    """Extract additional features from signal data."""
    delta = band_power(data, sf, (1, 4))
    theta = band_power(data, sf, (4, 8))
    alpha = band_power(data, sf, (8, 13))
    beta = band_power(data, sf, (13, 30))
    gamma = band_power(data, sf, (30, 45))
    return np.concatenate((delta, theta, alpha, beta, gamma), axis=-1)

# --- Advanced Feature Extraction ---
def wavelet_features(data):
    coeffs = pywt.wavedec(data, wavelet='db4', level=4, axis=-1)
    features = []
    for coeff in coeffs:
        features.append(np.mean(coeff, axis=-1))
        features.append(np.std(coeff, axis=-1))
    return np.concatenate(features, axis=-1)

def hurst_exponent(data):
    # Calculate the Hurst exponent for each signal/channel
    hurst_vals = np.array([compute_Hc(data[i])[0] for i in range(data.shape[0])])
    return hurst_vals


def spectral_entropy(data, sf):
    psd, freqs = welch(data, sf, nperseg=256, axis=-1)
    psd_norm = psd / np.sum(psd, axis=-1, keepdims=True)  # Normalize
    if psd_norm.size == 0:
        print("Spectral entropy: Empty PSD!")
        return np.zeros_like(data)  # Return zeros if empty
    return entropy(psd_norm, axis=-1)  # Ensure it returns a 1D array


def connectivity_features(data, sf):
    n_channels = data.shape[0]
    coherence_vals = []
    for i in range(n_channels):
        for j in range(i + 1, n_channels):
            _, Cxy = coherence(data[i], data[j], sf, nperseg=256)
            coherence_vals.append(np.mean(Cxy))
    return np.array(coherence_vals)


def concatenate_features_with_new(data, sf):
    basic_features = concatenate_features(data)
    band_features = extract_more_features(data, sf)
    
    # Check if features are empty or scalars and reshape them
    wavelet_feats = wavelet_features(data)
    if wavelet_feats.ndim == 0:
        wavelet_feats = np.expand_dims(wavelet_feats, axis=-1)
    
    hurst_feats = hurst_exponent(data)
    if hurst_feats.ndim == 0:
        hurst_feats = np.expand_dims(hurst_feats, axis=-1)
    
    entropy_feats = spectral_entropy(data, sf)
    if entropy_feats.ndim == 0:
        entropy_feats = np.expand_dims(entropy_feats, axis=-1)
    
    connectivity_feats = connectivity_features(data, sf)
    if connectivity_feats.ndim == 0:
        connectivity_feats = np.expand_dims(connectivity_feats, axis=-1)



    # Now concatenate them
    return np.concatenate((basic_features, band_features, wavelet_feats, hurst_feats, entropy_feats, connectivity_feats), axis=-1)




# Sampling frequency
sf = 250

# Extract features
expanded_features = []
for data in tqdm(data_array):
    expanded_features.append(concatenate_features_with_new(data, sf))
expanded_features = np.array(expanded_features)


  0%|          | 0/1142 [00:00<?, ?it/s]

In [11]:

import numpy as np
from tqdm import tqdm
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GroupKFold, GridSearchCV
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import cross_val_score

In [13]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
features_scaled = scaler.fit_transform(expanded_features)

# Apply PCA
n_components = 80  # Adjust based on desired explained variance
pca = PCA(n_components=n_components)
features_pca = pca.fit_transform(features_scaled)
print(f"Shape after PCA: {features_pca.shape}")

Shape after PCA: (1142, 80)


In [None]:
# Define the parameter grid
param_grid = {
    'random_forest__n_estimators': [50, 100, 200, 300, 500],
    'random_forest__max_depth': [None, 10, 20, 30, 40],
    'random_forest__min_samples_split': [2, 5, 10],
    'random_forest__min_samples_leaf': [1, 2, 5, 10],
    'random_forest__max_features': ['sqrt', 'log2', None]
}

# Create the pipeline
pipe = Pipeline([
    ('scaler', StandardScaler()),
    ('random_forest', RandomForestClassifier(random_state=42))
])

# Initialize GroupKFold
gkf = GroupKFold(n_splits=5)

# Grid search with cross-validation
gscv = GridSearchCV(pipe, param_grid, cv=gkf, scoring='accuracy', n_jobs=-1, verbose=1)

# Fit the grid search
gscv.fit(features_pca, label_array, groups=group_array)

# Best parameters and accuracy
print("Best parameters:", gscv.best_params_)
print("Best cross-validated accuracy:", gscv.best_score_)

Fitting 5 folds for each of 900 candidates, totalling 4500 fits
