In [None]:
import numpy as np
import mne
import matplotlib.pyplot as plt
import pandas as pd
import torch

In [None]:
# ---
# EDF Raw Data Visualizer / Checker
# Inputs:
# - edf_1 = path to the first .edf EEG file
# - edf_2 = path to a second .edf EEG file
# - edf_3 = path to a third .edf EEG file
# - n_axes = num of total axes. Default = 3
# ---

edf_1 = './physionet.org/files/eegmmidb/1.0.0/S001/S001R03.edf'
edf_2 = './physionet.org/files/eegmmidb/1.0.0/S001/S001R07.edf'
edf_3 = './physionet.org/files/eegmmidb/1.0.0/S001/S001R11.edf'
n_axes = 3


def plot_edf_segment(path, ax, seconds=10, max_channels=8):
    raw = mne.io.read_raw_edf(path, preload=False, verbose='ERROR')
    pick_list = np.arange(min(max_channels, raw.info['nchan']))  # safety net for max_channels > nchan
    stop = int(raw.info['sfreq'] * seconds)
    data = raw.get_data(picks=pick_list, start=0, stop=stop)
    times = np.arange(data.shape[1]) / raw.info['sfreq']

    scale = np.nanmax(np.abs(data))
    offsets = np.arange(data.shape[0]) * 0.000584 * 1.2  # 0.000584 derived from commit e007fb1
    for idx, trace in enumerate(data):
        ax.plot(times, trace + offsets[idx], linewidth=0.6)

    ax.set(
        yticks=offsets,
        yticklabels=[raw.ch_names[p] for p in pick_list],
        xlabel='Time (s)',
        ylabel='Region',
        title=str(path)[-6:-4],
    )


fig, axes = plt.subplots(1, n_axes, figsize=(12, 6), sharey=True)

plot_edf_segment(edf_1, ax=axes[0])
plot_edf_segment(edf_2, ax=axes[1])
plot_edf_segment(edf_3, ax=axes[2])

# fig.tight_layout()
# plt.show()


In [None]:
# ---
# EDF matrix generator for task 1 (reference: https://physionet.org/content/eegmmidb/1.0.0/)
# Scans each EDF file with MNE in triplets (corr. to each trail per task 1) for each subject in range.
# @ Returns matrix with EEG preloaded in memory to be handled by futher operations
#
# Task 1 (open and close left or right fist)
# Metadata with physionet data:
# - T0 corresponds to rest
# - T1 corresponds to onset of motion (real or imagined)
# - T2 corresponds to onset of motion (real or imagined)
#
# MNE interpretation: {np.str_('T0'): 1, np.str_('T1'): 2, np.str_('T2'): 3}
# ---


task1_subject_matrix = []
events_general, event_id_general = (0, 0)
event_map = {"left": 2, "right": 3}

for subjectNum in range(1, 2):

    pathing_input = f"00{subjectNum}"
    if subjectNum >= 100: 
        pathing_input = f"{subjectNum}"
    elif subjectNum >= 10:
        pathing_input = f"0{subjectNum}"
    
    raw_task1_1 = mne.io.read_raw_edf(f"./physionet.org/files/eegmmidb/1.0.0/S{pathing_input}/S{pathing_input}R03.edf", preload=False)
    raw_task1_2 = mne.io.read_raw_edf(f"./physionet.org/files/eegmmidb/1.0.0/S{pathing_input}/S{pathing_input}R07.edf", preload=False)
    raw_task1_3 = mne.io.read_raw_edf(f"./physionet.org/files/eegmmidb/1.0.0/S{pathing_input}/S{pathing_input}R11.edf", preload=False)

    events_general, event_id_general = mne.events_from_annotations(raw_task1_1)
    print(event_id_general)

    subjectEntry = [raw_task1_1, raw_task1_2, raw_task1_3]

    task1_subject_matrix.append(subjectEntry)

# task1_subject_matrix


In [None]:
# ---
# @ Returns a 3D array that represents the epochs of every subject across trials.
# len(local_X) = 3, corresponding to C, T, and M. C, T, and M are all arrays. M = trials, C = channels, T = time
# all_X is a list of all subjects' epochs, parsed per subject themselves.
# ---

all_X_task1 = []  # X = trials / each n entry into *this* array is an array of every epoch for one subject
all_Y_task1 = []  # Y = labels

for subject_entry in task1_subject_matrix:
    subject_X = []
    subject_Y = []
    for raw in subject_entry:
        events, _ = mne.events_from_annotations(raw)
        epochs = mne.Epochs(
            raw,
            events,
            event_id=event_map,
            tmin=0.0,
            tmax=4.0,
            baseline=None,
            preload=True
        )

        data = epochs.get_data() # (M, C, T)
        local_X = np.transpose(data, (1, 2, 0)) # (C, T, M)

        labels = epochs.events[:, -1]
        local_Y = np.where(labels == 2, -1, 1).reshape(-1, 1)

        all_X_task1.append(local_X)
        all_Y_task1.append(local_Y)

X_full = np.concatenate(all_X_task1, axis=2) # stack trials
Y_full = np.concatenate(all_Y_task1, axis=0) # stack labels

# print(X_full.shape, X_full.dtype)
# print(Y_full.shape, Y_full.dtype)


In [None]:
# ---
# @ Returns X, brainwwave voltage data in formats M, C, T and C, T, M. Y, the labels, do not change from order differences.
# Ensures that torch flags tensors as mps for processing later
# ---

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")  # mps = mac gpu (Metal Performance Shaders)

X_CTM = torch.tensor(X_full, dtype=torch.float32).to(device)
X_MCT = X_CTM.permute(2, 0, 1)
Y = torch.tensor(Y_full, dtype=torch.float32).to(device)