In [2]:
import numpy as np
import mne
import os
import matplotlib.pyplot as plt
import seaborn as sns
import glob
from scipy.signal import welch
from sklearn.model_selection import cross_val_score
from tqdm import tqdm 
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

In [3]:
#=============SETTINGS============
RAW_DATA_DIR = "../data/raw"
TASK_PARAGIDM="left_right_hand"

# Utils

In [4]:
#Run Type to Task Map
run_type_to_task = {
    "R01": {
        "name": "Baseline - Eyes Open",
        "task_type": "baseline",
        "labels": None
    },
    "R02": {
        "name": "Baseline - Eyes Closed",
        "task_type": "baseline",
        "labels": None
    },
    "R03": {
        "name": "Task 1 - Real Left/Right Fist",
        "task_type": "motor_execution",
        "paradigm": "left_right_hand",
        "labels": {
            "T1": "left_fist",
            "T2": "right_fist"
        }
    },
    "R04": {
        "name": "Task 2 - Imagine Left/Right Fist",
        "task_type": "motor_imagery",
        "paradigm": "left_right_hand",
        "labels": {
            "T1": "left_fist",
            "T2": "right_fist"
        }
    },
    "R05": {
        "name": "Task 3 - Real Fists/Feet",
        "task_type": "motor_execution",
        "paradigm": "hands_feet",
        "labels": {
            "T1": "both_fists",
            "T2": "both_feet"
        }
    },
    "R06": {
        "name": "Task 4 - Imagine Fists/Feet",
        "task_type": "motor_imagery",
        "paradigm": "hands_feet",
        "labels": {
            "T1": "both_fists",
            "T2": "both_feet"
        }
    },
    "R07": {
        "name": "Task 1 - Real Left/Right Fist",
        "task_type": "motor_execution",
        "paradigm": "left_right_hand",
        "labels": {
            "T1": "left_fist",
            "T2": "right_fist"
        }
    },
    "R08": {
        "name": "Task 2 - Imagine Left/Right Fist",
        "task_type": "motor_imagery",
        "paradigm": "left_right_hand",
        "labels": {
            "T1": "left_fist",
            "T2": "right_fist"
        }
    },
    "R09": {
        "name": "Task 3 - Real Fists/Feet",
        "task_type": "motor_execution",
        "paradigm": "hands_feet",
        "labels": {
            "T1": "both_fists",
            "T2": "both_feet"
        }
    },
    "R10": {
        "name": "Task 4 - Imagine Fists/Feet",
        "task_type": "motor_imagery",
        "paradigm": "hands_feet",
        "labels": {
            "T1": "both_fists",
            "T2": "both_feet"
        }
    },
    "R11": {
        "name": "Task 1 - Real Left/Right Fist",
        "task_type": "motor_execution",
        "paradigm": "left_right_hand",
        "labels": {
            "T1": "left_fist",
            "T2": "right_fist"
        }
    },
    "R12": {
        "name": "Task 2 - Imagine Left/Right Fist",
        "task_type": "motor_imagery",
        "paradigm": "left_right_hand",
        "labels": {
            "T1": "left_fist",
            "T2": "right_fist"
        }
    },
    "R13": {
        "name": "Task 3 - Real Fists/Feet",
        "task_type": "motor_execution",
        "paradigm": "hands_feet",
        "labels": {
            "T1": "both_fists",
            "T2": "both_feet"
        }
    },
    "R14": {
        "name": "Task 4 - Imagine Fists/Feet",
        "task_type": "motor_imagery",
        "paradigm": "hands_feet",
        "labels": {
            "T1": "both_fists",
            "T2": "both_feet"
        }
    }
}

MOTOR_CHANNELS = [
    'C3..',   # Left motor cortex (primary)
    'Cz..',   # Central motor area (feet)
    'C4..',   # Right motor cortex (primary)
    'Fc3.',   # Left frontal-central (premotor)
    'Fc4.',   # Right frontal-central (premotor)
    'Cp3.',   # Left central-parietal (sensorimotor)
    'Cp4.',   # Right central-parietal (sensorimotor)
    'C5..',   # Left lateral motor
    'C1..',   # Left medial motor
    'C2..',   # Right medial motor
    'C6..',   # Right lateral motor
    'Fc1.',   # Left medial frontal-central
    'Fc2.',   # Right medial frontal-central
    'Fc5.',   # Left lateral frontal-central
    'Fc6.',   # Right lateral frontal-central
    'Cp1.',   # Left medial central-parietal
    'Cp2.',   # Right medial central-parietal
    'Cp5.',   # Left lateral central-parietal
    'Cp6.'    # Right lateral central-parietal
]

extract_task_id = lambda filepath: 'R' + filepath.split('R')[-1].split('.')[0]

def rename_annotations(raw, run_type):
    """
    Rename MNE annotations to readable task labels
    """
    task_info = run_type_to_task[run_type]
    
    if task_info['labels'] is not None:
        # Map T0, T1, T2 to readable names
        annotation_mapping = {
            'T0': 'rest',
            'T1': task_info['labels']['T1'],
            'T2': task_info['labels']['T2']
        }
        
        raw.annotations.rename(annotation_mapping)
    
    # Extract events with new names
    events, event_dict = mne.events_from_annotations(raw)
    
    return event_dict

