In [2]:
import os
import numpy as np
import torch
import torch.nn as nn

from eegsc.ml.conv import ConvNet, train_conv, predict_conv
from eegsc.preprocessing.filters import BandPassFilter
from eegsc.preprocessing.spectrum import SpectrumTransformer
from eegsc.utils.io import read_raw
from eegsc.utils.path import get_data_path
from eegsc.utils.experiments import create_spectrum_dataset
from eegsc.utils.cross_val import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data = read_raw(os.path.join(get_data_path(), 'raw', '1st_Day.mat'),
                # data_types=['left_real', 'right_real',
                #             'left_im1', 'right_im1',
                #             'left_im2', 'right_im2',
                #             'left_quasi', 'right_quasi'])
                # data_types=['left_quasi', 'right_quasi'])
                # data_types=['left_im2', 'right_im2'])
                # data_types=['left_im1', 'right_im1'])
                data_types=['left_real', 'right_real'])

In [4]:
start_time = 5
signal_duration = 10 - start_time
bandpass_filter = BandPassFilter(signal_duration=signal_duration)
spectrum_transformer = SpectrumTransformer(psd_method='periodogram',
                                           signal_duration=signal_duration)

In [6]:
statistics, labels, person_idxs = create_spectrum_dataset(
    data=data,
    bandpass_filter=bandpass_filter,
    spectrum_transformer=spectrum_transformer,
    compute_stat=False,
    start_time=start_time
)
statistics.shape, labels.shape

((569, 192, 5095), (569,))

In [7]:
x_train, x_test, y_train, y_test = train_test_split(
    statistics, labels, person_idxs, test_size=0.2, random_state=0)
x_train.shape, x_test.shape

((455, 192, 5095), (114, 192, 5095))

In [8]:
np.save(os.path.join(get_data_path(), 'train_test', 'x_train.npy'), x_train)
np.save(os.path.join(get_data_path(), 'train_test', 'x_test.npy'), x_test)
np.save(os.path.join(get_data_path(), 'train_test', 'y_train.npy'), y_train)
np.save(os.path.join(get_data_path(), 'train_test', 'y_test.npy'), y_test)

In [9]:
np.nanmean(x_train), np.nanstd(x_train)

(1.2321113877261198e-09, 3.854901604545972e-06)