In [3]:
"""
CTNet: A Convolution-Transformer Network for EEG-Based Motor Imagery Classification

author: zhaowei701@163.com

"""

import os
gpus = [0]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import numpy as np
import pandas as pd
import random
import datetime
import time

from pandas import ExcelWriter
from torchsummary import summary
import torch
from torch.backends import cudnn
from utils import calMetrics
from utils import calculatePerClass
from utils import numberClassChannel
import math
import warnings
warnings.filterwarnings("ignore")
cudnn.benchmark = False
cudnn.deterministic = True



import torch
from torch import nn
from torch import Tensor
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, reduce, repeat
import torch.nn.functional as F

from utils import numberClassChannel
from utils import load_data_evaluate

import numpy as np
import pandas as pd
from torch.autograd import Variable







class PatchEmbeddingCNN(nn.Module):
    def __init__(self, f1=16, kernel_size=64, D=2, pooling_size1=8, pooling_size2=8, dropout_rate=0.3, number_channel=22, emb_size=40):
        super().__init__()
        f2 = D*f1
        self.cnn_module = nn.Sequential(
            # temporal conv kernel size 64=0.25fs
            nn.Conv2d(1, f1, (1, kernel_size), (1, 1), padding='same', bias=False), # [batch, 22, 1000] 
            nn.BatchNorm2d(f1),
            # channel depth-wise conv
            nn.Conv2d(f1, f2, (number_channel, 1), (1, 1), groups=f1, padding='valid', bias=False), # 
            nn.BatchNorm2d(f2),
            nn.ELU(),
            # average pooling 1
            nn.AvgPool2d((1, pooling_size1)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
            nn.Dropout(dropout_rate),
            # spatial conv
            nn.Conv2d(f2, f2, (1, 16), padding='same', bias=False), 
            nn.BatchNorm2d(f2),
            nn.ELU(),

            # average pooling 2 to adjust the length of feature into transformer encoder
            nn.AvgPool2d((1, pooling_size2)),
            nn.Dropout(dropout_rate),  
                    
        )

        self.projection = nn.Sequential(
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        
        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.cnn_module(x)
        x = self.projection(x)
        return x
    
    
  
    
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
    


# PointWise FFN
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )



class ClassificationHead(nn.Sequential):
    def __init__(self, flatten_number, n_classes):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(flatten_number, n_classes)
        )

    def forward(self, x):
        out = self.fc(x)
        
        return out


class ResidualAdd(nn.Module):
    def __init__(self, fn, emb_size, drop_p):
        super().__init__()
        self.fn = fn
        self.drop = nn.Dropout(drop_p)
        self.layernorm = nn.LayerNorm(emb_size)

    def forward(self, x, **kwargs):
        x_input = x
        res = self.fn(x, **kwargs)
        
        out = self.layernorm(self.drop(res)+x_input)
        return out

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=4,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                MultiHeadAttention(emb_size, num_heads, drop_p),
                ), emb_size, drop_p),
            ResidualAdd(nn.Sequential(
                FeedForwardBlock(emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                ), emb_size, drop_p)
            
            )    
        
        
class TransformerEncoder(nn.Sequential):
    def __init__(self, heads, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size, heads) for _ in range(depth)])




class BranchEEGNetTransformer(nn.Sequential):
    def __init__(self, heads=4, 
                 depth=6, 
                 emb_size=40, 
                 number_channel=22,
                 f1 = 20,
                 kernel_size = 64,
                 D = 2,
                 pooling_size1 = 8,
                 pooling_size2 = 8,
                 dropout_rate = 0.3,
                 **kwargs):
        super().__init__(
            PatchEmbeddingCNN(f1=f1, 
                                 kernel_size=kernel_size,
                                 D=D, 
                                 pooling_size1=pooling_size1, 
                                 pooling_size2=pooling_size2, 
                                 dropout_rate=dropout_rate,
                                 number_channel=number_channel,
                                 emb_size=emb_size),
#             TransformerEncoder(heads, depth, emb_size),
        )



    
        
        
