 # Stroke Rehab EEG Analysis Pipeline



 Pipeline to convert .mat files into MNE Raw and Epochs objects and store them in a structured DataFrame.

 ## 🧰 Setups and Imports

In [1]:
import os
import json
from joblib import cpu_count, Memory
import re
import mne
from mne_features.feature_extraction import FeatureExtractor
import numpy as np
import pandas as pd
from scipy.io import loadmat

# Set MNE logging level to WARNING to reduce output verbosity
mne.set_log_level("WARNING")


 ## ⚙️ Constants Definition

In [2]:
DATA_DIR = '/Dev/stroke-rehab-data-analysis/data/stroke-rehab'
FILE_REGEX = r'(?P<subject>P\d+)_(?P<stage>pre|post)_(?P<split>training|test)\.mat'
CHANNEL_NAMES = ['FC3','FCz','FC4','C5','C3','C1','Cz','C2','C4','C6', 'CP3','CP1','CPz','CP2','CP4','Pz']
CHANNEL_TYPES = ['eeg'] * len(CHANNEL_NAMES)
MONTAGE = 'standard_1020'
EVENT_ID={'left': 1, 'right': 2}
N_CORES = 8
output_csv_path='laterality_results.csv'
# Cache directory to speed up computations
cache_path = "/Dev/stroke-rehab-data-analysis/cache"
memory = Memory(location=cache_path, verbose=0)


In [3]:
INVERSE_EVENT_ID = {v: k for k, v in EVENT_ID.items()}
INVERSE_EVENT_ID

{1: 'left', 2: 'right'}

 ## 📂 Data File Paths Parsing

In [4]:
file_entries = []

for fname in os.listdir(DATA_DIR):
    match = re.match(FILE_REGEX, fname)
    if match:
        file_entries.append({
            'filepath': os.path.join(DATA_DIR, fname),
            'subject': match.group('subject'),
            'stage': match.group('stage'),
            'split': match.group('split'),
        })

df = pd.DataFrame(file_entries)
df.head(10)


Unnamed: 0,filepath,subject,stage,split
0,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P2,post,training
1,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P2,post,test
2,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P2,pre,training
3,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P3,pre,training
4,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P1,post,test
5,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P3,post,training
6,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P1,post,training
7,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P3,post,test
8,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P1,pre,test
9,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P2,pre,test


 ## 🧠 MNE Raw Objects Generation

In [5]:
def make_info(subject: str, stage: str, split: str, fs: float) -> mne.Info:
    """
    Create an MNE Info object with metadata.

    Parameters:
    - subject (str): Subject identifier (e.g., 'P1').
    - stage (str): Stage of the experiment (e.g., 'pre' or 'post').
    - split (str): Data split type (e.g., 'training' or 'test').
    - fs (float): Sampling frequency of the data.

    Returns:
    - mne.Info: MNE Info object containing channel information, montage, and metadata.
    """
    info = mne.create_info(
        ch_names=CHANNEL_NAMES,
        ch_types=CHANNEL_TYPES,
        sfreq=fs
    )
    info.set_montage(MONTAGE)

    # Add metadata
    info['subject_info'] = {'his_id': subject}
    info['description'] = json.dumps({'stage': stage, 'split': split})

    return info

In [6]:
def make_annotations(triggers: np.ndarray, fs: float) -> mne.Annotations:
    """
    Create MNE Annotations for the raw data based on trigger events.

    Parameters:
    - triggers (np.ndarray): Array of trigger values indicating event types.
    - fs (float): Sampling frequency of the data.

    Returns:
    - mne.Annotations: Annotations object containing event onsets, durations, and descriptions.
    """
    # Pad triggers to detect changes at the boundaries
    padded = np.r_[0, triggers, 0]
    diffs = np.diff(padded)
    idx = np.where(diffs != 0)[0]
    onsets, offsets = idx[::2], idx[1::2]
    values = triggers[onsets]

    # Calculate onset times and durations
    onset_times = onsets / fs
    annot_durations = (offsets - onsets) / fs
    annot_descriptions = ['left' if val == 1 else 'right' for val in values]

    # Create and return the Annotations object
    annot = mne.Annotations(
        onset=onset_times,
        duration=annot_durations,
        description=annot_descriptions
    )
    return annot

