# Sleep Spindle Study

## Feature extraction

The VDM applied in the preprocess phase can be used to select features from the raw eeg too.


In [None]:
import numpy as np
from scipy.signal import welch
import pywt


### Power Spectral Density (PSD) Features

The `extract_psd_features` function computes the average power of the EEG signal in a specified frequency band using Welch's method. Power Spectral Density represents the power distribution over different frequencies and is commonly used in EEG analysis to study the energy of brain waves within different frequency ranges.

- `epochs`: The Epochs object from MNE, which contains the EEG signal data.
- `fmin`: The minimum frequency range for feature extraction.
- `fmax`: The maximum frequency range for feature extraction.

With these parameters, the PSD of each epoch is calculated, filtered by the desired frequency band, and the average is taken as the feature value. This function is valuable for analyzing brain activity associated with different states, such as rest, cognitive tasks, or sleep stages.


In [None]:
def extract_psd_features(epochs, fmin, fmax):
    # Initialize an empty list to store PSD feature values
    features = []
    # Retrieve the sampling frequency from the epochs info structure
    sfreq = epochs.info['sfreq']
    # Iterate over each epoch within the input data
    for epoch in epochs.get_data():
        # Calculate the Power Spectral Density of the epoch using Welch's method
        f, psd = welch(epoch.squeeze(), sfreq, nperseg=int(sfreq*2))
        # Filter the frequencies to extract power in the band of interest
        idx_band = np.logical_and(f >= fmin, f <= fmax)
        psd_band = psd[idx_band]
        # Calculate the average power in the band and append it to the features list
        features.append(psd_band.mean())
    return np.array(features)



### Band Power Ratio Feature

The `extract_bp_feature` function calculates the ratio of signal power between two frequency bands. It's a feature indicative of the relative expression of brainwave activity across different frequencies and can be used to differentiate between physiological or cognitive states.

- `epochs`: The Epochs object containing the EEG signal data.
- `band1`: A tuple representing the frequency range of the first band.
- `band2`: A tuple representing the frequency range of the second band.

For each epoch, the power within each frequency band is calculated, and their ratio is computed.


In [None]:
def extract_bp_feature(epochs, band1, band2):
    power_ratios = []
    sfreq = epochs.info['sfreq']

    for epoch in epochs.get_data():
        f, psd = welch(epoch.squeeze(), sfreq, nperseg=int(sfreq*2))
        # Calculate power in the designated frequency bands
        band1_power = psd[(f >= band1[0]) & (f <= band1[1])].mean()
        band2_power = psd[(f >= band2[0]) & (f <= band2[1])].mean()
        ratio = band1_power / band2_power
        power_ratios.append(ratio)
    return np.array(power_ratios)



### Continuous Wavelet Transform (CWT) Features

The Continuous Wavelet Transform is a time-frequency analysis tool that allows examining the EEG signal at different frequencies and times. Unlike the Fourier Transform, the wavelet transform maintains temporal information, making it especially useful for non-stationary signals like EEG.

- `epochs`: The Epochs object containing the EEG data.
- `fmin`, `fmax`: The minimum and maximum frequency range of interest.
- `wavelet`: The type of mother wavelet used in CWT.

The `extract_cwt_features` function computes the wavelet transform of EEG signals using the specified scales. The scales are calculated to capture the desired frequency range using the `calculate_scales` function.


In [None]:
def calculate_scales(freq_min, freq_max, fs, wavelet_center_frequency=6):
    scale_min = wavelet_center_frequency * fs / freq_max
    scale_max = wavelet_center_frequency * fs / freq_min

    scales = np.arange(scale_min, scale_max)
    return scales

def extract_cwt_features(epochs, fmin, fmax, wavelet='morl'):
    scales = calculate_scales(fmin, fmax, epochs.info['sfreq'])
    eeg_data = epochs.get_data()
    num_epochs, _, signal_length = eeg_data.shape
    num_features = len(scales)
    features = np.zeros((num_epochs, num_features, signal_length))

    for epoch_idx in range(num_epochs):
        signal = eeg_data[epoch_idx, 0, :]
        coefficients, _ = pywt.cwt(signal, scales, wavelet=wavelet)
        features[epoch_idx, :, :] = coefficients

    return features


### Combining Features

The function `combined_raw_features` is used to merge the features together. The model's input shape is (n_samples, n_features, n_timestamps)

In [1]:
def combined_raw_features(epochs, features):
    raw_data = epochs.get_data()
    n_epochs, n_channels, n_times = raw_data.shape

    # Ensure 'features' is a list, even if it's a single feature
    features = np.atleast_1d(features)
    total_channels = n_channels + len(features)
    combined_data = np.zeros((n_epochs, total_channels, n_times))
    combined_data[:, :n_channels, :] = raw_data

    # Add each feature, repeated across all time points
    for i, feature in enumerate(features):
        if feature.ndim == 1 and feature.size == n_epochs:
            combined_data[:, n_channels + i, :] = np.repeat(feature[:, np.newaxis], n_times, axis=1)
        else:
            msg = "Feature must be a 1D array with length equal to the number of epochs."
            raise ValueError(msg)

    return combined_data

The function `get_raw_feature_all` is used to combine the raw **EEG** with the features.

As we are studying sleep spindles, they mainly occur between the frequencies 11 - 15 Hz. We will filter out in this band.

- PSD: 11 - 15 Hz
- BP feature: 11 - 13Hz and 13Hz - 15Hz. Sleep spindles have been characterised as fast spindles and slow spindles in these frequencies
- CWT: The frequency scales are 11 and 15Hz

In [2]:
def get_raw_feature_all(epochs, fmin, fmax):
    combined_all = None
    for subject in epochs:
        psd = extract_psd_features(subject, fmin, fmax)
        bd_power = extract_bp_feature(subject, (11, 13), (13, fmax))
        cwt = extract_cwt_features(subject, fmin, fmax) 
        combined = combined_raw_features(subject, [psd, bd_power])
        combined = np.concatenate([combined, cwt], axis=1)
        if combined_all is None:
            combined_all = combined
        else:
            combined_all = np.concatenate([combined_all, combined])
    return combined_all