In [1]:
from model.cct import CCT
from torchinfo import summary

import warnings
warnings.filterwarnings("ignore")

In [2]:
import argparse
import os
import math
import glob
import random
import datetime
import time
import sys
import scipy.io

import torch 
import itertools
import pandas as pd 
import pickle 
import numpy as np
import mne
import matplotlib.pyplot as plt

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch import Tensor

from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange
from einops.layers.torch import Rearrange, Reduce

from torch.backends import cudnn
cudnn.benchmark = False
cudnn.deterministic = True

In [3]:
model = CCT(kernel_sizes=[(22, 1), (1, 24), (1, 24)], stride=(1, 1), padding=(0, 0),
            pooling_kernel_size=(3, 3), pooling_stride=(1, 1), pooling_padding=(0, 0),
            n_conv_layers=3, n_input_channels=1,
            in_planes=64, activation=None, # ReLU
            max_pool=False, conv_bias=False,
            dim=64, num_layers=3,
            num_heads=4, num_classes=2, 
            attn_dropout=0.1, dropout=0.1, 
            mlp_size=64, positional_emb="learnable")

In [4]:
summary(model=model,
        input_size=(64, 1, 22, 321),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                            Input Shape          Output Shape         Param #              Trainable
CCT (CCT)                                          [64, 1, 22, 321]     [64, 2]              --                   True
├─Tokenizer (tokenizer)                            [64, 1, 22, 321]     [64, 275, 64]        --                   True
│    └─Sequential (conv_layers)                    [64, 1, 22, 321]     [64, 64, 1, 275]     --                   True
│    │    └─Sequential (0)                         [64, 1, 22, 321]     [64, 64, 1, 321]     1,408                True
│    │    └─Sequential (1)                         [64, 64, 1, 321]     [64, 64, 1, 298]     98,304               True
│    │    └─Sequential (2)                         [64, 64, 1, 298]     [64, 64, 1, 275]     98,304               True
│    └─Flatten (flattener)                         [64, 64, 1, 275]     [64, 64, 275]        --                   --
├─Transformer (transformer)                  

## Dataset Loading

In [5]:
datasets = ['datasets/aBNCI2014001R.pickle', 'datasets/aBNCI2014004R.pickle']

In [6]:
import torch 
import itertools
import pandas as pd 
import pickle 
import numpy as np
import mne

In [7]:
def load_data(filename):
    with open(filename, 'rb') as handle:
        data = pickle.load(handle)
    return data

data = load_data(datasets[0])

In [8]:
class_name = ['left_hand', 'right_hand']
subject = 0
s1 = data[subject]
s1.get_data().shape

(288, 22, 321)

In [9]:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if dev.type == 'cuda':
    print('Your GPU device name :', torch.cuda.get_device_name()) 

Your GPU device name : NVIDIA GeForce RTX 3060 Laptop GPU


In [10]:
from sklearn.model_selection import train_test_split
from torch.nn.functional import softmax

In [11]:
class EEGCCT():
    def __init__(self, nsub, n_subj=9):
        super(ExP, self).__init__()
        self.batch_size = 36
        self.n_epochs = 25  #2000
        self.c_dim = 4
        self.lr = 3e-5
        self.b1 = 0.9
        self.b2 = 0.999
        self.dimension = (190, 50)
        self.nSub = nsub
        self.n_subjects = 8 # total?
        self.start_epoch = 0

        self.Tensor = torch.cuda.FloatTensor
        self.LongTensor = torch.cuda.LongTensor
        self.FloatTensor = torch.cuda.FloatTensor

        self.criterion_l1 = torch.nn.L1Loss().cuda()
        self.criterion_l2 = torch.nn.MSELoss().cuda()
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()

        self.model = model.cuda()
        if torch.cuda.is_available():
            self.model = self.model.to("cuda")
            
        self.total_params = sum(param.numel() for param in self.model.parameters())
        print("Number of parameters: ", self.total_params)

        #self.model = self.model.cuda()
        # summary(self.model, (1, 22, 1000))


    # Segmentation and Reconstruction (S&R) data augmentation
    def interaug(self, timg, label):  
        aug_data = []
        aug_label = []
        for cls4aug in range(2):
            cls_idx = np.where(label == cls4aug + 1)
            tmp_data = timg[cls_idx]
            tmp_label = label[cls_idx]

            tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 22, 321))
            for ri in range(int(self.batch_size / 2)):
                for rj in range(3):
                    rand_idx = np.random.randint(0, tmp_data.shape[0], 3)
                    tmp_aug_data[ri, :, :, rj * 107:(rj + 1) * 107] = tmp_data[rand_idx[rj], :, :,
                                                                      rj * 107:(rj + 1) * 107]

            aug_data.append(tmp_aug_data)
            aug_label.append(tmp_label[:int(self.batch_size / 2)])
        aug_data = np.concatenate(aug_data)
        aug_label = np.concatenate(aug_label)
        aug_shuffle = np.random.permutation(len(aug_data))
        aug_data = aug_data[aug_shuffle, :, :]
        aug_label = aug_label[aug_shuffle]

        aug_data = torch.from_numpy(aug_data).cuda()
        aug_data = aug_data.float()
        aug_label = torch.from_numpy(aug_label-1).cuda()
        aug_label = aug_label.long()
        return aug_data, aug_label

    def get_source_data(self):
        
        self.test_subject = self.nSub

        # Get the data from the epochs object
        self.data = load_data(datasets[0])
        print('Dataset: ', datasets[0])

        self.train_subjects = [i for i in range(self.n_subjects) if i != self.test_subject]

        # Prepare test data
        self.X_test = self.data[self.test_subject].get_data()
        self.y_test = self.data[self.test_subject].events[:, -1]

        # Prepare training data
        self.X_train = np.concatenate([self.data[i].get_data() for i in self.train_subjects], axis=0)
        self.y_train = np.concatenate([self.data[i].events[:, -1] for i in self.train_subjects], axis=0)

        # train and val data
        self.train_data, self.val_data, self.train_label, self.val_label = train_test_split(self.X_train, self.y_train, test_size=0.1, random_state=42)
        
        self.allData = np.expand_dims(self.train_data, axis=1)
        self.allLabel = self.train_label
        
        self.valData = np.expand_dims(self.val_data, axis=1)
        self.valLabel = self.val_label

        shuffle_num = np.random.permutation(len(self.allData))
        self.allData = self.allData[shuffle_num, :, :, :]
        self.allLabel = self.allLabel[shuffle_num]

        # test data  
        self.testData = np.expand_dims(self.X_test, axis=1)
        self.testLabel = self.y_test
        
        # standardize
        target_mean = np.mean(self.allData)
        target_std = np.std(self.allData)
        self.allData = (self.allData - target_mean) / target_std
        self.testData = (self.testData - target_mean) / target_std
        self.valData = (self.valData - target_mean) / target_std

        # data shape: (trial, conv channel, electrode channel, time samples)
        return self.allData, self.allLabel, self.valData, self.valLabel, self.testData, self.testLabel

    def train(self):
        train_accuracies = []
        val_accuracies = []
        train_losses = []
        val_losses = []

        #img, label, test_data, test_label = self.get_source_data()
        img, label, val_data, val_label, test_data, test_label = self.get_source_data()

        img = torch.from_numpy(img)
        label = torch.from_numpy(label - 1)
        dataset = torch.utils.data.TensorDataset(img, label)
        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)

        val_data = torch.from_numpy(val_data)
        val_label = torch.from_numpy(val_label - 1)
        val_dataset = torch.utils.data.TensorDataset(val_data, val_label)
        self.val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=self.batch_size, shuffle=True)

        test_data = torch.from_numpy(test_data)
        test_label = torch.from_numpy(test_label - 1)
        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)

        # Optimizers
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        test_data = Variable(test_data.type(self.Tensor))
        test_label = Variable(test_label.type(self.LongTensor))
        
        val_data = Variable(val_data.type(self.Tensor))
        val_label = Variable(val_label.type(self.LongTensor))
        
        bestAcc = 0
        averAcc = 0
        num = 0
        Y_true = 0
        Y_pred = 0

        # Train the cnn model
        total_step = len(self.dataloader)
        curr_lr = self.lr

        for e in range(self.n_epochs):
            # in_epoch = time.time()
            self.model.train()
            for i, (img, label) in enumerate(self.dataloader):

                img = Variable(img.cuda().type(self.Tensor))
                label = Variable(label.cuda().type(self.LongTensor)) #FloatTensor

                # data augmentation
                aug_data, aug_label = self.interaug(self.allData, self.allLabel)
                img = torch.cat((img, aug_data))
                label = torch.cat((label, aug_label))

                outputs = self.model(img)

                loss = self.criterion_cls(outputs, label) 

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # test process
            if (e + 1) % 1 == 0:
                self.model.eval()
                Cls = self.model(test_data)
                probs = softmax(Cls, dim=1).cpu().detach().numpy()
                loss_test = self.criterion_cls(Cls, test_label)
                y_pred = torch.max(Cls, 1)[1]
                acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))

                #self.model.eval()
                ValCls = self.model(val_data)
                loss_val = self.criterion_cls(ValCls, val_label)
                val_pred = torch.max(ValCls, 1)[1]
                val_acc = float((val_pred == val_label).cpu().numpy().astype(int).sum()) / float(val_label.size(0))
                
                train_pred = torch.max(outputs, 1)[1]
                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
                
                print('Epoch:', e,
                      '  Train loss: %.4f' % loss.detach().cpu().numpy(),
                      '  Val loss: %.4f' % loss_val.detach().cpu().numpy(),
                      '  Test loss: %.4f' % loss_test.detach().cpu().numpy(),
                      '  Train acc: %.4f' % train_acc,
                      '  Val acc: %.4f' % val_acc,
                      '  Test acc: %.4f' % acc)

                num = num + 1
                averAcc = averAcc + acc
                if acc > bestAcc:
                    bestAcc = acc
                    Y_true = test_label
                    Y_pred = y_pred
            
            train_accuracies.append(train_acc)
            val_accuracies.append(val_acc)
            train_losses.append(loss.detach().cpu().numpy())
            val_losses.append(loss_val.detach().cpu().numpy())

        #torch.save(self.model.module.state_dict(), 'model.pth')
        averAcc = averAcc / num
        print('The average accuracy is:', averAcc)
        print('The best accuracy is:', bestAcc)
        
        return bestAcc, averAcc, Y_true, Y_pred

