 # Stroke Rehab EEG Analysis Pipeline



 Pipeline to convert .mat files into MNE Raw and Epochs objects and store them in a structured DataFrame.

 ## 🧰 Setups and Imports

In [113]:
import os
import json
from joblib import cpu_count, Memory
import re
import mne
from mne_features.feature_extraction import FeatureExtractor
import numpy as np
import pandas as pd
from scipy.io import loadmat

# Set MNE logging level to WARNING to reduce output verbosity
mne.set_log_level("WARNING")


 ## ⚙️ Constants Definition

In [114]:
DATA_DIR = '/Dev/stroke-rehab-data-analysis/data/stroke-rehab'
FILE_REGEX = r'(?P<subject>P\d+)_(?P<stage>pre|post)_(?P<split>training|test)\.mat'
CHANNEL_NAMES = ['FC3','FCz','FC4','C5','C3','C1','Cz','C2','C4','C6', 'CP3','CP1','CPz','CP2','CP4','Pz']
CHANNEL_TYPES = ['eeg'] * len(CHANNEL_NAMES)
MONTAGE = 'standard_1020'
EVENT_ID={'left': 1, 'right': 2}
N_CORES = 8
output_csv_path='laterality_results.csv'
# Cache directory to speed up computations
cache_path = "/Dev/stroke-rehab-data-analysis/cache"
memory = Memory(location=cache_path, verbose=0)


In [115]:
INVERSE_EVENT_ID = {v: k for k, v in EVENT_ID.items()}
INVERSE_EVENT_ID

{1: 'left', 2: 'right'}

 ## 📂 Data File Paths Parsing

In [116]:
file_entries = []

for fname in os.listdir(DATA_DIR):
    match = re.match(FILE_REGEX, fname)
    if match:
        file_entries.append({
            'filepath': os.path.join(DATA_DIR, fname),
            'subject': match.group('subject'),
            'stage': match.group('stage'),
            'split': match.group('split'),
        })

df = pd.DataFrame(file_entries)
df.head(10)


Unnamed: 0,filepath,subject,stage,split
0,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P2,post,training
1,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P2,post,test
2,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P2,pre,training
3,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P3,pre,training
4,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P1,post,test
5,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P3,post,training
6,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P1,post,training
7,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P3,post,test
8,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P1,pre,test
9,/Dev/stroke-rehab-data-analysis/data/stroke-re...,P2,pre,test


 ## 🧠 MNE Raw Objects Generation

In [117]:
def make_info(subject: str, stage: str, split: str, fs: float) -> mne.Info:
    """
    Create an MNE Info object with metadata.

    Parameters:
    - subject (str): Subject identifier (e.g., 'P1').
    - stage (str): Stage of the experiment (e.g., 'pre' or 'post').
    - split (str): Data split type (e.g., 'training' or 'test').
    - fs (float): Sampling frequency of the data.

    Returns:
    - mne.Info: MNE Info object containing channel information, montage, and metadata.
    """
    info = mne.create_info(
        ch_names=CHANNEL_NAMES,
        ch_types=CHANNEL_TYPES,
        sfreq=fs
    )
    info.set_montage(MONTAGE)

    # Add metadata
    info['subject_info'] = {'his_id': subject}
    info['description'] = json.dumps({'stage': stage, 'split': split})

    return info

In [118]:
def make_annotations(triggers: np.ndarray, fs: float) -> mne.Annotations:
    """
    Create MNE Annotations for the raw data based on trigger events.

    Parameters:
    - triggers (np.ndarray): Array of trigger values indicating event types.
    - fs (float): Sampling frequency of the data.

    Returns:
    - mne.Annotations: Annotations object containing event onsets, durations, and descriptions.
    """
    # Pad triggers to detect changes at the boundaries
    padded = np.r_[0, triggers, 0]
    diffs = np.diff(padded)
    idx = np.where(diffs != 0)[0]
    onsets, offsets = idx[::2], idx[1::2]
    values = triggers[onsets]

    # Calculate onset times and durations
    onset_times = onsets / fs
    annot_durations = (offsets - onsets) / fs
    annot_descriptions = ['left' if val == 1 else 'right' for val in values]

    # Create and return the Annotations object
    annot = mne.Annotations(
        onset=onset_times,
        duration=annot_durations,
        description=annot_descriptions
    )
    return annot