In [7]:
def load_raw_from_mat(filepath: str, subject: str, stage: str, split: str) -> mne.io.Raw:
    """
    Load raw EEG data from a .mat file and return an MNE Raw object.

    Parameters:
    - filepath (str): Path to the .mat file.
    - subject (str): Subject identifier (e.g., 'P1').
    - stage (str): Experiment stage (e.g., 'pre' or 'post').
    - split (str): Data split type (e.g., 'training' or 'test').

    Returns:
    - mne.io.Raw: MNE Raw object containing EEG data and annotations.
    """
    mat: dict = loadmat(filepath)
    data: np.ndarray = mat['y'].T
    triggers: np.ndarray = mat['trig'].ravel()
    fs: float = float(mat['fs'].squeeze())
    
    info: mne.Info = make_info(subject, stage, split, fs)
    raw: mne.io.Raw = mne.io.RawArray(data, info)

    annot: mne.Annotations = make_annotations(triggers, fs)
    raw.set_annotations(annot)

    return raw


In [8]:
df['raw'] = df.apply(
    lambda row: load_raw_from_mat(row['filepath'], row['subject'], row['stage'], row['split']),
    axis=1
)

In [9]:
# Select the simple string columns
meta = df[["subject", "stage", "split"]]
# Create a new DataFrame with the types of the objects
types = df[["raw"]].map(lambda x: type(x).__name__)
# Concatenate both for display
pd.concat([meta, types], axis=1)

Unnamed: 0,subject,stage,split,raw
0,P2,post,training,RawArray
1,P2,post,test,RawArray
2,P2,pre,training,RawArray
3,P3,pre,training,RawArray
4,P1,post,test,RawArray
5,P3,post,training,RawArray
6,P1,post,training,RawArray
7,P3,post,test,RawArray
8,P1,pre,test,RawArray
9,P2,pre,test,RawArray


 ## ✂️ MNE Epochs Objects Generation

In [10]:
def create_epochs_from_raw(raw: mne.io.Raw) -> mne.Epochs:
    """
    Create MNE Epochs from a Raw object.

    Parameters:
    - raw (mne.io.Raw): The MNE Raw object containing EEG data and annotations.

    Returns:
    - mne.Epochs: The MNE Epochs object created from the raw data.
    """
    fs: float = raw.info['sfreq']
    events, event_id = mne.events_from_annotations(raw, event_id=EVENT_ID)
    events[:, 0] += int(2 * fs)  # Shift events forward by 2 seconds as per task description

    metadata = [json.loads(raw.info['description'])]*events.shape[0]
    metadata_df = pd.DataFrame(metadata)
    
    epochs: mne.Epochs = mne.Epochs(raw, events, tmin=-1.5, tmax=6.0, event_id=event_id, metadata=metadata_df, baseline=(-1.5, 0), preload=True)
    
    return epochs


In [11]:
df['epochs'] = df['raw'].apply(create_epochs_from_raw)


 ## 🧾 Final DataFrame Structure

In [12]:
# Select the simple string columns
meta = df[["subject", "stage", "split"]]
# Create a new DataFrame with the types of the objects
types = df[["raw","epochs"]].map(lambda x: type(x).__name__)
# Concatenate both for display
pd.concat([meta, types], axis=1)

Unnamed: 0,subject,stage,split,raw,epochs
0,P2,post,training,RawArray,Epochs
1,P2,post,test,RawArray,Epochs
2,P2,pre,training,RawArray,Epochs
3,P3,pre,training,RawArray,Epochs
4,P1,post,test,RawArray,Epochs
5,P3,post,training,RawArray,Epochs
6,P1,post,training,RawArray,Epochs
7,P3,post,test,RawArray,Epochs
8,P1,pre,test,RawArray,Epochs
9,P2,pre,test,RawArray,Epochs


# Laterality Coefficient

In [13]:
import numpy as np
import pandas as pd
import mne
from mne_features.feature_extraction import FeatureExtractor

