In [1]:
import numpy as np
import pandas as pd
import mne
import random
import torch
import torch.nn as nn
import torch.optim as optim
from glob import glob
from mne import EpochsArray, create_info
from mne.time_frequency import psd_array_multitaper
from scipy.stats import pearsonr

In [3]:
sfreq = 125
# Some works mentioned alpha (8-12 Hz) and beta (12-30 Hz) bands have been associated with attentional processes and working memory.
fmin, fmax = 8, 30
segment_duration = 1
sample_duration = 10
ch_names = ["ch1_LF5 - FpZ", "ch2_OTE_L-FpZ", "ch4_RF6-FpZ", "ch5_OTE_R-FpZ"]
stage_names = ["Eyes closed", "Eyes open", "Concentration", "Immersion"]
stages = {
    "Eyes closed": {
        "tmin": 0,
        "tmax": 120,
        "label": 25,
    },
    "Eyes open": {
        "tmin": 120,
        "tmax": 240,
        "label": 50,
    },
    "Concentration": {
        "tmin": 240,
        "tmax": 360,
        "label": 75,
    },
    "Immersion": {
        "tmin": 360,
        "tmax": 500000,
        "label": 100,
    },
}

In [4]:
def preprocessing(raw_array):
    # Band-pass filter
    raw_array.filter(fmin, fmax, picks='eeg')
    # Remove artifacts, may not be necessary for this data
    ica = mne.preprocessing.ICA(n_components=len(ch_names), random_state=0)
    ica.fit(raw_array, picks='eeg')
    ica.apply(raw_array)
    return raw_array