In [119]:
def load_raw_from_mat(filepath: str, subject: str, stage: str, split: str) -> mne.io.Raw:
    """
    Load raw EEG data from a .mat file and return an MNE Raw object.

    Parameters:
    - filepath (str): Path to the .mat file.
    - subject (str): Subject identifier (e.g., 'P1').
    - stage (str): Experiment stage (e.g., 'pre' or 'post').
    - split (str): Data split type (e.g., 'training' or 'test').

    Returns:
    - mne.io.Raw: MNE Raw object containing EEG data and annotations.
    """
    mat: dict = loadmat(filepath)
    data: np.ndarray = mat['y'].T
    triggers: np.ndarray = mat['trig'].ravel()
    fs: float = float(mat['fs'].squeeze())
    
    info: mne.Info = make_info(subject, stage, split, fs)
    raw: mne.io.Raw = mne.io.RawArray(data, info)

    annot: mne.Annotations = make_annotations(triggers, fs)
    raw.set_annotations(annot)
    #raw = raw.copy().filter(1., 40.)
    return raw


In [120]:
df['raw'] = df.apply(
    lambda row: load_raw_from_mat(row['filepath'], row['subject'], row['stage'], row['split']),
    axis=1
)

In [121]:
# Select the simple string columns
meta = df[["subject", "stage", "split"]]
# Create a new DataFrame with the types of the objects
types = df[["raw"]].map(lambda x: type(x).__name__)
# Concatenate both for display
pd.concat([meta, types], axis=1)

Unnamed: 0,subject,stage,split,raw
0,P2,post,training,RawArray
1,P2,post,test,RawArray
2,P2,pre,training,RawArray
3,P3,pre,training,RawArray
4,P1,post,test,RawArray
5,P3,post,training,RawArray
6,P1,post,training,RawArray
7,P3,post,test,RawArray
8,P1,pre,test,RawArray
9,P2,pre,test,RawArray


 ## ✂️ MNE Epochs Objects Generation

In [122]:
def create_epochs_from_raw(raw: mne.io.Raw) -> mne.Epochs:
    """
    Create MNE Epochs from a Raw object.

    Parameters:
    - raw (mne.io.Raw): The MNE Raw object containing EEG data and annotations.

    Returns:
    - mne.Epochs: The MNE Epochs object created from the raw data.
    """
    fs: float = raw.info['sfreq']
    events, event_id = mne.events_from_annotations(raw, event_id=EVENT_ID)
    events[:, 0] += int(2 * fs)  # Shift events forward by 2 seconds as per task description

    metadata = [json.loads(raw.info['description'])]*events.shape[0]
    metadata_df = pd.DataFrame(metadata)
    
    epochs: mne.Epochs = mne.Epochs(raw, events, tmin=-1.5, tmax=6.0, event_id=event_id, metadata=metadata_df, baseline=(-1.5, 0), preload=True)
    
    return epochs


In [123]:
df['epochs'] = df['raw'].apply(create_epochs_from_raw)


 ## 🧾 Final DataFrame Structure

In [124]:
# Select the simple string columns
meta = df[["subject", "stage", "split"]]
# Create a new DataFrame with the types of the objects
types = df[["raw","epochs"]].map(lambda x: type(x).__name__)
# Concatenate both for display
pd.concat([meta, types], axis=1)

Unnamed: 0,subject,stage,split,raw,epochs
0,P2,post,training,RawArray,Epochs
1,P2,post,test,RawArray,Epochs
2,P2,pre,training,RawArray,Epochs
3,P3,pre,training,RawArray,Epochs
4,P1,post,test,RawArray,Epochs
5,P3,post,training,RawArray,Epochs
6,P1,post,training,RawArray,Epochs
7,P3,post,test,RawArray,Epochs
8,P1,pre,test,RawArray,Epochs
9,P2,pre,test,RawArray,Epochs


# Laterality Coefficient

In [125]:
import numpy as np
import pandas as pd
from mne.decoding import CSP
from mne_features.feature_extraction import FeatureExtractor

