## M. Jin, H. Chen, Z. Li, and J. Li, “Eeg-based emotion recognition using graph convolutional network with learnable electrode relations,” in 2021 43rd Annual International Conference of the IEEE Engineering in Medicine & Biology Society (EMBC). IEEE, 2021, pp. 5953–5957.

In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
path = '/content/drive/My Drive/PROJECT/EEG/emotions.csv'

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_size = 0.7
num_epochs = 5
batch_size = 16
learning_rate = 1e-3
num_classes = 3

# RNN model
input_size = 1
hidden_size = 256
num_layers = 2

In [None]:
df = pd.read_csv(path)
df.head()

Unnamed: 0,# mean_0_a,mean_1_a,mean_2_a,mean_3_a,mean_4_a,mean_d_0_a,mean_d_1_a,mean_d_2_a,mean_d_3_a,mean_d_4_a,...,fft_741_b,fft_742_b,fft_743_b,fft_744_b,fft_745_b,fft_746_b,fft_747_b,fft_748_b,fft_749_b,label
0,4.62,30.3,-356.0,15.6,26.3,1.07,0.411,-15.7,2.06,3.15,...,23.5,20.3,20.3,23.5,-215.0,280.0,-162.0,-162.0,280.0,NEGATIVE
1,28.8,33.1,32.0,25.8,22.8,6.55,1.68,2.88,3.83,-4.82,...,-23.3,-21.8,-21.8,-23.3,182.0,2.57,-31.6,-31.6,2.57,NEUTRAL
2,8.9,29.4,-416.0,16.7,23.7,79.9,3.36,90.2,89.9,2.03,...,462.0,-233.0,-233.0,462.0,-267.0,281.0,-148.0,-148.0,281.0,POSITIVE
3,14.9,31.6,-143.0,19.8,24.3,-0.584,-0.284,8.82,2.3,-1.97,...,299.0,-243.0,-243.0,299.0,132.0,-12.4,9.53,9.53,-12.4,POSITIVE
4,28.3,31.3,45.2,27.3,24.5,34.8,-5.79,3.06,41.4,5.52,...,12.0,38.1,38.1,12.0,119.0,-17.6,23.9,23.9,-17.6,NEUTRAL


In [None]:
# Spliting X and y to train and test data
def split_data(X, y, train_size):
    train_size = int(len(X) * train_size)
    X_train = X[:train_size]
    y_train = y[:train_size]

    X_test = X[train_size:]
    y_test = y[train_size:]

    return X_train, X_test, y_train, y_test

In [None]:
# Creating X and y and replacing labels
def preprcess_data(df, train_size=0.7):
    df = df.copy()

    y = df['label'].copy()
    y = y.replace(labels)

    X = df.drop('label', axis=1).copy()


    X_train, X_test, y_train, y_test = split_data(X, y, train_size)
    return X_train, X_test, y_train, y_test

In [None]:
X_train, X_test, y_train, y_test = preprcess_data(df, train_size)

  y = y.replace(labels)


In [None]:
# Viewing X
X_train

Unnamed: 0,# mean_0_a,mean_1_a,mean_2_a,mean_3_a,mean_4_a,mean_d_0_a,mean_d_1_a,mean_d_2_a,mean_d_3_a,mean_d_4_a,...,fft_740_b,fft_741_b,fft_742_b,fft_743_b,fft_744_b,fft_745_b,fft_746_b,fft_747_b,fft_748_b,fft_749_b
0,4.6200,30.3,-356.0,15.60,26.3,1.070,0.411,-15.7000,2.060,3.150,...,74.30,23.5,20.3,20.3,23.5,-215.0,280.00,-162.00,-162.00,280.00
1,28.8000,33.1,32.0,25.80,22.8,6.550,1.680,2.8800,3.830,-4.820,...,130.00,-23.3,-21.8,-21.8,-23.3,182.0,2.57,-31.60,-31.60,2.57
2,8.9000,29.4,-416.0,16.70,23.7,79.900,3.360,90.2000,89.900,2.030,...,-534.00,462.0,-233.0,-233.0,462.0,-267.0,281.00,-148.00,-148.00,281.00
3,14.9000,31.6,-143.0,19.80,24.3,-0.584,-0.284,8.8200,2.300,-1.970,...,-183.00,299.0,-243.0,-243.0,299.0,132.0,-12.40,9.53,9.53,-12.40
4,28.3000,31.3,45.2,27.30,24.5,34.800,-5.790,3.0600,41.400,5.520,...,114.00,12.0,38.1,38.1,12.0,119.0,-17.60,23.90,23.90,-17.60
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1487,9.1300,26.5,-231.0,6.53,26.4,6.230,3.190,9.7900,0.352,4.480,...,-652.00,508.0,-244.0,-244.0,508.0,-62.5,128.00,-51.60,-51.60,128.00
1488,26.1000,32.3,28.4,24.90,28.4,3.020,0.444,3.7100,2.720,-2.440,...,9.23,-41.0,-56.7,-56.7,-41.0,-10.6,-9.68,-138.00,-138.00,-9.68
1489,13.5000,31.1,-481.0,8.86,25.2,-1.050,-0.428,25.5000,2.030,0.315,...,-533.00,506.0,-252.0,-252.0,506.0,-444.0,461.00,-221.00,-221.00,461.00
1490,13.4000,18.3,-361.0,2.57,26.0,3.170,4.710,-0.0477,-0.202,-4.410,...,-289.00,284.0,-52.1,-52.1,284.0,-229.0,209.00,-61.50,-61.50,209.00