class PositioinalEncoding(nn.Module):
    def __init__(self, embedding, length=100, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.encoding = nn.Parameter(torch.randn(1, length, embedding))
    def forward(self, x): # x-> [batch, embedding, length]
        x = x + self.encoding[:, :x.shape[1], :].cuda()
        return self.dropout(x)        
        
   
        
        
        
    
class EEGTransformer(nn.Module):
    def __init__(self, heads=4, 
                 emb_size=40,
                 depth=6, 
                 database_type='A', 
                 eeg1_f1 = 20,
                 eeg1_kernel_size = 64,
                 eeg1_D = 2,
                 eeg1_pooling_size1 = 8,
                 eeg1_pooling_size2 = 8,
                 eeg1_dropout_rate = 0.3,
                 eeg1_number_channel = 22,
                 flatten_eeg1 = 600,
                 **kwargs):
        super().__init__()
        self.number_class, self.number_channel = numberClassChannel(database_type)
        self.emb_size = emb_size
        self.flatten_eeg1 = flatten_eeg1
        self.flatten = nn.Flatten()
        # print('self.number_channel', self.number_channel)
        self.cnn = BranchEEGNetTransformer(heads, depth, emb_size, number_channel=self.number_channel,
                                              f1 = eeg1_f1,
                                              kernel_size = eeg1_kernel_size,
                                              D = eeg1_D,
                                              pooling_size1 = eeg1_pooling_size1,
                                              pooling_size2 = eeg1_pooling_size2,
                                              dropout_rate = eeg1_dropout_rate,
                                              )
        self.position = PositioinalEncoding(emb_size, dropout=0.1)
        self.trans = TransformerEncoder(heads, depth, emb_size)
#         self.drop = nn.Dropout(0.5)
        
        
        # self.cnn_module = Branchcnn_moduleTransformer(heads, depth, emb_size)
        self.flatten = nn.Flatten()
        self.classification = ClassificationHead(self.flatten_eeg1 , self.number_class) # FLATTEN_EEGNet + FLATTEN_cnn_module
    def forward(self, x):
        cnn = self.cnn(x)

        # add label 
        cnn = cnn * math.sqrt(self.emb_size)
        cnn = self.position(cnn)
        
        trans = self.trans(cnn)
        
        features = cnn+trans
        
        
        out = self.classification(self.flatten(features))
        return features, out


class ExP():
    def __init__(self, nsub, data_dir, result_name, 
                 epochs=2000, 
                 number_aug=2,
                 number_seg=8, 
                 gpus=[0], 
                 evaluate_mode = 'subject-dependent',
                 heads=4, 
                 emb_size=40,
                 depth=6, 
                 dataset_type='A',
                 eeg1_f1 = 20,
                 eeg1_kernel_size = 64,
                 eeg1_D = 2,
                 eeg1_pooling_size1 = 8,
                 eeg1_pooling_size2 = 8,
                 eeg1_dropout_rate = 0.3,
                 flatten_eeg1 = 600, 
                 validate_ratio = 0.2,
                 learning_rate = 0.001,
                 batch_size = 72,  
                 ):
        
        super(ExP, self).__init__()
        self.dataset_type = dataset_type
        self.batch_size = batch_size
        self.lr = learning_rate
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_epochs = epochs
        self.nSub = nsub
        self.number_augmentation = number_aug
        self.number_seg = number_seg
        self.root = data_dir
        self.heads=heads
        self.emb_size=emb_size
        self.depth=depth
        self.result_name = result_name
        self.evaluate_mode = evaluate_mode
        self.validate_ratio = validate_ratio

        self.Tensor = torch.cuda.FloatTensor
        self.LongTensor = torch.cuda.LongTensor
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()

        self.number_class, self.number_channel = numberClassChannel(self.dataset_type)
        self.model = EEGTransformer(
             heads=self.heads, 
             emb_size=self.emb_size,
             depth=self.depth, 
            database_type=self.dataset_type, 
            eeg1_f1=eeg1_f1, 
            eeg1_D=eeg1_D,
            eeg1_kernel_size=eeg1_kernel_size,
            eeg1_pooling_size1 = eeg1_pooling_size1,
            eeg1_pooling_size2 = eeg1_pooling_size2,
            eeg1_dropout_rate = eeg1_dropout_rate,
            eeg1_number_channel = self.number_channel,
            flatten_eeg1 = flatten_eeg1,  
            ).cuda()
        #self.model = nn.DataParallel(self.model, device_ids=gpus)
        self.model = self.model.cuda()
        self.model_filename = self.result_name + '/model_{}.pth'.format(self.nSub)

    # Segmentation and Reconstruction (S&R) data augmentation
    def interaug(self, timg, label):  
        aug_data = []
        aug_label = []
        number_records_by_augmentation = self.number_augmentation * int(self.batch_size / self.number_class)
        number_segmentation_points = 1000 // self.number_seg
        for clsAug in range(self.number_class):
            cls_idx = np.where(label == clsAug + 1)
            tmp_data = timg[cls_idx]
            tmp_label = label[cls_idx]
            
            tmp_aug_data = np.zeros((number_records_by_augmentation, 1, self.number_channel, 1000))
            for ri in range(number_records_by_augmentation):
                for rj in range(self.number_seg):
                    rand_idx = np.random.randint(0, tmp_data.shape[0], self.number_seg)
                    tmp_aug_data[ri, :, :, rj * number_segmentation_points:(rj + 1) * number_segmentation_points] = \
                        tmp_data[rand_idx[rj], :, :, rj * number_segmentation_points:(rj + 1) * number_segmentation_points]

            aug_data.append(tmp_aug_data)
            aug_label.append(tmp_label[:number_records_by_augmentation])
        aug_data = np.concatenate(aug_data)
        aug_label = np.concatenate(aug_label)
        aug_shuffle = np.random.permutation(len(aug_data))
        aug_data = aug_data[aug_shuffle, :, :]
        aug_label = aug_label[aug_shuffle]

        aug_data = torch.from_numpy(aug_data).cuda()
        aug_data = aug_data.float()
        aug_label = torch.from_numpy(aug_label-1).cuda()
        aug_label = aug_label.long()
        return aug_data, aug_label



    def get_source_data(self):
        (self.train_data,    # (batch, channel, length)
         self.train_label, 
         self.test_data, 
         self.test_label) = load_data_evaluate(self.root, self.dataset_type, self.nSub, mode_evaluate=self.evaluate_mode)

        self.train_data = np.expand_dims(self.train_data, axis=1)  # (288, 1, 22, 1000)
        self.train_label = np.transpose(self.train_label)  

        self.allData = self.train_data
        self.allLabel = self.train_label[0]  

        shuffle_num = np.random.permutation(len(self.allData))
        # print("len(self.allData):", len(self.allData))
        self.allData = self.allData[shuffle_num, :, :, :]  # (288, 1, 22, 1000)
        # print("shuffle_num", shuffle_num)
        # print("self.allLabel", self.allLabel)
        self.allLabel = self.allLabel[shuffle_num]


        print('-'*20, "train size：", self.train_data.shape, "test size：", self.test_data.shape)
        # self.test_data = np.transpose(self.test_data, (2, 1, 0))
        self.test_data = np.expand_dims(self.test_data, axis=1)
        self.test_label = np.transpose(self.test_label)

        self.testData = self.test_data
        self.testLabel = self.test_label[0]


        # standardize
        target_mean = np.mean(self.allData)
        target_std = np.std(self.allData)
        self.allData = (self.allData - target_mean) / target_std
        self.testData = (self.testData - target_mean) / target_std
        
        isSaveDataLabel = False #True
        if isSaveDataLabel:
            np.save("./gradm_data/train_data_{}.npy".format(self.nSub), self.allData)
            np.save("./gradm_data/train_lable_{}.npy".format(self.nSub), self.allLabel)
            np.save("./gradm_data/test_data_{}.npy".format(self.nSub), self.testData)
            np.save("./gradm_data/test_label_{}.npy".format(self.nSub), self.testLabel)

        
        # data shape: (trial, conv channel, electrode channel, time samples)
        return self.allData, self.allLabel, self.testData, self.testLabel


    def train(self):
        img, label, test_data, test_label = self.get_source_data()
        # print("label size:", label.shape)
        # print("label size:", label)
        
        img = torch.from_numpy(img)
        label = torch.from_numpy(label - 1)
        dataset = torch.utils.data.TensorDataset(img, label)
        

        test_data = torch.from_numpy(test_data)
        test_label = torch.from_numpy(test_label - 1)
        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False)

        # Optimizers
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        test_data = Variable(test_data.type(self.Tensor))
        test_label = Variable(test_label.type(self.LongTensor))
        best_epoch = 0
        num = 0
        min_loss = 100
        # recording train_acc, train_loss, test_acc, test_loss
        result_process = []
        # Train the cnn model
        for e in range(self.n_epochs):
            self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
            epoch_process = {}
            epoch_process['epoch'] = e
            # in_epoch = time.time()
            self.model.train()
            outputs_list = []
            label_list = []
            # 验证集
            val_data_list = []
            val_label_list = []
            for i, (img, label) in enumerate(self.dataloader):
                number_sample = img.shape[0]
                number_validate = int(self.validate_ratio * number_sample)
                
                # split raw train dataset into real train dataset and validate dataset
                train_data = img[:-number_validate]
                train_label = label[:-number_validate]
                
                val_data_list.append(img[number_validate:])
                val_label_list.append(label[number_validate:])
                
                # real train dataset
                img = Variable(train_data.type(self.Tensor))
                label = Variable(train_label.type(self.LongTensor))
                
                # data augmentation
                aug_data, aug_label = self.interaug(self.allData, self.allLabel)
                # concat real train dataset and generate aritifical train dataset
                img = torch.cat((img, aug_data))
                label = torch.cat((label, aug_label))

                # training model
                features, outputs = self.model(img)
                outputs_list.append(outputs)
                label_list.append(label)
                # print("train outputs: ", outputs.shape, type(outputs))
                # print(features.size())
                loss = self.criterion_cls(outputs, label) 
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            del img
            torch.cuda.empty_cache()
            # out_epoch = time.time()
            # test process
            if (e + 1) % 1 == 0:
                self.model.eval()
                # validate model
                val_data = torch.cat(val_data_list).cuda()
                val_label = torch.cat(val_label_list).cuda()
                val_data = val_data.type(self.Tensor)
                val_label = val_label.type(self.LongTensor)            
                
                val_dataset = torch.utils.data.TensorDataset(val_data, val_label)
                self.val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=self.batch_size, shuffle=False)
                outputs_list = []
                with torch.no_grad():
                    for i, (img, _) in enumerate(self.val_dataloader):
                        # val model
                        img = img.type(self.Tensor).cuda()
                        _, Cls = self.model(img)
                        outputs_list.append(Cls)
                        del img, Cls
                        torch.cuda.empty_cache()
                    
                Cls = torch.cat(outputs_list)
                
                val_loss = self.criterion_cls(Cls, val_label)
                val_pred = torch.max(Cls, 1)[1]
                val_acc = float((val_pred == val_label).cpu().numpy().astype(int).sum()) / float(val_label.size(0))
                
                epoch_process['val_acc'] = val_acc                
                epoch_process['val_loss'] = val_loss.detach().cpu().numpy()  
                
                train_pred = torch.max(outputs, 1)[1]
                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
                epoch_process['train_acc'] = train_acc
                epoch_process['train_loss'] = loss.detach().cpu().numpy()

                num = num + 1

                # if min_loss>val_loss:                
                if min_loss>val_loss:
                    min_loss = val_loss
                    best_epoch = e
                    epoch_process['epoch'] = e
                    torch.save(self.model, self.model_filename)
                    print("{}_{} train_acc: {:.4f} train_loss: {:.6f}\tval_acc: {:.6f} val_loss: {:.7f}".format(self.nSub,
                                                                                           epoch_process['epoch'],
                                                                                           epoch_process['train_acc'],
                                                                                           epoch_process['train_loss'],
                                                                                           epoch_process['val_acc'],
                                                                                           epoch_process['val_loss'],
                                                                                        ))
            
                
            result_process.append(epoch_process)  

        
            del label, val_data, val_label
            torch.cuda.empty_cache()
        
        # load model for test
        self.model.eval()
        self.model = torch.load(self.model_filename).cuda()
        outputs_list = []
        with torch.no_grad():
            for i, (img, label) in enumerate(self.test_dataloader):
                img_test = Variable(img.type(self.Tensor)).cuda()
                # label_test = Variable(label.type(self.LongTensor))

                # test model
                features, outputs = self.model(img_test)
                val_pred = torch.max(outputs, 1)[1]
                outputs_list.append(outputs)
        outputs = torch.cat(outputs_list) 
        y_pred = torch.max(outputs, 1)[1]
        
        
        test_acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
        
        print("epoch: ", best_epoch, '\tThe test accuracy is:', test_acc)


        df_process = pd.DataFrame(result_process)

        return test_acc, test_label, y_pred, df_process, best_epoch
        # writer.close()
        







