### Import Libraries

In [126]:
%load_ext autoreload
%autoreload 2
import numpy as np
from pathlib import Path
import tensorflow as tf
from shared.models import SAT1Start, ShallowConvNet, EEGNet
from shared.generators import SAT1DataGenerator
from shared.utilities import earlyStopping_cb
from sklearn.metrics import classification_report

%env TF_FORCE_GPU_ALLOW_GROWTH=true
%env TF_GPU_ALLOCATOR=cuda_malloc_async

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: TF_FORCE_GPU_ALLOW_GROWTH=true
env: TF_GPU_ALLOCATOR=cuda_malloc_async


### Set up data

In [127]:
data_name = 'data21-25_AC.npz'
data_path = Path('data/sat1') / data_name

with np.load(data_path) as f_data:
    data = f_data['data']
    labels = f_data['labels']
    participants = f_data['participants']
    
data = data.reshape(-1, 30, 210, 1)

In [139]:
categories = sorted(list(set(labels.flatten())))
x_train = data[np.isin(participants, ['0021', '0022', '0023', '0024'])]
y_train = labels[np.isin(participants, ['0021', '0022', '0023', '0024'])]

x_test = data[np.isin(participants, ['0025'])]
y_test = labels[np.isin(participants, ['0025'])]

train_gen = SAT1DataGenerator(x_train, y_train)
test_gen = SAT1DataGenerator(x_test, y_test)

In [140]:
# # Run before re-training to clear up VRAM
import gc
gc.collect()
tf.keras.backend.clear_session()
del model

In [141]:
# Instantiate model for 210 sample length, 30 electrodes/features, and 4 classes
model = SAT1Start(30, 210, 4)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 30, 210, 1)]      0         
                                                                 
 conv2d (Conv2D)             (None, 30, 206, 16)       96        
                                                                 
 dropout (Dropout)           (None, 30, 206, 16)       0         
                                                                 
 conv2d_1 (Conv2D)           (None, 26, 202, 32)       12832     
                                                                 
 dropout_1 (Dropout)         (None, 26, 202, 32)       0         
                                                                 
 conv2d_2 (Conv2D)           (None, 22, 202, 64)       10304     
                                                                 
 batch_normalization (BatchN  (None, 22, 202, 64)      256   

In [142]:
# model.fit(train_gen,
#           epochs=20,
#           # TODO: Create val_gen (use 6th participant?)
#           validation_data=test_gen,
#           callbacks=[earlyStopping_cb])
model.fit(train_gen,
          epochs=20,
          # TODO: Create val_gen (use 6th participant?)
          validation_data=test_gen)

Epoch 1/20


2023-05-30 12:10:15.275740: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
	 [[{{node Placeholder/_0}}]]
2023-05-30 12:10:15.722538: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inmodel/dropout/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/20
 1/60 [..............................] - ETA: 0s - loss: 1.4746 - accuracy: 0.2500

2023-05-30 12:10:16.962806: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
	 [[{{node Placeholder/_0}}]]


Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x7f4cb40e18d0>

In [144]:
print('Testset results')
predicted_classes = np.argmax(model.predict(test_gen), axis=1)
print(classification_report(test_gen.labels_cat, predicted_classes))
print(test_gen.categories)

Testset results
 1/16 [>.............................] - ETA: 0s

2023-05-30 12:10:58.913476: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
	 [[{{node Placeholder/_0}}]]


              precision    recall  f1-score   support

           0       0.58      0.36      0.44        81
           1       0.88      0.13      0.23        52
           2       0.44      0.95      0.60        39
           3       0.63      0.86      0.73        83

    accuracy                           0.56       255
   macro avg       0.63      0.57      0.50       255
weighted avg       0.64      0.56      0.52       255

['2', '3', 'motor', 'perception']


In [125]:
test_gen.labels

array(['motor', '3', '3', 'perception', 'perception', '3', 'perception',
       'motor', '2', '3', '2', '3', 'perception', '2', 'perception', '2',
       '2', 'motor', '3', 'perception', '2', 'motor', '2', '3', '2', '2',
       'motor', '3', '3', 'perception', '2', 'motor', 'perception',
       'perception', 'motor', '2', '2', '3', '3', '2', '3', '3', '2',
       'perception', '2', '3', 'perception', 'perception', '2', 'motor',
       'perception', '3', 'perception', 'motor', 'motor', 'perception',
       '2', '2', '2', 'perception', 'perception', '2', '2', 'perception',
       '2', '2', 'motor', '2', 'perception', '2', 'perception', '3', '2',
       'perception', '3', 'perception', 'perception', '2', '3', '2', '3',
       '2', '3', '2', 'perception', '2', 'perception', 'perception', '2',
       'motor', 'perception', '3', 'motor', 'motor', 'perception',
       'motor', 'perception', '3', 'motor', '3', 'perception', '3', '3',
       '2', 'motor', '3', '2', '2', 'perception', 'motor', '