In [2]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import random
from scipy.io import loadmat
from scipy import signal
from google.colab import drive
import warnings
import string
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from scipy import signal
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import scale
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import TensorDataset
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import wandb
from pytorch_lightning.loggers import WandbLogger

## Load data

In [3]:
SUBJECT_SELECTED = "B"
subject_names = ["A", "B"]

# Check for errors in the settings
if SUBJECT_SELECTED not in subject_names:
    raise ValueError("SUBJECT_SELECTED value {} is invalid.\nPlease enter one of the following parameters {}".format(SUBJECT_SELECTED, subject_names))

# Google drive data paths
MODEL_LOCATIONS_FILE_PATH = '/content/' + SUBJECT_SELECTED
SUBJECT_TRAIN_FILE_PATH = '/content/Subject_' + SUBJECT_SELECTED + '_Train.mat'
SUBJECT_TEST_FILE_PATH = '/content/Subject_' + SUBJECT_SELECTED + '_Test.mat'
CHANNEL_LOCATIONS_FILE_PATH = '/content/channels.csv'
CHANNEL_COORD = '/content/coordinates.csv'

# Channel selection
CHANNELS = [i for i in range(64)]


In [4]:
data = loadmat(SUBJECT_TRAIN_FILE_PATH)

signals = data['Signal']
flashing = data['Flashing']
stimulus = data['StimulusType']
word = data['TargetChar']

SAMPLING_FREQUENCY = 240

REPETITIONS = 15

RECORDING_DURATION = (len(signals))*(len(signals[0]))/(SAMPLING_FREQUENCY*60)

TRIALS = len(word[0])

BALANCE_DATASET = False

print("Sampling Frequency: %d Hz [%.2f ms]" % (SAMPLING_FREQUENCY, (1000/SAMPLING_FREQUENCY)))
print("Session duration:   %.2f min" % RECORDING_DURATION)
print("Number of letters:  %d" % TRIALS)
print("Spelled word:       %s" % ''.join(word))

Sampling Frequency: 240 Hz [4.17 ms]
Session duration:   46.01 min
Number of letters:  85
Spelled word:       VGREAAH8TVRHBYN_UGCOLO4EUERDOOHCIFOMDNU6LQCPKEIREKOYRQIDJXPBKOJDWZEUEWWFOEBHXTQTTZUMO


## Processing the training set data
1. Add bandpassfilter (set to 0.1-20Hz)
2. Reduce sampling frequency to 120Hz
3. Obtain 650ms windows at the start of each blink (175ms)
4. Normalize the samples in each window
5. Reshape each window into a 3D tensor of size: **(N_SAMPLES, 78, 64)**
6. Obtain class ratio **(noP300 / P300)** to balance the dataset during training

In [5]:
# Add filter butterworth 
b, a = signal.butter(4, [0.1/SAMPLING_FREQUENCY, 20/SAMPLING_FREQUENCY], 'bandpass')
for trial in range(TRIALS):
    signals[trial, :, :] = signal.filtfilt(b, a, signals[trial, :, :], axis=0)

In [6]:
# Down sampling frequency to 120hz
DOWNSAMPLING_FREQUENCY = 120
SCALE_FACTOR = round(SAMPLING_FREQUENCY / DOWNSAMPLING_FREQUENCY)
SAMPLING_FREQUENCY = DOWNSAMPLING_FREQUENCY

print("# Samples of EEG signals before downsampling: {}".format(len(signals[0])))

signals = signals[:, 0:-1:SCALE_FACTOR, :]
flashing = flashing[:, 0:-1:SCALE_FACTOR]
stimulus = stimulus[:, 0:-1:SCALE_FACTOR]

print("# Samples of EEG signals after downsampling: {}".format(len(signals[0])))

# Samples of EEG signals before downsampling: 7794
# Samples of EEG signals after downsampling: 3897


In [7]:
# Numbers channels of EEG
N_CHANNELS = len(CHANNELS)

# Windown duration
WINDOW_DURATION = 650

#Numbers of window samples
WINDOW_SAMPLES = round(SAMPLING_FREQUENCY * (WINDOW_DURATION / 1000))

# Numbers of samples per trial
SAMPLES_PER_TRIAL = len(signals[0])