def calculate_batch_laterality_coefficients(epochs, active_time=(2.0, 6.0), baseline_time=(-1.5, 0)):
    """
    Calculate laterality coefficients for multiple epochs and events efficiently.
    
    Parameters
    ----------
    epochs : mne.Epochs
        The epochs object containing EEG data.
    freq_band : tuple, optional
        The frequency band of interest (default: mu rhythm 8-13 Hz).
    active_time : tuple, optional
        Time window for calculating ERD/ERS (default: 2 to 6 seconds post-stimulus).
    baseline_time : tuple, optional
        Time window for baseline (default: -1.5 to 0 seconds pre-stimulus).
    
    Returns
    -------
    results_df : pandas.DataFrame
        DataFrame containing the laterality coefficients and associated metadata.
    """
    
    reverse_mapping = {v: k for k, v in epochs.event_id.items()}
    events_list = [reverse_mapping[num] for num in epochs.events[:, 2]]
    
    # Ensure we have C3 and C4 channels
    if 'C3' not in epochs.ch_names or 'C4' not in epochs.ch_names:
        raise ValueError("Channels C3 and C4 must be present in the data")
    
    # Extract baseline epochs
    baseline_epochs = epochs.copy().crop(tmin=baseline_time[0], tmax=baseline_time[1])
    
    # Extract active epochs
    active_epochs = epochs.copy().crop(tmin=active_time[0], tmax=active_time[1])
    
    freq_bands = {f'mu_band': (8.0, 13.0)}
    selected_funcs = ['pow_freq_bands']
    feature_extractor = FeatureExtractor(
        sfreq=epochs.info['sfreq'], 
        selected_funcs=selected_funcs, 
        params={'pow_freq_bands__freq_bands': freq_bands},
        n_jobs=N_CORES,
        memory=cache_path
    )
    
    # Extract features
    baseline_features = feature_extractor.fit_transform(baseline_epochs.pick(('C3','C4')).get_data())
    active_features = feature_extractor.fit_transform(active_epochs.pick(('C3','C4')).get_data())
    c3_idx, c4_idx = 0, 1
    
    # Calculate ERD/ERS for each epoch
    # ERD/ERS = (active - baseline) / baseline
    erd_ers_c3 = (active_features[:, c3_idx] - baseline_features[:, c3_idx]) / baseline_features[:, c3_idx]
    erd_ers_c4 = (active_features[:, c4_idx] - baseline_features[:, c4_idx]) / baseline_features[:, c4_idx]
    
    # Initialize results list
    results = []
    
    # Calculate laterality coefficient for each epoch
    for i, event in enumerate(events_list):
        # Determine contralateral and ipsilateral hemispheres based on hand movement
        if event == 'right':
            # Left hemisphere (C3) is contralateral to right hand
            contralateral_value = erd_ers_c3[i]
            ipsilateral_value = erd_ers_c4[i]
        elif event == 'left':
            # Right hemisphere (C4) is contralateral to left hand
            contralateral_value = erd_ers_c4[i]
            ipsilateral_value = erd_ers_c3[i]
        
        # Calculate laterality coefficient
        # Handle potential division by zero or NaN values
        try:
            lc = (contralateral_value - ipsilateral_value) / (contralateral_value + ipsilateral_value)
            # Check if result is valid
            if np.isnan(lc) or np.isinf(lc):
                lc = np.nan
        except:
            lc = np.nan
        
        # Add result to list
        results.append({
            'event': event,
            'LC': lc
        })
    
    # Convert to DataFrame
    results_df = pd.DataFrame(results)
    
    return results_df


In [14]:
import os

csv_path = 'laterality_results.csv'
results = []

# Check if the CSV file already exists
file_exists = os.path.exists(csv_path)

for _, row in df.iterrows():
    subject_id = row['subject']
    stage = row['stage']
    split = row['split']
    epochs = row['epochs']
    
    # Compute LC DataFrame
    lc_df = calculate_batch_laterality_coefficients(epochs)
    
    # Add metadata
    lc_df['subject'] = subject_id
    lc_df['stage'] = stage
    lc_df['split'] = split
    
    # Reorder columns
    lc_df = lc_df[['subject', 'stage', 'split', 'event', 'LC']]
    
    # Append to CSV
    lc_df.to_csv(csv_path, mode='a', index=False, header=not file_exists)
    
    # After first write, set file_exists to True
    file_exists = True


________________________________________________________________________________
[Memory] Calling mne_features.feature_extraction.extract_features...
extract_features(array([[[  5.241694, ..., -27.677387],
        [ -6.65168 , ...,  -0.343422]],

       ...,

       [[ 33.314821, ..., -15.910857],
        [ 56.859787, ..., -11.434887]]], shape=(80, 2, 385)), 
256.0, ['pow_freq_bands'], funcs_params={'pow_freq_bands__freq_bands': {'mu_band': (8.0, 13.0)}}, n_jobs=8)
_________________________________________________extract_features - 8.1s, 0.1min
________________________________________________________________________________
[Memory] Calling mne_features.feature_extraction.extract_features...
extract_features(array([[[-69.153068, ..., -78.659858],
        [-49.978622, ..., -48.436928]],

       ...,

       [[-27.377867, ...,  39.387838],
        [-40.901349, ...,  36.419025]]], shape=(80, 2, 1025)), 
256.0, ['pow_freq_bands'], funcs_params={'pow_freq_bands__freq_bands': {'mu_band': (8.