In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from eegsc.ml.gru import GRUNet, train_gru, predict_gru
from eegsc.preprocessing.filters import BandPassFilter
from eegsc.utils.io import read_raw
from eegsc.utils.path import get_data_path
from eegsc.utils.experiments import create_sequence_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = read_raw(os.path.join(get_data_path(), 'raw', '1st_Day.mat'),
                # data_types=['left_real', 'right_real',
                #             'left_im1', 'right_im1',
                #             'left_im2', 'right_im2',
                #             'left_quasi', 'right_quasi'])
                # data_types=['left_quasi', 'right_quasi'])
                # data_types=['left_im2', 'right_im2'])
                # data_types=['left_im1', 'right_im1'])
                data_types=['left_real', 'right_real'])

In [3]:
start_time = 5
signal_duration = 10 - start_time
bandpass_filter = BandPassFilter(signal_duration=signal_duration)
seq_data, labels = create_sequence_dataset(data, bandpass_filter, start_time=start_time)
seq_data.shape, labels.shape

((569, 32, 5095), (569,))

In [4]:
x_train, x_test, y_train, y_test = train_test_split(
    seq_data, labels, test_size=0.2, random_state=0)
x_train.shape, x_test.shape

((455, 32, 5095), (114, 32, 5095))

In [5]:
model = GRUNet(input_size=x_train.shape[1],
               hidden_size=32,
               n_layers=1,
               n_classes=np.unique(labels).shape[0])

In [6]:
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.001, weight_decay=1e-3)
model = train_gru(model, x_train, y_train, criterion, optimizer, n_epochs=20)

Epoch 1 | loss = 3828.6737653017044, accuracy = 0.13554708764289602
Epoch 2 | loss = 7653.845572352409, accuracy = 0.13554708764289602
Epoch 3 | loss = 11477.851648449898, accuracy = 0.13572854291417166
Epoch 4 | loss = 15301.264512777328, accuracy = 0.13554708764289602


KeyboardInterrupt: 

In [None]:
accuracy_score(y_test, predict_gru(model, x_test))

0.4824561403508772

In [5]:
np.save(os.path.join(get_data_path(), 'train_test', 'x_train.npy'), x_train)
np.save(os.path.join(get_data_path(), 'train_test', 'x_test.npy'), x_test)
np.save(os.path.join(get_data_path(), 'train_test', 'y_train.npy'), y_train)
np.save(os.path.join(get_data_path(), 'train_test', 'y_test.npy'), y_test)