def extract_features(raw_array):
    # Extract psd features
    events = mne.make_fixed_length_events(raw_array, duration=segment_duration)
    epochs = mne.Epochs(raw_array, events, tmin=0, tmax=segment_duration, baseline=None, preload=True)
    psd_features, _ = psd_array_multitaper(epochs.get_data(), sfreq=sfreq, fmin=fmin, fmax=fmax, n_jobs=1)
    # Concat features of 4 channels for each segment
    n_segments, n_channels, n_features = psd_features.shape
    n_segments_per_sample = sample_duration / segment_duration
    n_segments = int((n_segments // n_segments_per_sample) * n_segments_per_sample)
    psd_features = psd_features[:n_segments, :, :]
    psd_features = psd_features.reshape((int(n_segments // n_segments_per_sample), int(n_segments_per_sample), n_channels * n_features))
    return psd_features

def extract_data_from_csv(csv_path):
    # Create raw MNE object from data
    df = pd.read_csv(csv_path, usecols=[0, 1, 3, 4])
    data = np.array(df.values).T
    print(data)
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types='eeg')
    raw = mne.io.RawArray(data, info)
    
    # Preprocessing step
    raw = preprocessing(raw)
    
    # Extract the features step, features shape (n_sample, n_segments_per_sample, n_channels * n_features)
    features = extract_features(raw)
    
    # Making scores groundtruth
    labels = np.zeros((features.shape[0],))
    for stage_name, stage_value in stages.items():
        tmin, tmax = stage_value['tmin'], stage_value['tmax']
        labels[tmin // sample_duration: tmax // sample_duration] = stage_names.index(stage_name)
    
    return features, labels

In [5]:
files = glob("Focus_data_for testing/*.csv")
files

['Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S13.csv',
 'Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S10.csv',
 'Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S6.csv',
 'Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S7.csv',
 'Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S4.csv',
 'Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S1.csv',
 'Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S3.csv',
 'Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S2.csv',
 'Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S9.csv',
 'Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S8.csv']

In [13]:
# Split into training and test sets with ratio 90:10
random.shuffle(files)

n_train = int(len(files) * 0.9)

train_file_paths = files[:n_train]
test_file_paths = files[n_train:]

print("Train set:\n\t", '\n\t'.join(train_file_paths))
print("Test set:\n\t", '\n\t'.join(test_file_paths))

Train set:
	 Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S2.csv
	Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S9.csv
	Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S13.csv
	Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S7.csv
	Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S1.csv
	Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S10.csv
	Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S3.csv
	Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S6.csv
	Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S8.csv
Test set:
	 Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S4.csv


In [14]:
X_train, y_train = [], []
for file in train_file_paths:
    samples, scores = extract_data_from_csv(file)
    X_train.extend(samples)
    y_train.extend(scores)
    
X_train = np.array(X_train)
y_train = np.array(y_train)

[[ 45227  45221  45224 ...  28718  28719  28720]
 [225459 225503 225471 ... 250466 250454 250462]
 [  2129   2126   2129 ...    906    904    900]
 [247006 247044 247017 ... 266979 266968 266984]]
Creating RawArray with float64 data, n_channels=4, n_times=84850
    Range : 0 ... 84849 =      0.000 ...   678.792 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 207 samples (1.656 sec)

Fitting ICA to data using 4 channels (please be patient, this may take a while)
Selecting by number: 4 components

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s finished


Fitting ICA took 0.3s.
Applying ICA to Raw instance
    Transforming to ICA space (4 components)
    Zeroing out 0 ICA components
    Projecting back using 4 PCA components
Not setting metadata
678 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 678 events and 126 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
[[ 21687  21691  21701 ...  21543  21553  21542]
 [271336 271273 271295 ... 316109 316136 316097]
 [ 13748  13740  13750 ...  46872  46874  46871]
 [280848 280799 280887 ... 327257 327174 327220]]
Creating RawArray with float64 data, n_channels=4, n_times=83758
    Range : 0 ... 83757 =      0.000 ...   670.056 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) met

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s finished


[[-713847 -713819 -713854 ... -712320 -712315 -712312]
 [-619124 -619154 -619137 ... -619544 -619678 -619615]
 [-704357 -704330 -704365 ... -701356 -701351 -701349]
 [-618986 -618965 -618993 ... -619339 -619336 -619336]]
Creating RawArray with float64 data, n_channels=4, n_times=83746
    Range : 0 ... 83745 =      0.000 ...   669.960 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 207 samples (1.656 sec)

Fitting ICA to data using 4 channels (please be patient, this may take a while)
Selecting

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s finished
  ica.fit(raw_array, picks='eeg')


[[  3292   3295   3299 ...  -2511  -2517  -2514]
 [321910 322126 321854 ... 321565 321503 321553]
 [ 30766  30763  30758 ...  32678  32672  32667]
 [354126 354190 354128 ... 349739 349706 349725]]
Creating RawArray with float64 data, n_channels=4, n_times=83192
    Range : 0 ... 83191 =      0.000 ...   665.528 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 207 samples (1.656 sec)

Fitting ICA to data using 4 channels (please be patient, this may take a while)
Selecting by number: 4 components

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s finished


[[ 26948  26945  26946 ...  32710  32709  32711]
 [490905 490899 490901 ... 482544 482563 482521]
 [ 64774  64775  64773 ...  85740  85742  85742]
 [475718 475718 475735 ... 469715 469718 469676]]
Creating RawArray with float64 data, n_channels=4, n_times=85057
    Range : 0 ... 85056 =      0.000 ...   680.448 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 207 samples (1.656 sec)

Fitting ICA to data using 4 channels (please be patient, this may take a while)
Selecting by number: 4 components

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s finished


[[ 41793  41809  41821 ...  23628  23635  23672]
 [319562 323270 322219 ... 303898 306789 305322]
 [ 25925  25918  25923 ...   3666   3651   3669]
 [340254 341793 341072 ... 323353 327476 326396]]
Creating RawArray with float64 data, n_channels=4, n_times=73289
    Range : 0 ... 73288 =      0.000 ...   586.304 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 207 samples (1.656 sec)

Fitting ICA to data using 4 channels (please be patient, this may take a while)
Selecting by number: 4 components

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s finished


[[   440    441    442 ...  -8165  -8144  -8148]
 [265566 265493 265522 ... 261730 261829 261852]
 [ 32839  32841  32847 ...  22267  22259  22266]
 [255297 255272 255290 ... 255293 255435 255410]]
Creating RawArray with float64 data, n_channels=4, n_times=77889
    Range : 0 ... 77888 =      0.000 ...   623.104 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 207 samples (1.656 sec)

Fitting ICA to data using 4 channels (please be patient, this may take a while)
Selecting by number: 4 components

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s finished


[[  17699   17704   17702 ... -565656 -560247 -554532]
 [ 292520  293178  291859 ... -565432 -560025 -554313]
 [  -4668   -4659   -4661 ...  749999  749999  749999]
 [ 324828  324852  324816 ... -566229 -560815 -555095]]
Creating RawArray with float64 data, n_channels=4, n_times=84800
    Range : 0 ... 84799 =      0.000 ...   678.392 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 207 samples (1.656 sec)

Fitting ICA to data using 4 channels (please be patient, this may take a while)
Selecting

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s finished


Selecting by number: 4 components
Fitting ICA took 0.1s.
Applying ICA to Raw instance
    Transforming to ICA space (4 components)
    Zeroing out 0 ICA components
    Projecting back using 4 PCA components
Not setting metadata
629 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 629 events and 126 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows


In [15]:
y_train.shape

(583,)

### Modeling and training

In [58]:
batch_size = 36
input_size = X_train.shape[-1]
output_size = 4
hidden_size = 36
num_epochs = 50
learning_rate = 0.0005
lstm_num_layers = 2

In [59]:
class EEGFocusLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(EEGFocusLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc1 = nn.Linear(hidden_size, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, output_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))

        # Apply fully connected layers with dropout and ReLU activation
        out = self.dropout(self.relu(self.fc1(out[:, -1, :])))
        out = self.dropout(self.relu(self.fc2(out)))
        out = self.fc3(out)

        return out

In [60]:
# create PyTorch dataset and dataloader
train_data = torch.tensor(X_train, dtype=torch.float32)
dataset = torch.utils.data.TensorDataset(train_data, torch.from_numpy(y_train).long())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [61]:
model = EEGFocusLSTM(input_size, hidden_size, lstm_num_layers, output_size)

# Define loss function and optimizer
# criterion = nn.MSELoss()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [62]:
# Training loop
for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    model.train()
    for i, (inputs, targets) in enumerate(dataloader):
        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        train_acc += torch.sum(preds == targets).item()
    train_loss /= len(dataloader.dataset)
    train_acc /= len(dataloader.dataset)
    print('Epoch [{}/{}], Loss: {:.4f}, train_acc: {}'.format(epoch+1, num_epochs, train_loss, train_acc))


Epoch [1/50], Loss: 0.0388, train_acc: 0.411663807890223
Epoch [2/50], Loss: 0.0381, train_acc: 0.444253859348199
Epoch [3/50], Loss: 0.0372, train_acc: 0.444253859348199
Epoch [4/50], Loss: 0.0372, train_acc: 0.444253859348199
Epoch [5/50], Loss: 0.0357, train_acc: 0.4751286449399657
Epoch [6/50], Loss: 0.0344, train_acc: 0.5265866209262435
Epoch [7/50], Loss: 0.0329, train_acc: 0.548885077186964
Epoch [8/50], Loss: 0.0295, train_acc: 0.5900514579759862
Epoch [9/50], Loss: 0.0309, train_acc: 0.5643224699828473
Epoch [10/50], Loss: 0.0293, train_acc: 0.5763293310463122
Epoch [11/50], Loss: 0.0290, train_acc: 0.5866209262435678
Epoch [12/50], Loss: 0.0299, train_acc: 0.5694682675814752
Epoch [13/50], Loss: 0.0271, train_acc: 0.5969125214408233
Epoch [14/50], Loss: 0.0267, train_acc: 0.6243567753001715
Epoch [15/50], Loss: 0.0269, train_acc: 0.62778730703259
Epoch [16/50], Loss: 0.0259, train_acc: 0.6157804459691252
Epoch [17/50], Loss: 0.0264, train_acc: 0.6397941680960549
Epoch [18/50]

### Evaluation

In [63]:
X_test, y_test = [], []
for file in test_file_paths:
    samples, scores = extract_data_from_csv(file)
    X_test.append(samples)
    y_test.append(scores)

[[  3931   3928   3933 ...   1442   1452   1449]
 [287429 287614 287570 ... 297867 297499 297555]
 [ 54069  54065  54072 ...  52651  52659  52653]
 [319177 319210 319210 ... 328467 328475 328459]]
Creating RawArray with float64 data, n_channels=4, n_times=78829
    Range : 0 ... 78828 =      0.000 ...   630.624 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 207 samples (1.656 sec)

Fitting ICA to data using 4 channels (please be patient, this may take a while)
Selecting by number: 4 components

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s finished


0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows


<b>Pearson correlation coefficient</b> can captures the degree to which two variables move in the same direction. If two variables are positively correlated, then an increase in one variable is associated with an increase in the other variable, while a decrease in one variable is associated with a decrease in the other variable. This indicates a trend of movement in the same direction.<br>
So in this case, <b>to measure the trend increasing from eye closed to immersion</b>, I'm going to use Pearson correlation coefficient.

In [64]:
model.eval()
for idx in range(len(X_test)):
    print("File: ", test_file_paths[idx])
    output = model(torch.tensor(X_test[idx], dtype=torch.float32))
    _, preds = torch.max(output, 1)
    corr_coef, p_value = pearsonr(preds, y_test[idx])

    print("Pearson correlation coefficient:", corr_coef)
    print("p-value:", p_value)

File:  Focus_data_for testing/Copy of eeg_focus_6sessions_cut_S4.csv
Pearson correlation coefficient: 0.6472323475262727
p-value: 9.92578343668047e-09
