In [14]:
import pandas as pd
import numpy as np
from scipy.interpolate import CubicSpline
from tqdm import tqdm
from scipy.stats import zscore
import torch

In [5]:
raw_df = pd.read_pickle('EEG_raw.pkl')
raw_eeg, wn_ids = raw_df['channels'],raw_df['wn_id']

### Generate Classification Labels

In [39]:
unique_ids = wn_ids.unique().tolist()
labels = np.zeros((len(raw_df),len(unique_ids)),dtype=bool)

for i,wn_id in enumerate(wn_ids):
    labels[i,unique_ids.index(wn_id)] = True

np.save('classification_labels.npy',labels)

### EEG Preprocess

In [4]:
def interpolate(channels):
    interpolated_sig = np.empty((3,3*250))
    t = np.linspace(0,3,channels.shape[1])
    t_new = np.linspace(0,3,interpolated_sig.shape[1])
    for i in range(channels.shape[0]):
        y = channels[i,:]
        spl = CubicSpline(t,y)
        interpolated_sig[i,:] = spl(t_new)
    assert not np.any(np.isnan(interpolated_sig))
    return interpolated_sig

def preprocess_eeg(channels):
    channels = channels[0:2,:]
    channels = interpolate(channels)
    channels = np.hstack((channels,channels[:,-375:]))
    assert not np.any(np.isnan(channels))
    return channels

In [None]:
# raw_eeg = raw_eeg.apply(preprocess_eeg)
# raw_eeg.to_pickle('EEGNet_input.pkl')

In [58]:
all_eeg = np.empty((len(raw_df),3,1125))
for i,trial_data in tqdm(enumerate(raw_eeg)):
    all_eeg[i,:,:] = preprocess_eeg(trial_data)

eeg_ndarray = zscore(all_eeg,axis=1)
assert not np.any(np.isnan(eeg_ndarray))
eeg_tensor = torch.from_numpy(eeg_ndarray).float()
assert not eeg_tensor.isnan().any()
torch.save(eeg_tensor, 'eeg_tensor.pt')

7203it [00:08, 880.73it/s]