def main(dirs,                
         evaluate_mode = 'subject-dependent', # 评估模式：LOSO（跨个体）或其他（subject-dependent, subject-specific），
         heads=8,             # heads of MHA
         emb_size=48,         # token embding dim
         depth=3,             # Transformer encoder depth
         dataset_type='A',    # A->'BCI IV2a', B->'BCI IV2b'
         eeg1_f1=20,          # features of temporal conv
         eeg1_kernel_size=64, # kernel size of temporal conv
         eeg1_D=2,            # depth-wise conv 
         eeg1_pooling_size1=8,# p1
         eeg1_pooling_size2=8,# p2
         eeg1_dropout_rate=0.3,
         flatten_eeg1=600,   
         validate_ratio = 0.2
         ):

    if not os.path.exists(dirs):
        os.makedirs(dirs)

    result_write_metric = ExcelWriter(dirs+"/result_metric.xlsx")
    
    result_metric_dict = {}
    y_true_pred_dict = { }

    process_write = ExcelWriter(dirs+"/process_train.xlsx")
    pred_true_write = ExcelWriter(dirs+"/pred_true.xlsx")
    subjects_result = []
    best_epochs = []
    
    for i in range(N_SUBJECT):      
        
        starttime = datetime.datetime.now()
        seed_n = np.random.randint(2024)
        print('seed is ' + str(seed_n))
        random.seed(seed_n)
        np.random.seed(seed_n)
        torch.manual_seed(seed_n)
        torch.cuda.manual_seed(seed_n)
        torch.cuda.manual_seed_all(seed_n)
        index_round =0
        print('Subject %d' % (i+1))
        exp = ExP(i + 1, DATA_DIR, dirs, EPOCHS, N_AUG, N_SEG, gpus, 
                  evaluate_mode = evaluate_mode,
                  heads=heads, 
                  emb_size=emb_size,
                  depth=depth, 
                  dataset_type=dataset_type,
                  eeg1_f1 = eeg1_f1,
                  eeg1_kernel_size = eeg1_kernel_size,
                  eeg1_D = eeg1_D,
                  eeg1_pooling_size1 = eeg1_pooling_size1,
                  eeg1_pooling_size2 = eeg1_pooling_size2,
                  eeg1_dropout_rate = eeg1_dropout_rate,
                  flatten_eeg1 = flatten_eeg1,  
                  validate_ratio = validate_ratio
                  )

        testAcc, Y_true, Y_pred, df_process, best_epoch = exp.train()
        true_cpu = Y_true.cpu().numpy().astype(int)
        pred_cpu = Y_pred.cpu().numpy().astype(int)
        df_pred_true = pd.DataFrame({'pred': pred_cpu, 'true': true_cpu})
        df_pred_true.to_excel(pred_true_write, sheet_name=str(i+1))
        y_true_pred_dict[i] = df_pred_true

        accuracy, precison, recall, f1, kappa = calMetrics(true_cpu, pred_cpu)
        subject_result = {'accuray': accuracy*100,
                          'precision': precison*100,
                          'recall': recall*100,
                          'f1': f1*100, 
                          'kappa': kappa*100
                          }
        subjects_result.append(subject_result)
        df_process.to_excel(process_write, sheet_name=str(i+1))
        best_epochs.append(best_epoch)
    
        print(' THE BEST ACCURACY IS ' + str(testAcc) + "\tkappa is " + str(kappa) )
    

        endtime = datetime.datetime.now()
        print('subject %d duration: '%(i+1) + str(endtime - starttime))

        if i == 0:
            yt = Y_true
            yp = Y_pred
        else:
            yt = torch.cat((yt, Y_true))
            yp = torch.cat((yp, Y_pred))
                
        df_result = pd.DataFrame(subjects_result)
    process_write.close()
    pred_true_write.close()


    print('**The average Best accuracy is: ' + str(df_result['accuray'].mean()) + "kappa is: " + str(df_result['kappa'].mean()) + "\n" )
    print("best epochs: ", best_epochs)
    #df_result.to_excel(result_write_metric, index=False)
    result_metric_dict = df_result

    mean = df_result.mean(axis=0)
    mean.name = 'mean'
    std = df_result.std(axis=0)
    std.name = 'std'
    df_result = pd.concat([df_result, pd.DataFrame(mean).T, pd.DataFrame(std).T])
    
    df_result.to_excel(result_write_metric, index=False)
    print('-'*9, ' all result ', '-'*9)
    print(df_result)
    
    print("*"*40)

    result_write_metric.close()

    
    return result_metric_dict

