<a href="https://colab.research.google.com/github/sonydata/EEG_Epilepsy_Classification/blob/main/EEG_Features_extraction_for_supervised_ML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Features extraction methodology



###1. Extract frequency-domain features
**Power Spectral Density (PSD)**
A common approach uses the Welch method for estimating PSD. EEG signals are split into standard frequency bands. For epilepsy detection, these bands can help identify abnormal activity or changes in spectral power.

* Delta: 0.5–4 Hz

* Theta: 4–8 Hz

* Alpha: 8–12 Hz

* Beta: 12–30 Hz

* Gamma: 30–45 (or 50) Hz

**Band Power**

We can compute relative band power by dividing each band’s power by the total power.

Each band’s power (relative=True) is normalized by the total power across all frequencies, which gives more robust measures that are less sensitive to overall amplitude scaling.

### 2. Time-Domain Features
Common time-domain features:

* Mean
* Variance
* Skewness
* Kurtosis
* Zero-Crossing Rate
* Teager-Kaiser Energy Operator (TKEO)

### 3. Entropy and Complexity Measures
Various entropy metrics are popular in seizure detection because they capture signal complexity. Two widely used measures:

* Sample Entropy (SampEn): Measures of complexity based on the regularity of a time series

* Permutation Entropy (PermEn): Looks at the order patterns of the signal values.

### 4. Wavelet Transform (Discrete Wavelet Transform – DWT)
* Captures transient, non-stationary seizure events well.

* Commonly used in many epilepsy-detection studies; wavelet coefficients (and wavelet-based entropy) often provide strong discriminative power.

**Implementation Details**

* Perform multi-level DWT (e.g., 4–6 levels), then extract features like energy, variance, or entropy from each level.

* Keep feature size controlled to avoid overfitting (e.g., by focusing on relevant sub-bands: delta, theta, alpha, beta, gamma).

## Code

Install libraries

In [None]:
! pip install mne
! pip install antropy
import os
import numpy as np
import pandas as pd
import mne
import pywt
from scipy.signal import welch
from scipy.stats import skew, kurtosis
from antropy import sample_entropy, perm_entropy

Create an EEG Feature extractor class with methods for:
* Reading and processing EEG signals
* Applying features extraction/engineering as described in the methodology
* Linking extracted features with patient metadata for subsequent ML analysis

In [None]:
# Standard EEG frequency bands relevant for epilepsy detection
# Each band correlates with specific neurological states
BANDS = {
    'delta': (0.5, 4),    # Deep sleep, but also pathological in epilepsy
    'theta': (4, 8),      # Drowsiness, increased in some epileptic states
    'alpha': (8, 12),     # Relaxed wakefulness, often suppressed during seizures
    'beta': (12, 30),     # Active thinking, shows patterns in epilepsy
    'gamma': (30, 45)     # High frequency activity indicative of seizures
}