In [12]:
def main():
    best = 0
    aver = 0

    for i in range(9):
        starttime = datetime.datetime.now()


        seed_n = np.random.randint(2021)
        print('seed is ' + str(seed_n))
        random.seed(seed_n)
        np.random.seed(seed_n)
        torch.manual_seed(seed_n)
        torch.cuda.manual_seed(seed_n)
        torch.cuda.manual_seed_all(seed_n)


        print('Subject %d' % (i+1))
        exp = EEGCCT(i)

        bestAcc, averAcc, Y_true, Y_pred = exp.train()
        print('THE BEST ACCURACY IS ' + str(bestAcc))

        endtime = datetime.datetime.now()
        print('subject %d duration: '%(i+1) + str(endtime - starttime))
        best = best + bestAcc
        aver = aver + averAcc
        if i == 0:
            yt = Y_true
            yp = Y_pred
        else:
            yt = torch.cat((yt, Y_true))
            yp = torch.cat((yp, Y_pred))


    best = best / 9
    aver = aver / 9
    
    print(f"Mean of best is {best}")
    print(f"Mean of average is {aver}")

In [13]:
main()

seed is 1136
Subject 1
Number of parameters:  290627
Dataset:  datasets/aBNCI2014001R.pickle
Epoch: 0   Train loss: 0.6863   Val loss: 0.6850   Test loss: 0.6868   Train acc: 0.6000   Val acc: 0.5198   Test acc: 0.5382
Epoch: 1   Train loss: 0.6318   Val loss: 0.6358   Test loss: 0.6390   Train acc: 0.7400   Val acc: 0.6584   Test acc: 0.6597
Epoch: 2   Train loss: 0.5974   Val loss: 0.5799   Test loss: 0.5494   Train acc: 0.6600   Val acc: 0.6931   Test acc: 0.7118
Epoch: 3   Train loss: 0.4945   Val loss: 0.5556   Test loss: 0.5290   Train acc: 0.7800   Val acc: 0.6931   Test acc: 0.6840
Epoch: 4   Train loss: 0.4559   Val loss: 0.5494   Test loss: 0.4858   Train acc: 0.8600   Val acc: 0.7327   Test acc: 0.7465
Epoch: 5   Train loss: 0.5249   Val loss: 0.5515   Test loss: 0.4798   Train acc: 0.7000   Val acc: 0.7228   Test acc: 0.7465
Epoch: 6   Train loss: 0.4858   Val loss: 0.5513   Test loss: 0.4601   Train acc: 0.7400   Val acc: 0.7228   Test acc: 0.7569
Epoch: 7   Train loss: 0.