if __name__ == "__main__":
    #----------------------------------------
    DATA_DIR = r'../mymat_raw/'
    EVALUATE_MODE = 'LOSO-No' # leaving one subject out subject-dependent  subject-indenpedent

    N_SUBJECT = 9       # BCI 
    N_AUG = 3           # data augmentation times for benerating artificial training data set
    N_SEG = 8           # segmentation times for S&R

    EPOCHS = 1000
    EMB_DIM = 16
    HEADS = 2
    DEPTH = 6
    TYPE = 'B'
    validate_ratio = 0.3 # split raw train dataset into real train dataset and validate dataset

    EEGNet1_F1 = 8
    EEGNet1_KERNEL_SIZE=64
    EEGNet1_D=2
    EEGNet1_POOL_SIZE1 = 8
    EEGNet1_POOL_SIZE2 = 8
    FLATTEN_EEGNet1 = 240

    if EVALUATE_MODE!='LOSO':
        EEGNet1_DROPOUT_RATE = 0.5
    else:
        EEGNet1_DROPOUT_RATE = 0.25    

    
    parameters_list = ['A']
    for TYPE in parameters_list:
        number_class, number_channel = numberClassChannel(TYPE)
        RESULT_NAME = "CTNet_{}_heads_{}_depth_{}_{}".format(TYPE, HEADS, DEPTH, int(time.time()))
    
        sModel = EEGTransformer(
            heads=HEADS, 
            emb_size=EMB_DIM,
            depth=DEPTH, 
            database_type=TYPE,
            eeg1_f1=EEGNet1_F1, 
            eeg1_D=EEGNet1_D,
            eeg1_kernel_size=EEGNet1_KERNEL_SIZE,
            eeg1_pooling_size1 = EEGNet1_POOL_SIZE1,
            eeg1_pooling_size2 = EEGNet1_POOL_SIZE2,
            eeg1_dropout_rate = EEGNet1_DROPOUT_RATE,
            eeg1_number_channel = number_channel,
            flatten_eeg1 = FLATTEN_EEGNet1,  
            ).cuda()
        summary(sModel, (1, number_channel, 1000)) 
    
        print(time.asctime(time.localtime(time.time())))
        
        result = main(RESULT_NAME,
                        evaluate_mode = EVALUATE_MODE,
                        heads=HEADS, 
                        emb_size=EMB_DIM,
                        depth=DEPTH, 
                        dataset_type=TYPE,
                        eeg1_f1 = EEGNet1_F1,
                        eeg1_kernel_size = EEGNet1_KERNEL_SIZE,
                        eeg1_D = EEGNet1_D,
                        eeg1_pooling_size1 = EEGNet1_POOL_SIZE1,
                        eeg1_pooling_size2 = EEGNet1_POOL_SIZE2,
                        eeg1_dropout_rate = EEGNet1_DROPOUT_RATE,
                        flatten_eeg1 = FLATTEN_EEGNet1,
                        validate_ratio = validate_ratio,
                      )
        print(time.asctime(time.localtime(time.time())))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 8, 22, 1000]             512
       BatchNorm2d-2          [-1, 8, 22, 1000]              16
            Conv2d-3          [-1, 16, 1, 1000]             352
       BatchNorm2d-4          [-1, 16, 1, 1000]              32
               ELU-5          [-1, 16, 1, 1000]               0
         AvgPool2d-6           [-1, 16, 1, 125]               0
           Dropout-7           [-1, 16, 1, 125]               0
            Conv2d-8           [-1, 16, 1, 125]           4,096
       BatchNorm2d-9           [-1, 16, 1, 125]              32
              ELU-10           [-1, 16, 1, 125]               0
        AvgPool2d-11            [-1, 16, 1, 15]               0
          Dropout-12            [-1, 16, 1, 15]               0
        Rearrange-13               [-1, 15, 16]               0