class EEGFeatureExtractor:
    """Extracts features from EEG signals for epilepsy detection."""

    def __init__(self, window_size=4, overlap=0.5, wavelet='db4', dwt_level=4):
        """Initialize feature extractor with window parameters and wavelet settings.

        Uses 4-second windows with 50% overlap as standard in EEG analysis.
        Daubechies 4 wavelet chosen for similarity to EEG waveforms.
        """
        self.window_size = window_size
        self.overlap = overlap
        self.wavelet = wavelet
        self.dwt_level = dwt_level

    def extract_from_file(self, edf_file_path, selected_channels=None):
        """Extract features from an EDF file, optionally filtering to specific channels."""
        # Load EDF file using MNE which is specialized for EEG data
        raw = mne.io.read_raw_edf(edf_file_path, preload=True)
        file_id = os.path.basename(edf_file_path).split('.')[0]

        # Get data and metadata
        data, times = raw.get_data(return_times=True)
        sfreq = raw.info['sfreq']
        ch_names = raw.info['ch_names']

        # Filter channels if specified
        # Allows focusing on channels known to be relevant for epilepsy
        if selected_channels:
            ch_indices = [ch_names.index(ch) for ch in selected_channels if ch in ch_names]
            data = data[ch_indices]
            ch_names = [ch_names[i] for i in ch_indices]

        # Set window parameters - consistent regardless of sampling frequency
        window_samples = int(self.window_size * sfreq)
        step = int(window_samples * (1 - self.overlap))

        # Extract features from windows
        feature_rows = []

        # Process each channel separately
        for ch_idx, ch_name in enumerate(ch_names):
            ch_signal = data[ch_idx]
            start = 0
            window_count = 0

            # Slide window through signal with overlap
            while start + window_samples <= len(ch_signal):
                window = ch_signal[start:start + window_samples]

                # Create feature dictionary with metadata
                features = {
                    'file_id': file_id,
                    'channel': ch_name,
                    'window_idx': window_count,
                    'start_time': start / sfreq,
                    'end_time': (start + window_samples) / sfreq
                }

                # Extract feature sets that capture different signal aspects
                # Each type of feature reflects different neurophysiological information
                features.update(self._compute_time_domain(window))         # Statistical properties
                features.update(self._compute_frequency_domain(window, sfreq))  # Oscillatory components
                features.update(self._compute_entropy(window))             # Signal complexity
                features.update(self._compute_wavelet(window))             # Time-frequency characteristics

                feature_rows.append(features)
                start += step
                window_count += 1

        return pd.DataFrame(feature_rows)

    def _compute_time_domain(self, signal):
        """Calculate time-domain statistical features that characterize signal properties."""
        features = {
            'mean': np.mean(signal),                # Central tendency
            'variance': np.var(signal),             # Signal variability
            'skewness': skew(signal),               # Asymmetry - altered during seizures
            'kurtosis': kurtosis(signal, fisher=False),  # Peakedness - reflects sharp activities
            'zcr': len(np.where(np.diff(np.sign(signal)))[0]) / len(signal)  # Signal polarity changes
        }

        # Teager-Kaiser Energy Operator - highlights rapid amplitude/frequency changes
        # Effective at detecting abrupt changes characteristic of seizure onset
        if len(signal) > 2:
            tkeo = np.mean(signal[1:-1]**2 - signal[:-2] * signal[2:])
            features['tkeo'] = tkeo
        else:
            features['tkeo'] = 0

        return features

    def _compute_frequency_domain(self, signal, sf):
        """Calculate relative power in standard frequency bands.
        Using relative power normalizes across recordings with different amplitudes.
        """
        features = {}

        for band_name, (low, high) in BANDS.items():
            bp = self._bandpower(signal, sf, (low, high), relative=True)
            features[f'{band_name}_power'] = bp

        return features

    def _compute_entropy(self, signal):
        """Calculate entropy-based complexity measures.
        Entropy changes often correlate with pathological brain states.
        """
        # Sample entropy - measures time series regularity, lower during seizures
        samp_en = sample_entropy(signal, order=2, r=0.2*np.std(signal))

        # Permutation entropy - sensitive to dynamic changes during seizures
        perm_en = perm_entropy(signal, order=3, normalize=True)

        return {
            'sample_entropy': samp_en,
            'perm_entropy': perm_en
        }

    def _compute_wavelet(self, signal):
        """Calculate wavelet-based time-frequency features.
        Effective at capturing transient events like seizures across frequency scales.
        """
        features = {}

        # Decompose signal using Discrete Wavelet Transform
        coeffs = pywt.wavedec(signal, wavelet=self.wavelet, level=self.dwt_level)

        for i, coef in enumerate(coeffs):
            # Name bands according to standard wavelet notation
            band_name = f"A{self.dwt_level}" if i == 0 else f"D{self.dwt_level - i + 1}"

            # Calculate coefficient features
            energy = np.sum(coef**2)              # Power in this frequency band
            variance = np.var(coef)               # Variability in this band

            # Shannon entropy measures information content in this frequency band
            p = np.abs(coef)**2
            p = p / (np.sum(p) + 1e-12)          # Normalized power with epsilon to avoid log(0)
            shannon_entropy = -np.sum(p * np.log2(p + 1e-12))

            features[f'{band_name}_energy'] = energy
            features[f'{band_name}_variance'] = variance
            features[f'{band_name}_entropy'] = shannon_entropy

        return features

    def _bandpower(self, data, sf, band, window_sec=None, relative=False):
        """Compute power in a specific frequency band using Welch's method.
        Integrates the power spectral density within the frequency band.
        """
        low, high = band

        # Set window length - affects frequency vs. time resolution tradeoff
        if window_sec is None:
            window_sec = len(data) / sf

        nperseg = min(int(window_sec * sf), len(data))
        freqs, psd = welch(data, sf, nperseg=nperseg)

        # Extract band of interest
        idx_band = np.logical_and(freqs >= low, freqs <= high)
        freq_res = freqs[1] - freqs[0]

        # Integrate power using trapezoidal rule
        bp = np.trapz(psd[idx_band], dx=freq_res)

        # Calculate relative power if requested - helps normalize across recordings
        if relative:
            total_power = np.trapz(psd, dx=freq_res)
            bp /= total_power if total_power > 0 else 1

        return bp

def process_dataset(metadata_df, edf_dir, output_file, channels=None):
    """Process multiple EEG recordings and extract features based on metadata.
    Links extracted features with patient metadata for subsequent ML analysis.
    """
    extractor = EEGFeatureExtractor()
    all_features = []

    for idx, row in metadata_df.iterrows():
        file_name = row['file_path']
        file_path = os.path.join(edf_dir, file_name)

        if not os.path.exists(file_path):
            print(f"File not found: {file_path}")
            continue

        features_df = extractor.extract_from_file(file_path, channels)

        if not features_df.empty:
            # Add metadata columns - important for tracking patient info, diagnosis, etc.
            for col in metadata_df.columns:
                if col != 'file_path':
                    features_df[col] = row[col]

            all_features.append(features_df)
            print(f"Processed {file_name} - extracted {len(features_df)} windows")

    if all_features:
        final_df = pd.concat(all_features, ignore_index=True)
        final_df.to_csv(output_file, index=False)
        print(f"Features saved to {output_file}")
        return final_df
    else:
        print("No features extracted")
        return pd.DataFrame()


Test


In [None]:
# Example usage
if __name__ == "__main__":
    # Simple example metadata with binary classification (epileptic vs. non-epileptic)
    metadata = pd.DataFrame({ #read csv file
        'file_path': ['sample_data/aaaaaanr_s001_t001.edf'],
        'patient_id': ['aaaaaanr'],
        'label': [0]  # 0 for non-epileptic, 1 for epileptic
    })

    # Process dataset with selected channels relevant to epilepsy detection
    features = process_dataset(
        metadata_df=metadata,
        edf_dir='.',
        output_file='eeg_features.csv',
        channels= ['add channels here'] #add channels to select
    )

    if not features.empty:
        print(f"\nExtracted {len(features)} feature windows")
        print(f"Features per window: {len(features.columns)}")
        print("\nFeature columns:")
        print(features.columns.tolist())