In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from eegsc.ml.gru import GRUNet, train_gru, predict_gru
from eegsc.preprocessing.filters import BandPassFilter
from eegsc.preprocessing.spectrum import SpectrumTransformer
from eegsc.utils.io import read_raw
from eegsc.utils.path import get_data_path
from eegsc.utils.experiments import create_spectrum_dataset

In [2]:
data = read_raw(os.path.join(get_data_path(), 'raw', '1st_Day.mat'),
                # data_types=['left_real', 'right_real',
                #             'left_im1', 'right_im1',
                #             'left_im2', 'right_im2',
                #             'left_quasi', 'right_quasi'])
                # data_types=['left_quasi', 'right_quasi'])
                # data_types=['left_im2', 'right_im2'])
                # data_types=['left_im1', 'right_im1'])
                data_types=['left_real', 'right_real'])

In [4]:
bandpass_filter = BandPassFilter()
spectrum_transformer = SpectrumTransformer(psd_method='periodogram')
seq_data, labels = create_spectrum_dataset(data,
                                           bandpass_filter,
                                           spectrum_transformer,
                                           compute_stat=False,
                                           save=False)
seq_data.shape, labels.shape

((569, 192, 10095), (569,))

In [5]:
x_train, x_test, y_train, y_test = train_test_split(
    seq_data, labels, test_size=0.2, random_state=0)
x_train.shape, x_test.shape

((455, 192, 10095), (114, 192, 10095))

In [7]:
model = GRUNet(input_size=x_train.shape[1],
               hidden_size=32,
               n_layers=1,
               n_classes=np.unique(labels).shape[0])

In [8]:
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.001, weight_decay=1e-3)
model = train_gru(model, x_train, y_train, criterion, optimizer, n_epochs=20)

Epoch 1 | loss = 317.2737897336483, accuracy = 0.46153846153846156
Epoch 2 | loss = 316.4888388514519, accuracy = 0.45054945054945056
Epoch 3 | loss = 316.1366795897484, accuracy = 0.44395604395604393
Epoch 4 | loss = 315.9732453227043, accuracy = 0.45274725274725275
Epoch 5 | loss = 315.83587038517, accuracy = 0.45494505494505494
Epoch 6 | loss = 315.8270699977875, accuracy = 0.44175824175824174
Epoch 7 | loss = 315.85170233249664, accuracy = 0.45274725274725275
Epoch 8 | loss = 315.74256205558777, accuracy = 0.4461538461538462
Epoch 9 | loss = 315.79726642370224, accuracy = 0.45054945054945056
Epoch 10 | loss = 315.8255853652954, accuracy = 0.46153846153846156


KeyboardInterrupt: 

In [6]:
np.save(os.path.join(get_data_path(), 'train_test', 'x_train.npy'), x_train)
np.save(os.path.join(get_data_path(), 'train_test', 'x_test.npy'), x_test)
np.save(os.path.join(get_data_path(), 'train_test', 'y_train.npy'), y_train)
np.save(os.path.join(get_data_path(), 'train_test', 'y_test.npy'), y_test)