In [2]:
import os
import glob
import re
from scipy.io import loadmat

# Labels for each session (each has 24 labels)
session1_label = [
    1,
    2,
    3,
    0,
    2,
    0,
    0,
    1,
    0,
    1,
    2,
    1,
    1,
    1,
    2,
    3,
    2,
    2,
    3,
    3,
    0,
    3,
    0,
    3,
]
session2_label = [
    2,
    1,
    3,
    0,
    0,
    2,
    0,
    2,
    3,
    3,
    2,
    3,
    2,
    0,
    1,
    1,
    2,
    1,
    0,
    3,
    0,
    1,
    3,
    1,
]
session3_label = [
    1,
    2,
    2,
    1,
    3,
    3,
    3,
    1,
    1,
    2,
    1,
    0,
    2,
    3,
    3,
    0,
    2,
    3,
    0,
    0,
    2,
    0,
    1,
    0,
]

# Base path for the EEG raw data
base_path = r"seed-iv\eeg_raw_data"

# Paths for each session (folder 1, 2, and 3)
session_paths = [
    os.path.join(base_path, "1"),
    os.path.join(base_path, "2"),
    os.path.join(base_path, "3"),
]

# A dictionary to store final data from all sessions
all_sessions_data = {"session1": [], "session2": [], "session3": []}

for i, session_path in enumerate(session_paths):
    # Select the corresponding labels and session name
    if i == 0:
        labels = session1_label
        session_name = "session1"
    elif i == 1:
        labels = session2_label
        session_name = "session2"
    else:
        labels = session3_label
        session_name = "session3"

    # Get all .mat files in the current session folder
    mat_files = sorted(glob.glob(os.path.join(session_path, "*.mat")))

    session_data = []

    for file_path in mat_files:
        # Parse the subject number from the file name (e.g., "7_20150715.mat" -> subject_number=7)
        file_name = os.path.basename(file_path)
        subject_str = file_name.split("_")[0]

        try:
            subject_number = int(subject_str)
        except ValueError:
            # If the part before '_' is not purely numeric, handle it accordingly
            subject_number = None

        # Load the .mat file
        data_dict = loadmat(file_path)

        # We will collect valid keys in the format "<prefix>_eeg<number>"
        valid_keys = []

        # Iterate over all keys in the loaded dictionary
        for key in data_dict.keys():
            # Ignore meta keys
            if key.startswith("__"):
                continue

            # Check if this key matches the pattern: something_eegN
            match = re.match(r"(.*)_eeg(\d+)$", key)
            if match:
                # Extract the index (e.g. from _eeg1, _eeg2, etc.)
                idx = int(match.group(2))
                prefix = match.group(1)  # e.g. 'cz' or 'mz'

                valid_keys.append((key, idx, prefix))

        # Sort the valid keys based on their index (1..24)
        valid_keys.sort(key=lambda x: x[1])

        # Check if the number of found signals matches the length of the label array (24)
        if len(valid_keys) != len(labels):
            print(
                f"Warning: In file '{file_name}', found {len(valid_keys)} EEG signals but expected {len(labels)}."
            )
            print("Labels may not align correctly with the signals.")

        # Pair each key with the corresponding label
        for (key, idx, prefix), label in zip(valid_keys, labels):
            eeg_data = data_dict[key]

            # Store a record of this signal
            session_data.append(
                {
                    "file_path": file_path,
                    "subject_number": subject_number,
                    "signal_name": key,  # e.g. 'mz_eeg1'
                    "signal_index": idx,  # e.g. 1
                    "prefix": prefix,  # e.g. 'mz'
                    "data": eeg_data,  # the actual EEG matrix
                    "label": label,  # label from session*_label
                }
            )

    # Save the accumulated data for the current session
    all_sessions_data[session_name] = session_data

print("All sessions loaded successfully.")

{'__header__': b'MATLAB 5.0 MAT-file, Platform: PCWIN64, Created on: Thu Nov 15 15:30:38 2018',
 '__version__': '1.0',
 '__globals__': [],
 'mz_eeg1': array([[134.11045074, 135.57076454, 128.0605793 , ...,   5.78165054,
          -1.28149986, -13.38124275],
        [ 86.54594421,  85.62207222,  86.57574654, ...,   4.02331352,
          -6.82473183,   5.453825  ],
        [102.04315186,  98.28805923, 104.15911674, ...,  10.04338264,
           3.21865082,  11.62290573],
        ...,
        [-14.69254494, -19.66953278, -19.28210258, ...,   8.13603401,
          12.72559166,   9.86456871],
        [-10.7884407 , -13.26203346, -14.90116119, ...,  -0.14901161,
          10.22219658,   5.15580177],
        [-12.48717308, -16.12305641, -16.12305641, ...,  -0.32782555,
          12.04013824,   3.24845314]], shape=(62, 33601)),
 'mz_eeg2': array([[  1.01327896,  -1.69873238,   9.56654549, ...,  14.93096352,
          25.83861351,  21.93450928],
        [  2.95042992,   3.42726707,   5.1856041 