In [0]:
# configuration
import os
import pandas as pd
import numpy as np
import mne

raw_data_path = dbutils.widgets.get("raw_data_path")
processed_features_path = dbutils.widgets.get("processed_features_path")

subject_id_to_process = dbutils.widgets.get("subject_id")
if not subject_id_to_process:
    dbutils.notebook.exit("Subject ID widget cannot be empty.")

raw_file_name = f"subject_{subject_id_to_process}_eeg_stim.parquet"
raw_eeg_file_path = os.path.join(raw_data_path, f"subject_{subject_id_to_process}_eeg_stim.parquet")


# configuration for signal processing and feature extraction
PREPROCESSING_CONFIG = {
    "sfreq": 512,
    "filter_l_freq": 1.0,
    "filter_h_freq": 20.0,
    "notch_freqs": np.arange(50, 251, 50).tolist(),
    "epoch_tmin": -0.1,
    "epoch_tmax": 0.6,
    "baseline_correction": (-0.1, 0.0),
    "reject_criteria_eeg": 100e-6,
    "event_id": {'NonTarget': 1, 'Target': 2},
    "stim_channel_name": "stim"
}
FEATURE_EXTRACTION_CONFIG = {
    "resample_sfreq": 100
}


# helper functions
def create_mne_raw_from_df(df, sfreq, stim_channel_name="stim"):
    eeg_channels = [col for col in df.columns if col.startswith('eeg_ch')]
    all_channels = eeg_channels + [stim_channel_name]
    eeg_types = ['eeg'] * len(eeg_channels)
    all_types = eeg_types + ['stim']

    eeg_data = df[eeg_channels].T.values * 1e-6
    stim_data = df[stim_channel_name].values.astype(np.int64)[None, :]
    data_for_mne = np.concatenate([eeg_data, stim_data], axis=0)

    info = mne.create_info(ch_names=all_channels, sfreq=sfreq, ch_types=all_types, verbose=False)
    raw = mne.io.RawArray(data_for_mne, info, verbose=False)
    return raw

def preprocess_raw_eeg(mne_raw_object, config):
    mne_raw_object.filter(config['filter_l_freq'], config['filter_h_freq'], fir_design='firwin', verbose=False)
    valid_notch_freqs = [f for f in config.get('notch_freqs', []) if config['filter_l_freq'] < f < config['filter_h_freq']]
    if valid_notch_freqs:
        mne_raw_object.notch_filter(freqs=valid_notch_freqs, fir_design='firwin', verbose=False)
    events = mne.find_events(mne_raw_object, stim_channel=config['stim_channel_name'], shortest_event=1, verbose=False)
    return mne_raw_object, events

def epoch_data(mne_raw_object, events, config):
    epochs = mne.Epochs(mne_raw_object, events, config['event_id'], tmin=config['epoch_tmin'], tmax=config['epoch_tmax'],
                        baseline=config['baseline_correction'], picks='eeg', preload=True, reject_by_annotation=False, verbose=False)
    if config.get('reject_criteria_eeg'):
        epochs.drop_bad(reject={'eeg': config['reject_criteria_eeg']}, verbose=False)
    return epochs

def extract_features(mne_epochs_object, config):
    if not len(mne_epochs_object):
        return np.array([]), np.array([])
    epochs_resampled = mne_epochs_object.copy().resample(sfreq=config['resample_sfreq'], npad='auto', verbose=False)
    x_data = epochs_resampled.get_data()
    y_labels = epochs_resampled.events[:, -1]
    n_epochs, n_channels, n_times_resampled = x_data.shape
    x_flattened = x_data.reshape(n_epochs, n_channels * n_times_resampled)
    return x_flattened, y_labels


# execute preprocessing pipeline
df_subject_raw = pd.read_parquet(raw_eeg_file_path, engine='pyarrow')
raw_mne = create_mne_raw_from_df(df_subject_raw, PREPROCESSING_CONFIG['sfreq'])

raw_mne_processed, events = preprocess_raw_eeg(raw_mne, PREPROCESSING_CONFIG)
if events is None or len(events) == 0:
    display(f"  no events found for {subject_id_to_process}. skipping.")

epochs_subject = epoch_data(raw_mne_processed, events, PREPROCESSING_CONFIG)
if epochs_subject is None or len(epochs_subject) == 0:
    display(f"  no epochs remaining for {subject_id_to_process} after artifact rejection. skipping.")

x_subject_features, y_subject_labels = extract_features(epochs_subject, FEATURE_EXTRACTION_CONFIG)
if x_subject_features.size == 0:
    display(f"  no features extracted for {subject_id_to_process}. skipping.")


# save processed data into a dedicated folder for the subject
output_subject_path = os.path.join(processed_features_path, subject_id_to_process)
os.makedirs(output_subject_path, exist_ok=True)

df_features = pd.DataFrame(x_subject_features)
df_features.columns = df_features.columns.astype(str)

df_labels = pd.DataFrame(y_subject_labels, columns=['label'])
df_labels.columns = df_labels.columns.astype(str)

df_features.to_parquet(os.path.join(output_subject_path, "features.parquet"))
df_labels.to_parquet(os.path.join(output_subject_path, "labels.parquet"))