In [2]:
from pathlib import Path
import mne
import numpy as np
import h5py
mne.set_log_level(verbose=False)

In [3]:
NUM_SUBJECTS=8
ANNOTATION_LABELS = {
    '01': [0, 0, 0],
    '02': [0, 0, 0],
    '03': [0, 1, 2],
    '04': [0, 1, 2],
    '05': [0, 3, 4],
    '06': [0, 3, 4],
    '07': [0, 1, 2],
    '08': [0, 1, 2],
    '09': [0, 3, 4],
    '10': [0, 3, 4],
    '11': [0, 1, 2],
    '12': [0, 1, 2],
    '13': [0, 3, 4],
    '14': [0, 3, 4],
}

dsets = {
    'train': (0, 2000),
    'test': (2000, 1000),
}
columns = None
data_dir = Path('../data')
edf_files = np.array(list(data_dir.glob('**/*R02.edf')))
num_files = len(edf_files)
file_num = 0
seen_subjects = {}

with h5py.File(data_dir / 'eeg-s.h5', 'w-') as f:
    for file in edf_files:
        file_num += 1
        if file_num%10 == 0:
            print('===================')
            print(f"    {file_num} of {num_files}")
            print('===================')

        subject, task = file.stem.split('R')
        subject = int(subject.strip('S'))
        if subject not in seen_subjects:
            if len(seen_subjects) >= NUM_SUBJECTS:
                print('Max subjects reached. Skipping {file.stem}...')
                continue
            seen_subjects[subject] = {}
        seen_subjects[subject][task] = {}
        print(f'Processing {file.stem}...')
        
        with mne.io.read_raw_edf(file.resolve()) as edf:
            if columns is None:
                columns = edf.ch_names
            elif edf.ch_names != columns:
                print('Nope')
                break
            data = edf.get_data().T
            events, _ = mne.events_from_annotations(edf)
        events = events[:, [0,-1]] + [0, -1]

        for dset, (start, size) in dsets.items():
            if dset not in f:
                num_columns = data.shape[1] + 1
                f.create_dataset(
                    dset,
                    shape=(0, num_columns),
                    dtype=data.dtype,
                    maxshape=(None, num_columns),
                    compression='gzip',
                    compression_opts=9,
                )
            seen_subjects[subject][task][dset] = size
            
            dset = f[dset]
            existing = dset.shape[0]
            dset.resize(existing + size, axis=0)
            dset[existing:, :-1] = data[start:start+size]
            dset[existing:, -1] = subject
            
#         for i in range(len(events)):
#             start, annotation = events[i]
#             end = data.shape[0] if i == len(events) - 1 else events[i+1][0]
#             dset[existing+start:existing+end, -1] = ANNOTATION_LABELS[task][annotation]

print(f'Processed {file_num} files!')
print(seen_subjects)

Processing S070R02...
Processing S062R02...
Processing S079R02...
Processing S018R02...
Processing S087R02...
Processing S004R02...
Processing S052R02...
Processing S035R02...
Max subjects reached. Skipping {file.stem}...
    10 of 109
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
    20 of 109
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects reached. Skipping {file.stem}...
Max subjects r