def calculate_batch_laterality_coefficients(epochs, active_time=(2.0, 6.0), baseline_time=(-1.5, 0)):
    """
    Calculate laterality coefficients using the FIRST CSP component
    for left vs. right hemispheres, rather than channels C3/C4.
    """
    # map integer events back to names
    reverse_mapping = {v: k for k, v in epochs.event_id.items()}
    events_list = [reverse_mapping[num] for num in epochs.events[:, 2]]

    # define your two hemi channel‑sets
    left_chs  = ['FC3', 'C5', 'C3', 'C1', 'CP3', 'CP1']
    right_chs = ['FC4', 'C6', 'C4', 'C2', 'CP4', 'CP2']

    # crop baseline & active windows
    baseline_epochs = epochs.copy().crop(tmin=baseline_time[0], tmax=baseline_time[1])
    active_epochs   = epochs.copy().crop(tmin=active_time[0], tmax=active_time[1])

    # --- 1) Fit CSP on the FULL trial data for each hemisphere ---
    labels = epochs.events[:, 2]

    # left CSP
    left_full = epochs.copy().pick(left_chs).get_data()
    csp_left = CSP(n_components=1, transform_into='csp_space')
    csp_left.fit(left_full, labels)

    # right CSP
    right_full = epochs.copy().pick(right_chs).get_data()
    csp_right = CSP(n_components=1, transform_into='csp_space')
    csp_right.fit(right_full, labels)

    # --- 2) Transform baseline & active into CSP space (shape: n_epochs × 1 × n_times) ---
    base_left  = csp_left.transform(baseline_epochs.copy().pick(left_chs).get_data())
    act_left   = csp_left.transform(active_epochs.copy().pick(left_chs).get_data())
    base_right = csp_right.transform(baseline_epochs.copy().pick(right_chs).get_data())
    act_right  = csp_right.transform(active_epochs.copy().pick(right_chs).get_data())

    # --- 3) Extract mu‑band power on those 1‑component time series ---
    freq_bands = {'mu_band': (7.0, 13.0)}
    fx = FeatureExtractor(
        sfreq=epochs.info['sfreq'],
        selected_funcs=['pow_freq_bands'],
        params={'pow_freq_bands__freq_bands': freq_bands},
        n_jobs=N_CORES,
        memory=cache_path
    )

    # each returns (n_epochs × n_features). Here n_features = 1 (the mu band power).
    base_feat_L = fx.fit_transform(base_left)[:, 0]
    act_feat_L  = fx.fit_transform(act_left)[:, 0]
    base_feat_R = fx.fit_transform(base_right)[:, 0]
    act_feat_R  = fx.fit_transform(act_right)[:, 0]

    # --- 4) Compute ERD/ERS per side ---
    erd_ers_L = (act_feat_L - base_feat_L) / base_feat_L
    erd_ers_R = (act_feat_R - base_feat_R) / base_feat_R

    # --- 5) Build Laterality Coefficient from left‑vs‑right CSP components ---
    results = []
    for i, ev in enumerate(events_list):
        if ev == 'right':
            # contralateral = left CSP; ipsilateral = right CSP
            contra = erd_ers_L[i]
            ipsi   = erd_ers_R[i]
        elif ev == 'left':
            # contralateral = right CSP; ipsilateral = left CSP
            contra = erd_ers_R[i]
            ipsi   = erd_ers_L[i]
        else:
            # skip unknown events
            continue

        # laterality coefficient
        lc = np.nan
        denom = (contra + ipsi)
        if denom != 0:
            lc = (contra - ipsi) / denom

        results.append({'event': ev, 'LC': lc})

    return pd.DataFrame(results)


In [126]:
import os

csv_path = 'laterality_results.csv'
results = []

# Check if the CSV file already exists
file_exists = os.path.exists(csv_path)

for _, row in df.iterrows():
    subject_id = row['subject']
    stage = row['stage']
    split = row['split']
    epochs = row['epochs']
    
    # Compute LC DataFrame
    lc_df = calculate_batch_laterality_coefficients(epochs)
    
    # Add metadata
    lc_df['subject'] = subject_id
    lc_df['stage'] = stage
    lc_df['split'] = split
    
    # Reorder columns
    lc_df = lc_df[['subject', 'stage', 'split', 'event', 'LC']]
    
    # Append to CSV
    lc_df.to_csv(csv_path, mode='a', index=False, header=not file_exists)
    
    # After first write, set file_exists to True
    file_exists = True


________________________________________________________________________________
[Memory] Calling mne_features.feature_extraction.extract_features...
extract_features(array([[[0.581217, ..., 0.083832]],

       ...,

       [[0.065169, ..., 0.195579]]], shape=(80, 1, 385)), 
256.0, ['pow_freq_bands'], funcs_params={'pow_freq_bands__freq_bands': {'mu_band': (7.0, 13.0)}}, n_jobs=8)
_________________________________________________extract_features - 8.7s, 0.1min
________________________________________________________________________________
[Memory] Calling mne_features.feature_extraction.extract_features...
extract_features(array([[[-0.157921, ...,  0.069508]],

       ...,

       [[ 0.138862, ..., -0.420336]]], shape=(80, 1, 1025)), 
256.0, ['pow_freq_bands'], funcs_params={'pow_freq_bands__freq_bands': {'mu_band': (7.0, 13.0)}}, n_jobs=8)
_________________________________________________extract_features - 0.0s, 0.0min
_________________________________________________________________