In [2]:
import os
import re
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

In [3]:
class AudiobookEEGDataset(Dataset):
    def __init__(self, eeg_root, stim_root):
        self.data_pairs = []

        audiobook_pattern = re.compile(r'audiobook_\d+_\d+')
        audiobook_pattern1 = re.compile(r'audiobook_\d+')
        podcast_pattern = re.compile(r'podcast_\d+')

        for subject in sorted(os.listdir(eeg_root)):
            if subject.endswith(".json"):
                continue
            ses_path = os.path.join(eeg_root, subject)
            subject_path = os.path.join(ses_path, os.listdir(ses_path)[0])
            

            for file in sorted(os.listdir(subject_path)):
                if file.endswith(".npy"):
                    eeg_path = os.path.join(subject_path, file)

                    match = audiobook_pattern.search(file)
                    if not match:
                        match = audiobook_pattern1.search(file)
                        if not match:
                            match = podcast_pattern.search(file)
                            if not match:
                                print(f"Warning: No audiobook ID found in {file}")
                                continue
                        

                    audiobook_id = match.group()
                    stim_file = f"{audiobook_id}_envelope.npy"
                    # print(stim_file)
                    stim_path = os.path.join(stim_root, stim_file)

                    if os.path.exists(stim_path):
                        self.data_pairs.append((eeg_path, stim_path))
                    else:
                        stim_file = f"{audiobook_id}_shifted_envelope.npy"
                        # print(stim_file)
                        stim_path = os.path.join(stim_root, stim_file)
                        if os.path.exists(stim_path):
                            self.data_pairs.append((eeg_path, stim_path))
                        else:
                            print(f"Warning: Stimulus file not found for {file}")

        print(f"Loaded {len(self.data_pairs)} EEG-stimulus pairs")

    def __len__(self):
        return len(self.data_pairs)

    def __getitem__(self, idx):
        eeg_path, stim_path = self.data_pairs[idx]

        eeg = np.load(eeg_path)  # shape: (64, T)
        stim = np.load(stim_path)  # shape: (T, 1)

        # Transpose EEG to (T, 64) if needed
        if eeg.shape[0] == 64:
            eeg = eeg.T  # shape becomes (T, 64)

        # Flatten stimulus if needed
        stim = stim.squeeze()  # shape becomes (T,)

        # Align lengths by trimming to shortest length
        min_len = min(eeg.shape[0], stim.shape[0])
        eeg = eeg[:min_len]
        stim = stim[:min_len]
        # print(eeg.shape[0], stim.shape[0])

        eeg_tensor = torch.tensor(eeg, dtype=torch.float32)
        stim_tensor = torch.tensor(stim, dtype=torch.float32)

        return [eeg_tensor, stim_tensor]

In [4]:
eeg_root = '../data/derivatives/preprocessed_eeg'
stim_root = '../data/derivatives/preprocessed_stimuli'

dataset = AudiobookEEGDataset(eeg_root, stim_root)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

Loaded 588 EEG-stimulus pairs


In [22]:
def segment_trials(d):
    data = []
    window_length = 5 * 64       # 320
    overlap = 1 * 64             # 64
    stride = window_length - overlap  # 256

    for eeg, stim in d:
        assert eeg.shape[0] == stim.shape[0], "EEG and stim must be same length"
        T = eeg.shape[0]

        cur = 0
        while cur + window_length <= T:
            eeg_seg = eeg[cur:cur + window_length]   # (320, 64)
            stim_seg = stim[cur:cur + window_length] # (320,)
            data.append((eeg_seg, stim_seg))
            cur += stride
    return data


In [24]:
from sklearn.model_selection import train_test_split

train_val, test = train_test_split(dataset, test_size=0.2, random_state=42)
train, val = train_test_split(train_val, test_size=0.1, random_state=42)


In [25]:
train_data = segment_trials(train)
val_data = segment_trials(val)
test_data = segment_trials(test)

In [31]:
torch.save(train_data, 'train.pt')

In [40]:
torch.save(test_data, 'test.pt')
torch.save(val_data, 'val.pt')