def select_eeg_files_for_subject_by_paradigm(subject_id, paradigm, task_type="motor_imagery"):
    """
    Select motor imagery EEG files for a subject by paradigm
    """
    matching_runs = [
        run_id for run_id, info in run_type_to_task.items()
        if info.get('paradigm') == paradigm and info['task_type'] == task_type
    ]
    
    return [
        os.path.join(RAW_DATA_DIR, subject_id, f"{subject_id}{run_id}.edf")
        for run_id in sorted(matching_runs)
    ]


def load_and_concatenate_to_epochs(files, tmin=-0.5, tmax=3.5, picks=MOTOR_CHANNELS):
    """
    Load multiple EDF files and concatenate into single Epochs object
    """
    epochs_list = []
    
    for filepath in files:
        # Extract run_id
        run_id = extract_task_id(filepath)
        
        # Load raw
        raw = mne.io.read_raw_edf(filepath, preload=True, verbose=False)
        
        # Rename annotations
        event_dict = rename_annotations(raw, run_id)
        
        # Get events (exclude rest)
        events, _ = mne.events_from_annotations(raw)
        event_id = {k: v for k, v in event_dict.items() if k != 'rest'}
        
        # Create epochs
        epochs = mne.Epochs(
            raw, events, event_id=event_id,
            tmin=tmin, tmax=tmax,
            baseline=None,
            picks=picks,
            preload=True,
            verbose=False
        )
        
        epochs_list.append(epochs)
    
    # Concatenate
    return mne.concatenate_epochs(epochs_list)



# Classification Pipeline

## Pipeline

In [14]:
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from mne.decoding import CSP
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from tqdm import tqdm

mne.set_log_level('WARNING')

#HYPERPARAMS
TEST_SIZE=0.20
K_FOLDS=5
N_COMPONENTS=6
RANDOM_STATE=42
L_FREQ=8
H_FREQ=30
TMIN=0.0
TMAX=2.0
TASK_TYPE="motor_execution"

def get_epochs_for_subject(subject_id, paradigm, task_type):
    files = select_eeg_files_for_subject_by_paradigm(subject_id, paradigm, task_type)
    epochs = load_and_concatenate_to_epochs(files, tmin=TMIN, tmax=TMAX, picks=MOTOR_CHANNELS)
    epochs.filter(l_freq=L_FREQ, h_freq=H_FREQ, verbose=False) #Filter
    return epochs


def get_pipeline_scores_for_subject(subject_id, paradigm, task_type):

    epochs = get_epochs_for_subject(subject_id, paradigm, task_type)
    
    # Extract data and labels
    X = epochs.get_data()  # (n_trials, n_channels, n_times)
    y = epochs.events[:, -1]  # Labels (2 or 3 for left/right)
    
    # Train/test split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=TEST_SIZE, random_state=RANDOM_STATE, stratify=y
    )
    
    # Create pipeline
    pipeline = Pipeline([
        ('csp', CSP(n_components=N_COMPONENTS, reg=None, log=True)),
        ('lda', LinearDiscriminantAnalysis())
    ])
    
    # Cross-validation on training set
    cv_scores = cross_val_score(pipeline, X_train, y_train, cv=K_FOLDS)
    
    # Train on full training set and evaluate on test set
    pipeline.fit(X_train, y_train)
    test_score = pipeline.score(X_test, y_test)
    
    return {
        'cv_score': cv_scores.mean(),
        'cv_std': cv_scores.std(),
        'test_score': test_score
    }



subject_ids = [f"S{i:03d}" for i in range(1, 110)]


In [15]:
def evaluate_pipeline_on_experiment(task_paradigm, task_type):
    
    subject_results = []
    LIMIT=109
    
    for i in tqdm(range(LIMIT)):
        results = get_pipeline_scores_for_subject(subject_ids[i], task_paradigm, task_type)
        subject_results.append(results)

    
    results_df = pd.DataFrame(subject_results)
    return {
        "task_paradigm" : task_paradigm,
        "task_type" : task_type,
        "cv_score" : float(results_df.mean()["cv_score"]),
        "test_score" : float(results_df.mean()["test_score"])
    }


    

## Evaluate on all experiments


In [16]:
# Run 4 experiments - print each as it completes
experiments = []

exp0 = evaluate_pipeline_on_experiment('left_right_hand', 'motor_execution')
print(f"experiment 0: accuracy = {exp0['test_score']:.4f}")
experiments.append(exp0)

exp1 = evaluate_pipeline_on_experiment('left_right_hand', 'motor_imagery')
print(f"experiment 1: accuracy = {exp1['test_score']:.4f}")
experiments.append(exp1)

exp2 = evaluate_pipeline_on_experiment('hands_feet', 'motor_execution')
print(f"experiment 2: accuracy = {exp2['test_score']:.4f}")
experiments.append(exp2)

exp3 = evaluate_pipeline_on_experiment('hands_feet', 'motor_imagery')
print(f"experiment 3: accuracy = {exp3['test_score']:.4f}")
experiments.append(exp3)

# Final mean
mean_accuracy = sum(e['test_score'] for e in experiments) / 4
print(f"\nMean accuracy of 4 experiments: {mean_accuracy:.4f}")

100%|██████████████████████████████████████████████████████████████████████████████████████| 109/109 [01:13<00:00,  1.49it/s]


experiment 0: accuracy = 0.5928


100%|██████████████████████████████████████████████████████████████████████████████████████| 109/109 [01:15<00:00,  1.44it/s]


experiment 1: accuracy = 0.5703


  7%|██████▍                                                                                 | 8/109 [00:06<01:21,  1.24it/s]


KeyboardInterrupt: 