PatchEmbeddingCNN-14               [-1,

1_2 train_acc: 0.3146 train_loss: 1.977199	val_acc: 0.436275 val_loss: 1.2736508
1_3 train_acc: 0.3333 train_loss: 1.755401	val_acc: 0.465686 val_loss: 1.2389588
1_4 train_acc: 0.3109 train_loss: 1.899612	val_acc: 0.504902 val_loss: 1.2232871
1_5 train_acc: 0.3820 train_loss: 1.592768	val_acc: 0.563725 val_loss: 1.1546565
1_6 train_acc: 0.4045 train_loss: 1.511896	val_acc: 0.612745 val_loss: 1.0918190
1_7 train_acc: 0.4195 train_loss: 1.373556	val_acc: 0.671569 val_loss: 0.9923918
1_8 train_acc: 0.4045 train_loss: 1.432190	val_acc: 0.661765 val_loss: 0.9919081
1_9 train_acc: 0.4607 train_loss: 1.275380	val_acc: 0.681373 val_loss: 0.9286916
1_10 train_acc: 0.4120 train_loss: 1.381421	val_acc: 0.715686 val_loss: 0.8334206
1_11 train_acc: 0.5169 train_loss: 1.131351	val_acc: 0.750000 val_loss: 0.7718888
1_12 train_acc: 0.5393 train_loss: 1.087529	val_acc: 0.759804 val_loss: 0.7452695
1_13 train_acc: 0.4944 train_loss: 1.200974	val_acc: 0.769608 val_loss: 0.6812358
1_14 train_acc: 0.5805 t

2_30 train_acc: 0.5693 train_loss: 0.993822	val_acc: 0.779412 val_loss: 0.7722383
2_31 train_acc: 0.5543 train_loss: 1.009177	val_acc: 0.759804 val_loss: 0.7513396
2_33 train_acc: 0.5918 train_loss: 1.009498	val_acc: 0.779412 val_loss: 0.7212336
2_35 train_acc: 0.6180 train_loss: 0.933551	val_acc: 0.764706 val_loss: 0.7078946
2_36 train_acc: 0.6554 train_loss: 0.871794	val_acc: 0.774510 val_loss: 0.6741776
2_38 train_acc: 0.6180 train_loss: 0.969289	val_acc: 0.828431 val_loss: 0.6347200
2_41 train_acc: 0.5843 train_loss: 0.945853	val_acc: 0.813725 val_loss: 0.6230605
2_44 train_acc: 0.6255 train_loss: 0.904311	val_acc: 0.818627 val_loss: 0.6127021
2_46 train_acc: 0.6592 train_loss: 0.872641	val_acc: 0.813725 val_loss: 0.5762656
2_50 train_acc: 0.6629 train_loss: 0.850594	val_acc: 0.818627 val_loss: 0.5724664
2_52 train_acc: 0.6742 train_loss: 0.833364	val_acc: 0.838235 val_loss: 0.5640951
2_54 train_acc: 0.5993 train_loss: 0.928978	val_acc: 0.833333 val_loss: 0.5452684
2_55 train_acc: 

3_23 train_acc: 0.6180 train_loss: 0.954799	val_acc: 0.794118 val_loss: 0.5971347
3_28 train_acc: 0.6742 train_loss: 0.812142	val_acc: 0.799020 val_loss: 0.5860439
3_29 train_acc: 0.6742 train_loss: 0.800904	val_acc: 0.784314 val_loss: 0.5835966
3_30 train_acc: 0.6629 train_loss: 0.787365	val_acc: 0.808824 val_loss: 0.5331498
3_36 train_acc: 0.6891 train_loss: 0.801861	val_acc: 0.823529 val_loss: 0.4819491
3_39 train_acc: 0.7154 train_loss: 0.776719	val_acc: 0.843137 val_loss: 0.4503557
3_44 train_acc: 0.7191 train_loss: 0.703306	val_acc: 0.852941 val_loss: 0.4172957
3_46 train_acc: 0.7341 train_loss: 0.683137	val_acc: 0.862745 val_loss: 0.4033098
3_51 train_acc: 0.6966 train_loss: 0.712021	val_acc: 0.877451 val_loss: 0.3869971
3_52 train_acc: 0.7416 train_loss: 0.605647	val_acc: 0.867647 val_loss: 0.3639707
3_65 train_acc: 0.7978 train_loss: 0.482566	val_acc: 0.862745 val_loss: 0.3441083
3_66 train_acc: 0.7715 train_loss: 0.514149	val_acc: 0.892157 val_loss: 0.3112175
3_67 train_acc: 

4_115 train_acc: 0.7528 train_loss: 0.646819	val_acc: 0.887255 val_loss: 0.3070424
4_121 train_acc: 0.8052 train_loss: 0.559751	val_acc: 0.897059 val_loss: 0.2935484
4_123 train_acc: 0.7828 train_loss: 0.555073	val_acc: 0.911765 val_loss: 0.2908670
4_127 train_acc: 0.7603 train_loss: 0.603969	val_acc: 0.916667 val_loss: 0.2789775
4_134 train_acc: 0.7678 train_loss: 0.603673	val_acc: 0.921569 val_loss: 0.2741978
4_135 train_acc: 0.7678 train_loss: 0.568837	val_acc: 0.926471 val_loss: 0.2616286
4_137 train_acc: 0.7978 train_loss: 0.570422	val_acc: 0.921569 val_loss: 0.2464073
4_148 train_acc: 0.7903 train_loss: 0.568547	val_acc: 0.931373 val_loss: 0.2425217
4_159 train_acc: 0.7940 train_loss: 0.518425	val_acc: 0.921569 val_loss: 0.2417615
4_169 train_acc: 0.8015 train_loss: 0.538417	val_acc: 0.946078 val_loss: 0.2412428
4_170 train_acc: 0.7715 train_loss: 0.581387	val_acc: 0.931373 val_loss: 0.2339395
4_173 train_acc: 0.7341 train_loss: 0.612290	val_acc: 0.946078 val_loss: 0.2122555
4_18

5_88 train_acc: 0.8165 train_loss: 0.485470	val_acc: 0.921569 val_loss: 0.2457034
5_90 train_acc: 0.7865 train_loss: 0.511651	val_acc: 0.921569 val_loss: 0.2424754
5_93 train_acc: 0.7865 train_loss: 0.559664	val_acc: 0.931373 val_loss: 0.2129847
5_100 train_acc: 0.8464 train_loss: 0.382793	val_acc: 0.936275 val_loss: 0.2060985
5_109 train_acc: 0.8277 train_loss: 0.430944	val_acc: 0.941176 val_loss: 0.2027691
5_114 train_acc: 0.8315 train_loss: 0.413532	val_acc: 0.936275 val_loss: 0.1957652
5_116 train_acc: 0.8202 train_loss: 0.478459	val_acc: 0.950980 val_loss: 0.1938128
5_126 train_acc: 0.8502 train_loss: 0.428709	val_acc: 0.950980 val_loss: 0.1935576
5_129 train_acc: 0.8202 train_loss: 0.538778	val_acc: 0.965686 val_loss: 0.1740549
5_130 train_acc: 0.8277 train_loss: 0.449379	val_acc: 0.960784 val_loss: 0.1691879
5_140 train_acc: 0.8052 train_loss: 0.450481	val_acc: 0.965686 val_loss: 0.1690423
5_142 train_acc: 0.8539 train_loss: 0.399115	val_acc: 0.946078 val_loss: 0.1653057
5_144 t

6_90 train_acc: 0.7715 train_loss: 0.542261	val_acc: 0.867647 val_loss: 0.3709896
6_91 train_acc: 0.7266 train_loss: 0.762061	val_acc: 0.872549 val_loss: 0.3654485
6_96 train_acc: 0.7678 train_loss: 0.566222	val_acc: 0.862745 val_loss: 0.3533607
6_102 train_acc: 0.7154 train_loss: 0.701094	val_acc: 0.877451 val_loss: 0.3504062
6_104 train_acc: 0.7715 train_loss: 0.571141	val_acc: 0.887255 val_loss: 0.3344275
6_111 train_acc: 0.7566 train_loss: 0.558341	val_acc: 0.897059 val_loss: 0.3094836
6_127 train_acc: 0.7341 train_loss: 0.547154	val_acc: 0.906863 val_loss: 0.2992566
6_136 train_acc: 0.8202 train_loss: 0.501302	val_acc: 0.887255 val_loss: 0.2866499
6_149 train_acc: 0.8015 train_loss: 0.554018	val_acc: 0.892157 val_loss: 0.2865650
6_153 train_acc: 0.8127 train_loss: 0.533188	val_acc: 0.911765 val_loss: 0.2700042
6_168 train_acc: 0.8052 train_loss: 0.506948	val_acc: 0.897059 val_loss: 0.2648866
6_177 train_acc: 0.8015 train_loss: 0.505126	val_acc: 0.936275 val_loss: 0.2374460
6_197 t

7_323 train_acc: 0.9326 train_loss: 0.201233	val_acc: 0.985294 val_loss: 0.0402221
7_324 train_acc: 0.9363 train_loss: 0.174498	val_acc: 0.985294 val_loss: 0.0384179
7_337 train_acc: 0.9326 train_loss: 0.173008	val_acc: 0.995098 val_loss: 0.0213172
7_413 train_acc: 0.9401 train_loss: 0.195256	val_acc: 0.995098 val_loss: 0.0126396
7_468 train_acc: 0.9513 train_loss: 0.140612	val_acc: 1.000000 val_loss: 0.0119383
7_559 train_acc: 0.9700 train_loss: 0.062205	val_acc: 0.995098 val_loss: 0.0108549
7_566 train_acc: 0.9775 train_loss: 0.063677	val_acc: 1.000000 val_loss: 0.0094584
7_581 train_acc: 0.9738 train_loss: 0.061441	val_acc: 1.000000 val_loss: 0.0090113
7_599 train_acc: 0.9588 train_loss: 0.129488	val_acc: 1.000000 val_loss: 0.0081659
7_612 train_acc: 0.9850 train_loss: 0.044685	val_acc: 1.000000 val_loss: 0.0080768
7_632 train_acc: 0.9663 train_loss: 0.064959	val_acc: 1.000000 val_loss: 0.0080399
7_653 train_acc: 0.9625 train_loss: 0.116922	val_acc: 1.000000 val_loss: 0.0050517
7_70

8_714 train_acc: 0.9438 train_loss: 0.140781	val_acc: 1.000000 val_loss: 0.0046548
8_720 train_acc: 0.9513 train_loss: 0.115896	val_acc: 1.000000 val_loss: 0.0043697
8_791 train_acc: 0.9513 train_loss: 0.131804	val_acc: 1.000000 val_loss: 0.0032694
8_845 train_acc: 0.9738 train_loss: 0.089156	val_acc: 1.000000 val_loss: 0.0030615
8_871 train_acc: 0.9551 train_loss: 0.127484	val_acc: 1.000000 val_loss: 0.0028688
8_904 train_acc: 0.9738 train_loss: 0.100220	val_acc: 1.000000 val_loss: 0.0028444
8_906 train_acc: 0.9401 train_loss: 0.115065	val_acc: 1.000000 val_loss: 0.0025835
8_994 train_acc: 0.9700 train_loss: 0.084614	val_acc: 1.000000 val_loss: 0.0021090
epoch:  994 	The test accuracy is: 0.875
 THE BEST ACCURACY IS 0.875	kappa is 0.8333333333333334
subject 8 duration: 0:20:28.284846
seed is 1195
Subject 9
-------------------- train size： (288, 1, 22, 1000) test size： (288, 22, 1000)
9_0 train_acc: 0.2434 train_loss: 2.446332	val_acc: 0.245098 val_loss: 1.4741368
9_1 train_acc: 0.2772