# feature and labels
train_features = []
train_labels = []

count_positive = 0
count_negative = 0

for trial in range(TRIALS):
    for sample in (range(SAMPLES_PER_TRIAL)):
        if (sample == 0) or (flashing[trial, sample-1] == 0 and flashing[trial, sample] == 1):
            lower_sample = sample
            upper_sample = sample + WINDOW_SAMPLES
            window = signals[trial, lower_sample:upper_sample, :]
            train_features.append(window)
            if stimulus[trial, sample] == 1:
                count_positive += 1
                train_labels.append(1) 
            else:
                count_negative += 1
                train_labels.append(0) 

train_ratio = count_negative/count_positive

train_features = np.array(train_features)
train_labels = np.array(train_labels)

dim_train = train_features.shape
print("Features tensor shape: {}".format(dim_train))

Features tensor shape: (15300, 78, 64)


In [8]:
# Nomalization Zi = (Xi - mu) / sigma
for pattern in range(len(train_features)):
    train_features[pattern] = scale(train_features[pattern], axis=0)

[1;30;43mKết quả truyền trực tuyến bị cắt bớt đến 5000 dòng cuối.[0m


## Test set data processing
1. Apply bandpass filter **(0.1-20Hz)**;
2. Reduce the sample signal from 240Hz to 120Hz;
3. Obtain 650ms windows at the start of each blink (175ms)
4. Normalize the samples in each window
5. Reshape each window into a 3D tensor of size: **(N_SAMPLES, 78, 64)**
6. Calculate weight vector to balance the importance of samples and estimate accuracy

In [9]:
data_test = loadmat(SUBJECT_TEST_FILE_PATH)

signals_test = data_test['Signal']
flashing_test = data_test['Flashing']
word_test =  data_test['TargetChar']
stimulus_code_test = data_test['StimulusCode']

SAMPLING_FREQUENCY = 240
REPETITIONS = 15
RECORDING_DURATION_TEST = (len(signals_test))*(len(signals_test[0]))/(SAMPLING_FREQUENCY*60)
TRIALS_TEST = len(word_test[0])
SAMPLES_PER_TRIAL_TEST = len(signals_test[0])

print("Sampling Frequency: %d Hz [%.2f ms]" % (SAMPLING_FREQUENCY, (1000/SAMPLING_FREQUENCY)))
print("Session duration:   %.2f min" % RECORDING_DURATION_TEST)
print("Number of letters:  %d" % TRIALS_TEST)
print("Spelled word:       %s" % ''.join(word_test))

Sampling Frequency: 240 Hz [4.17 ms]
Session duration:   54.12 min
Number of letters:  100
Spelled word:       MERMIROOMUHJPXJOHUVLEORZP3GLOO7AUFDKEFTWEOOALZOP9ROCGZET1Y19EWX65QUYU7NAK_4YCJDVDNGQXODBEV2B5EFDIDNR


In [10]:
# Create characters matrix
char_matrix = [[0 for j in range(6)] for i in range(6)]
s = string.ascii_uppercase + '1' + '2' + '3' + '4' + '5' + '6' + '7' + '8' + '9' + '_'

# Append cols and rows in a list
list_matrix = []
for i in range(6):
    col = [s[j] for j in range(i, 36, 6)]
    list_matrix.append(col)
for i in range(6):
    row = [s[j] for j in range(i * 6, i * 6 + 6)]
    list_matrix.append(row)

# Create StimulusType array for the test set (missing from the given database)
stimulus_test = [[0 for j in range(SAMPLES_PER_TRIAL_TEST)] for i in range(TRIALS_TEST)]
stimulus_test = np.array(stimulus_test)

for trial in range(TRIALS_TEST):
    counter=0
    for sample in range(SAMPLES_PER_TRIAL_TEST):
        index = int(stimulus_code_test[trial, sample]) - 1
        if not index == -1:
            if word_test[0][trial] in list_matrix[index]:
                stimulus_test[trial, sample] = 1
            else:
                stimulus_test[trial, sample] = 0

In [11]:
b, a = signal.butter(4, [0.1/SAMPLING_FREQUENCY, 20/SAMPLING_FREQUENCY], 'bandpass')
for trial in range(TRIALS_TEST):
    signals_test[trial, :, :] = signal.filtfilt(b, a, signals_test[trial, :, :], axis=0)

In [12]:
DOWNSAMPLING_FREQUENCY = 120
SCALE_FACTOR = round(SAMPLING_FREQUENCY / DOWNSAMPLING_FREQUENCY)
SAMPLING_FREQUENCY = DOWNSAMPLING_FREQUENCY

print("# Samples of EEG signals before downsampling: {}".format(len(signals_test[0])))

signals_test = signals_test[:, 0:-1:SCALE_FACTOR, :]
flashing_test = flashing_test[:, 0:-1:SCALE_FACTOR]
stimulus_test = stimulus_test[:, 0:-1:SCALE_FACTOR]

print("# Samples of EEG signals after downsampling: {}".format(len(signals_test[0])))

# Samples of EEG signals before downsampling: 7794
# Samples of EEG signals after downsampling: 3897


In [13]:
N_CHANNELS = len(CHANNELS)
WINDOW_DURATION = 650
WINDOW_SAMPLES = round(SAMPLING_FREQUENCY * (WINDOW_DURATION / 1000))
SAMPLES_PER_TRIAL_TEST = len(signals[0])

test_features = []
test_labels = []
windowed_stimulus = []

count_positive = 0
count_negative = 0

for trial in range(TRIALS_TEST):
    for sample in (range(SAMPLES_PER_TRIAL_TEST)):
        if (sample == 0) or (flashing_test[trial, sample-1] == 0 and flashing_test[trial, sample] == 1):
            lower_sample = sample
            upper_sample = sample + WINDOW_SAMPLES
            window = signals_test[trial, lower_sample:upper_sample, :]
            # Extracting number of row/col in a window
            number_stimulus = int(stimulus_code_test[trial, sample])
            windowed_stimulus.append(number_stimulus)
            # Features extraction
            test_features.append(window)
            # Labels extraction
            if stimulus_test[trial, sample] == 1:
                count_positive += 1
                test_labels.append(1) # Class P300
            else:
                count_negative += 1
                test_labels.append(0) # Class no-P300

# Get test weights to take into account the number of classes
test_weights = []
for i in range(len(test_labels)):
    if test_labels[i] == 1:
        test_weights.append(len(test_labels)/count_positive)
    else:
        test_weights.append(len(test_labels)/count_negative)
test_weights = np.array(test_weights)

# Convert lists to numpy arrays
test_features = np.array(test_features)
test_labels = np.array(test_labels)

# 3D tensor (SAMPLES, 64, 78)
dim_test = test_features.shape
print("Features tensor shape: {}".format(dim_test))

Features tensor shape: (18000, 78, 64)


In [14]:
# Nomarlization Zi = (Xi - mu) / sigma
for pattern in range(len(test_features)):
    test_features[pattern] = scale(test_features[pattern], axis=0)

[1;30;43mKết quả truyền trực tuyến bị cắt bớt đến 5000 dòng cuối.[0m


## Models

In [15]:
def cecotti_normal(tensor):
    if tensor.dim() == 1:
        fan_in = tensor.size(0)
    elif tensor.dim() == 2:
        fan_in = tensor.size(1)
    else:
        receptive_field_size = 1
        for dim in tensor.size()[2:]:
            receptive_field_size *= dim
        fan_in = tensor.size(1) * receptive_field_size

    stddev = 1.0 / fan_in
    with torch.no_grad():
        return tensor.normal_(0.0, stddev)


In [16]:
def scaled_tanh(z):
    return 1.7159 * torch.tanh((2.0 / 3.0) * z)

In [17]:
class CNNModel(pl.LightningModule):
    def __init__(self, channels=64, filters=10):
        super(CNNModel, self).__init__()

        self.save_hyperparameters()

        self.conv1 = nn.Conv1d(in_channels=channels, out_channels=filters, kernel_size=1, padding='same')
        self.conv2 = nn.Conv1d(in_channels=filters, out_channels=50, kernel_size=13, stride=11)
        self.fc1 = nn.Linear(50 * 6, 100)  
        self.fc2 = nn.Linear(100, 1)

        self.sigmoid = nn.Sigmoid()

        self.loss_fn = nn.BCELoss()

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                cecotti_normal(m.weight)
                if m.bias is not None:
                    cecotti_normal(m.bias)

    def forward(self, x):
      x = x.permute(0, 2, 1)

      x = scaled_tanh(self.conv1(x))
      x = scaled_tanh(self.conv2(x))
      x = torch.flatten(x, start_dim=1)
      x = scaled_tanh(self.fc1(x))
      x = self.sigmoid(self.fc2(x))
      return x


    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y.unsqueeze(1)

        y_pred = self(x)

        loss = self.loss_fn(y_pred, y)

        acc = ((y_pred > 0.5).float() == y).float().mean()
        self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True)
        self.log("train_acc", acc, on_step=True, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y.unsqueeze(1)
        y_pred = self(x)
        loss = self.loss_fn(y_pred, y)
        acc = ((y_pred > 0.5).float() == y).float().mean()
        self.log("val_loss", loss, on_step=True, on_epoch=True, logger=True)
        self.log("val_acc", acc, on_step=True, on_epoch=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y = y.unsqueeze(1)
        y_pred = self(x)
        acc = ((y_pred > 0.5).float() == y).float().mean()
        self.log("test_acc", acc)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)


In [18]:
BATCH_SIZE = 256
EPOCHS = 200
VALID_SPLIT = 0.05

In [19]:
from sklearn.model_selection import train_test_split

train_feature, val_feature, train_label, val_label = train_test_split(
    train_features, train_labels, test_size=VALID_SPLIT, random_state=42
)

train_feature = torch.tensor(train_feature, dtype=torch.float32)
train_label = torch.tensor(train_label, dtype=torch.float32)
val_feature = torch.tensor(val_feature, dtype=torch.float32)
val_label = torch.tensor(val_label, dtype=torch.float32)

train_dataset = TensorDataset(train_feature, train_label)
val_dataset = TensorDataset(val_feature, val_label)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Number sample in train: {len(train_feature)}")
print(f"Number sample in validation: {len(val_feature)}")

Number sample in train: 14535
Number sample in validation: 765


In [20]:
from google.colab import userdata
userdata.get('wandb-key')

'3f819d8e70eee8b09b5b6e9cfa39c86d883ad8f2'

In [21]:
checkpoint = ModelCheckpoint(
    dirpath=MODEL_LOCATIONS_FILE_PATH,
    filename="model-best",
    monitor="val_loss",
    mode="min",
    save_top_k=1
)

earlystop = EarlyStopping(
    monitor="val_loss",
    mode="min",
    patience=20
)

wandb_logger = WandbLogger(project="Speller-classification", name="speller", config={
    "learning_rate": 0.001,
    "batch_size": BATCH_SIZE,
    "optim": "Adam",
    "max_epochs": EPOCHS
    }
)

trainer = pl.Trainer(
    max_epochs=EPOCHS,
    callbacks=[checkpoint, earlystop],
    logger=wandb_logger
)
wandb.finish()

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [22]:
model = CNNModel(channels=64, filters=10)
trainer.fit(model, train_loader, val_loader)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33m22022638[0m ([33mxuan_nguyen_ai_uet-vietnam-national-university-hanoi[0m). Use [1m`wandb login --relogin`[0m to force relogin


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /content/B exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type    | Params | Mode 
--------------------------------------------
0 | conv1   | Conv1d  | 650    | train
1 | conv2   | Conv1d  | 6.5 K  | train
2 | fc1     | Linear  | 30.1 K | train
3 | fc2     | Linear  | 101    | train
4 | sigmoid | Sigmoid | 0      | train
5 | loss_fn | BCELoss | 0      | train
--------------------------------------------
37.4 K    Trainable params
0         Non-trainable params
37.4 K    Total params
0.150     Total estimated model params size (MB)
6         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [23]:
test_feature = torch.tensor(test_features, dtype=torch.float32)
test_label = torch.tensor(test_labels, dtype=torch.float32)

test_dataset = TensorDataset(test_feature, test_label)

test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

trainer.test(model, dataloaders=test_loader)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_acc': 0.8525555729866028}]