In [None]:
class EEGBrainWavePreTrain(Dataset):
    def __init__(self, X):
        # x.shape: (N, 2548)
        self.X = torch.from_numpy(X.to_numpy()[:, :-1].astype(np.float32)) # (N, 2547)
        self.y = torch.from_numpy(X.to_numpy()[:, -1].astype(np.float32)) # (N)
        self.y = self.y.view(-1, 1) # (N, 1)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]
class EEGBrainWaveFineTune(Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X.to_numpy().astype(np.float32))
        self.y = torch.from_numpy(y.to_numpy())

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
train_data = EEGBrainWavePreTrain(X_train)
test_data =  EEGBrainWavePreTrain(X_test)
train_data_fn = EEGBrainWaveFineTune(X_train, y_train)
test_data_fn =  EEGBrainWaveFineTune(X_test, y_test)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=1)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=1)

In [None]:
train_loader_fn = DataLoader(train_data_fn, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=1)
test_loader_fn = DataLoader(test_data_fn, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=1)
X_train.shape

(1492, 2548)

In [None]:
# Resize Layer
resize_and_rescale = tf.keras.Sequential([
  layers.experimental.preprocessing.Resizing(224,224),
  layers.experimental.preprocessing.Rescaling(1./255),
  layers.experimental.preprocessing.RandomFlip("horizontal"),
  layers.experimental.preprocessing.RandomRotation(0.1),
  layers.experimental.preprocessing.RandomZoom(0.1),
  layers.experimental.preprocessing.RandomContrast(0.1),
])


In [None]:
# Load the pretained model
pretrained_model = tf.keras.applications.efficientnet.EfficientNetB0(
    input_shape=(224, 224, 3),
    include_top=False,
    weights='imagenet',
    pooling='max'
)

pretrained_model.trainable = False

In [None]:
# Create checkpoint callback
checkpoint_path = "grape_disease_classification_model_checkpoint"
checkpoint_callback = ModelCheckpoint(checkpoint_path,
                                      save_weights_only=True,
                                      monitor="val_accuracy",
                                      save_best_only=True)

In [None]:
# Setup EarlyStopping callback to stop training if model's val_loss doesn't improve for 3 epochs
early_stopping = EarlyStopping(monitor = "val_loss", # watch the val loss metric
                               patience = 5,
                               restore_best_weights = True) # if val loss decreases for 3 epochs in a row, stop training

In [None]:
inputs = pretrained_model.input
x = resize_and_rescale(inputs)

x = Dense(128, activation='relu')(pretrained_model.output)
x = Dropout(0.45)(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.45)(x)


outputs = Dense(4, activation='softmax')(x)

model = Model(inputs=inputs, outputs=outputs)

model.compile(
    optimizer=Adam(0.00001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    train_images,
    steps_per_epoch=len(train_images),
    validation_data=val_images,
    validation_steps=len(val_images),
    epochs=100,
    callbacks=[
        early_stopping,
        create_tensorboard_callback("training_logs",
                                    "grape_classification"),
        checkpoint_callback,
    ]
)

In [1]:
model.evaluate(test_images, verbose=0)

Epoch 1/16
60/60 ━━━━━━━━━━━━━━━━━━━━ 140s 2s/step - accuracy: 0.4213 - loss: 0.4779 - val_accuracy: 0.4006 - val_loss: 0.4791
Epoch 2/16
60/60 ━━━━━━━━━━━━━━━━━━━━ 140s 2s/step - accuracy: 0.4513 - loss: 0.4779 - val_accuracy: 0.4506 - val_loss: 0.4791
Epoch 3/16
60/60 ━━━━━━━━━━━━━━━━━━━━ 141s 2s/step - accuracy: 0.4833 - loss: 0.4724 - val_accuracy: 0.4406 - val_loss: 0.4960
Epoch 4/16
60/60 ━━━━━━━━━━━━━━━━━━━━ 162s 2s/step - accuracy: 0.5093 - loss: 0.4703 - val_accuracy: 0.4806 - val_loss: 0.4681
Epoch 5/16
60/60 ━━━━━━━━━━━━━━━━━━━━ 183s 2s/step - accuracy: 0.5239 - loss: 0.4607 - val_accuracy: 0.5006 - val_loss: 0.4683
Epoch 6/16
60/60 ━━━━━━━━━━━━━━━━━━━━ 142s 2s/step - accuracy: 0.5404 - loss: 0.4954 - val_accuracy: 0.5206 - val_loss: 0.4995
Epoch 7/16
60/60 ━━━━━━━━━━━━━━━━━━━━ 129s 2s/step - accuracy: 0.6041 - loss: 0.4488 - val_accuracy: 0.5506 - val_loss: 0.5774
Epoch 8/16
60/60 ━━━━━━━━━━━━━━━━━━━━ 143s 2s/step - accuracy: 0.6229 - loss: 0.4863 - val_accuracy: 0.5806 - v