In [1]:
"""
This script fine-tunes the Conformer project (https://github.com/eeyhsong/EEG-Conformer) and achieves better experimental results compared to the previous project. 

fine-tunes by author: zhaowei701@163.com

"""

import os
gpus = [1]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
import numpy as np
import pandas as pd
import random
import datetime
import time

from pandas import ExcelWriter
from torchsummary import summary
import torch
from torch.backends import cudnn
from utils import calMetrics
from utils import calculatePerClass
from utils import numberClassChannel

import warnings
warnings.filterwarnings("ignore")
cudnn.benchmark = False
cudnn.deterministic = True



import torch
from torch import nn
from torch import Tensor
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, reduce, repeat
import torch.nn.functional as F

from utils import numberClassChannel
from utils import load_data_evaluate

import numpy as np
import pandas as pd
from torch.autograd import Variable


    
    
class Conformer(nn.Module):
    def __init__(self, number_channel=22, nb_classes=4, dropout_rate=0.5):
        # self.patch_size = patch_size
        super().__init__()

        self.shallownet = nn.Sequential(
            nn.Conv2d(1, 40, (1, 25), (1, 1)),
            nn.Conv2d(40, 40, (number_channel, 1), (1, 1)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.AvgPool2d((1, 75), (1, 15)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
            nn.Dropout(dropout_rate),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, 40, (1, 1), stride=(1, 1)),  # transpose, conv could enhance fiting ability slightly
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        self.trans = TransformerEncoder(10, 6, 40)
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2440, nb_classes)
        )

    def forward(self, x: Tensor) -> Tensor:
#         b, _, _, _ = x.shape
        x = self.shallownet(x)
        x = self.projection(x)
        x = self.trans(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x


    
    
  
    
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
    





class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class GELU(nn.Module):
    def forward(self, input: Tensor) -> Tensor:
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=10,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class TransformerEncoder(nn.Sequential):
    def __init__(self, num_heads, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size, num_heads) for _ in range(depth)])




class BranchEEGNetTransformer(nn.Sequential):
    def __init__(self, heads=4, 
                 depth=6, 
                 emb_size=40, 
                 number_channel=22,
                 f1 = 20,
                 kernel_size = 64,
                 D = 2,
                 pooling_size1 = 8,
                 pooling_size2 = 8,
                 dropout_rate = 0.3,
                 **kwargs):
        super().__init__(
            PatchEmbeddingCNN(f1=f1, 
                                 kernel_size=kernel_size,
                                 D=D, 
                                 pooling_size1=pooling_size1, 
                                 pooling_size2=pooling_size2, 
                                 dropout_rate=dropout_rate,
                                 number_channel=number_channel,
                                 emb_size=emb_size),
            TransformerEncoder(heads, depth, emb_size),
        )



    
class EEGTransformer(nn.Module):
    def __init__(self, heads=4, 
                 emb_size=40,
                 depth=6, 
                 database_type='A', 
                 eeg1_f1 = 20,
                 eeg1_kernel_size = 64,
                 eeg1_D = 2,
                 eeg1_pooling_size1 = 8,
                 eeg1_pooling_size2 = 8,
                 eeg1_dropout_rate = 0.3,
                 eeg1_number_channel = 22,
                 flatten_eeg1 = 600,
                 **kwargs):
        super().__init__()
        self.number_class, self.number_channel = numberClassChannel(database_type)
        self.flatten_eeg1 = flatten_eeg1
        self.net = Conformer(number_channel=self.number_channel, nb_classes=self.number_class)
        # print('self.number_channel', self.number_channel)
#         self.eegNet = BranchEEGNetTransformer(heads, depth, emb_size, number_channel=self.number_channel,
#                                               f1 = eeg1_f1,
#                                               kernel_size = eeg1_kernel_size,
#                                               D = eeg1_D,
#                                               pooling_size1 = eeg1_pooling_size1,
#                                               pooling_size2 = eeg1_pooling_size2,
#                                               dropout_rate = eeg1_dropout_rate,
#                                               )
#         # self.cnn_module = Branchcnn_moduleTransformer(heads, depth, emb_size)
#         self.flatten = nn.Flatten()
#         self.classification = ClassificationHead(self.flatten_eeg1 , self.number_class) # FLATTEN_EEGNet + FLATTEN_cnn_module
    def forward(self, x):
#         branchEEGNet1 = self.eegNet(x)
#         # branchcnn_module1 = self.cnn_module(x)
#         features = torch.cat([branchEEGNet1], dim=-2) # branchcnn_module1, 
#         net = torch.cat([self.flatten(branchEEGNet1)], dim=-1) # , self.flatten(branchcnn_module1), 
#         out = self.classification(net)
        out = self.net(x)
        return 0, out


class ExP():
    def __init__(self, nsub, data_dir, result_name, 
                 epochs=2000, 
                 number_aug=2,
                 number_seg=8, 
                 gpus=[0], 
                 evaluate_mode = 'subject-dependent',
                 heads=4, 
                 emb_size=40,
                 depth=6, 
                 dataset_type='A',
                 eeg1_f1 = 20,
                 eeg1_kernel_size = 64,
                 eeg1_D = 2,
                 eeg1_pooling_size1 = 8,
                 eeg1_pooling_size2 = 8,
                 eeg1_dropout_rate = 0.3,
                 flatten_eeg1 = 600, 
                 validate_ratio = 0.2,
                 learning_rate = 0.001,
                 batch_size = 72,  
                 ):
        
        super(ExP, self).__init__()
        self.dataset_type = dataset_type
        self.batch_size = batch_size
        self.lr = learning_rate
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_epochs = epochs
        self.nSub = nsub
        self.number_augmentation = number_aug
        self.number_seg = number_seg
        self.root = data_dir
        self.heads=heads
        self.emb_size=emb_size
        self.depth=depth
        self.result_name = result_name
        self.evaluate_mode = evaluate_mode
        self.validate_ratio = validate_ratio

        self.Tensor = torch.cuda.FloatTensor
        self.LongTensor = torch.cuda.LongTensor
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()

        self.number_class, self.number_channel = numberClassChannel(self.dataset_type)
        self.model = EEGTransformer(
             heads=self.heads, 
             emb_size=self.emb_size,
             depth=self.depth, 
            database_type=self.dataset_type, 
            eeg1_f1=eeg1_f1, 
            eeg1_D=eeg1_D,
            eeg1_kernel_size=eeg1_kernel_size,
            eeg1_pooling_size1 = eeg1_pooling_size1,
            eeg1_pooling_size2 = eeg1_pooling_size2,
            eeg1_dropout_rate = eeg1_dropout_rate,
            eeg1_number_channel = self.number_channel,
            flatten_eeg1 = flatten_eeg1,  
            ).cuda()
        #self.model = nn.DataParallel(self.model, device_ids=gpus)
        self.model = self.model.cuda()
        self.model_filename = self.result_name + '/model_{}.pth'.format(self.nSub)

    # Segmentation and Reconstruction (S&R) data augmentation
    def interaug(self, timg, label):  
        aug_data = []
        aug_label = []
        number_records_by_augmentation = self.number_augmentation * int(self.batch_size / self.number_class)
        number_segmentation_points = 1000 // self.number_seg
        for clsAug in range(self.number_class):
            cls_idx = np.where(label == clsAug + 1)
            tmp_data = timg[cls_idx]
            tmp_label = label[cls_idx]
            
            tmp_aug_data = np.zeros((number_records_by_augmentation, 1, self.number_channel, 1000))
            for ri in range(number_records_by_augmentation):
                for rj in range(self.number_seg):
                    rand_idx = np.random.randint(0, tmp_data.shape[0], self.number_seg)
                    tmp_aug_data[ri, :, :, rj * number_segmentation_points:(rj + 1) * number_segmentation_points] = \
                        tmp_data[rand_idx[rj], :, :, rj * number_segmentation_points:(rj + 1) * number_segmentation_points]

            aug_data.append(tmp_aug_data)
            aug_label.append(tmp_label[:number_records_by_augmentation])
        aug_data = np.concatenate(aug_data)
        aug_label = np.concatenate(aug_label)
        aug_shuffle = np.random.permutation(len(aug_data))
        aug_data = aug_data[aug_shuffle, :, :]
        aug_label = aug_label[aug_shuffle]

        aug_data = torch.from_numpy(aug_data).cuda()
        aug_data = aug_data.float()
        aug_label = torch.from_numpy(aug_label-1).cuda()
        aug_label = aug_label.long()
        return aug_data, aug_label



    def get_source_data(self):
        (self.train_data,    # (batch, channel, length)
         self.train_label, 
         self.test_data, 
         self.test_label) = load_data_evaluate(self.root, self.dataset_type, self.nSub, mode_evaluate=self.evaluate_mode)

        self.train_data = np.expand_dims(self.train_data, axis=1)  # (288, 1, 22, 1000)
        self.train_label = np.transpose(self.train_label)  

        self.allData = self.train_data
        self.allLabel = self.train_label[0]  

        shuffle_num = np.random.permutation(len(self.allData))
        # print("len(self.allData):", len(self.allData))
        self.allData = self.allData[shuffle_num, :, :, :]  # (288, 1, 22, 1000)
        # print("shuffle_num", shuffle_num)
        # print("self.allLabel", self.allLabel)
        self.allLabel = self.allLabel[shuffle_num]


        print('-'*20, "train size：", self.train_data.shape, "test size：", self.test_data.shape)
        # self.test_data = np.transpose(self.test_data, (2, 1, 0))
        self.test_data = np.expand_dims(self.test_data, axis=1)
        self.test_label = np.transpose(self.test_label)

        self.testData = self.test_data
        self.testLabel = self.test_label[0]


        # standardize
        target_mean = np.mean(self.allData)
        target_std = np.std(self.allData)
        self.allData = (self.allData - target_mean) / target_std
        self.testData = (self.testData - target_mean) / target_std
        
        isSaveDataLabel = False #True
        if isSaveDataLabel:
            np.save("./gradm_data/train_data_{}.npy".format(self.nSub), self.allData)
            np.save("./gradm_data/train_lable_{}.npy".format(self.nSub), self.allLabel)
            np.save("./gradm_data/test_data_{}.npy".format(self.nSub), self.testData)
            np.save("./gradm_data/test_label_{}.npy".format(self.nSub), self.testLabel)

        
        # data shape: (trial, conv channel, electrode channel, time samples)
        return self.allData, self.allLabel, self.testData, self.testLabel


    def train(self):
        img, label, test_data, test_label = self.get_source_data()
        # print("label size:", label.shape)
        # print("label size:", label)
        
        img = torch.from_numpy(img)
        label = torch.from_numpy(label - 1)
        dataset = torch.utils.data.TensorDataset(img, label)
        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)

        test_data = torch.from_numpy(test_data)
        test_label = torch.from_numpy(test_label - 1)
        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False)

        # Optimizers
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        test_data = Variable(test_data.type(self.Tensor))
        test_label = Variable(test_label.type(self.LongTensor))
        best_epoch = 0
        num = 0
        min_loss = 100
        # recording train_acc, train_loss, test_acc, test_loss
        result_process = []
        # Train the cnn model
        for e in range(self.n_epochs):
            epoch_process = {}
            epoch_process['epoch'] = e
            # in_epoch = time.time()
            self.model.train()
            outputs_list = []
            label_list = []
            # 验证集
            val_data_list = []
            val_label_list = []
            for i, (img, label) in enumerate(self.dataloader):
                number_sample = img.shape[0]
                number_validate = int(self.validate_ratio * number_sample)
                
                # split raw train dataset into real train dataset and validate dataset
                train_data = img[:-number_validate]
                train_label = label[:-number_validate]
                
                val_data_list.append(img[number_validate:])
                val_label_list.append(label[number_validate:])
                
                # real train dataset
                img = Variable(train_data.type(self.Tensor))
                label = Variable(train_label.type(self.LongTensor))
                
                # data augmentation
                aug_data, aug_label = self.interaug(self.allData, self.allLabel)
                # concat real train dataset and generate aritifical train dataset
                img = torch.cat((img, aug_data))
                label = torch.cat((label, aug_label))

                # training model
                features, outputs = self.model(img)
                outputs_list.append(outputs)
                label_list.append(label)
                # print("train outputs: ", outputs.shape, type(outputs))
                # print(features.size())
                loss = self.criterion_cls(outputs, label) 
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            del img
            torch.cuda.empty_cache()
            # out_epoch = time.time()
            # test process
            if (e + 1) % 1 == 0:
                self.model.eval()
                # validate model
                val_data = torch.cat(val_data_list).cuda()
                val_label = torch.cat(val_label_list).cuda()
                val_data = val_data.type(self.Tensor)
                val_label = val_label.type(self.LongTensor)            
                
                val_dataset = torch.utils.data.TensorDataset(val_data, val_label)
                self.val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=self.batch_size, shuffle=False)
                outputs_list = []
                with torch.no_grad():
                    for i, (img, _) in enumerate(self.val_dataloader):
                        # val model
                        img = img.type(self.Tensor).cuda()
                        _, Cls = self.model(img)
                        outputs_list.append(Cls)
                        del img, Cls
                        torch.cuda.empty_cache()
                    
                Cls = torch.cat(outputs_list)
                
                val_loss = self.criterion_cls(Cls, val_label)
                val_pred = torch.max(Cls, 1)[1]
                val_acc = float((val_pred == val_label).cpu().numpy().astype(int).sum()) / float(val_label.size(0))
                
                epoch_process['val_acc'] = val_acc                
                epoch_process['val_loss'] = val_loss.detach().cpu().numpy()  
                
                train_pred = torch.max(outputs, 1)[1]
                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
                epoch_process['train_acc'] = train_acc
                epoch_process['train_loss'] = loss.detach().cpu().numpy()

                num = num + 1

                # if min_loss>val_loss:                
                if min_loss>val_loss:
                    min_loss = val_loss
                    best_epoch = e
                    epoch_process['epoch'] = e
                    torch.save(self.model, self.model_filename)
                    print("{}_{} train_acc: {:.4f} train_loss: {:.6f}\tval_acc: {:.6f} val_loss: {:.7f}".format(self.nSub,
                                                                                           epoch_process['epoch'],
                                                                                           epoch_process['train_acc'],
                                                                                           epoch_process['train_loss'],
                                                                                           epoch_process['val_acc'],
                                                                                           epoch_process['val_loss'],
                                                                                        ))
            
                
            result_process.append(epoch_process)  

        
            del label, val_data, val_label
            torch.cuda.empty_cache()
        
        # load model for test
        self.model.eval()
        self.model = torch.load(self.model_filename).cuda()
        outputs_list = []
        with torch.no_grad():
            for i, (img, label) in enumerate(self.test_dataloader):
                img_test = Variable(img.type(self.Tensor)).cuda()
                # label_test = Variable(label.type(self.LongTensor))

                # test model
                features, outputs = self.model(img_test)
                val_pred = torch.max(outputs, 1)[1]
                outputs_list.append(outputs)
        outputs = torch.cat(outputs_list) 
        y_pred = torch.max(outputs, 1)[1]
        
        
        test_acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
        
        print("epoch: ", best_epoch, '\tThe test accuracy is:', test_acc)


        df_process = pd.DataFrame(result_process)

        return test_acc, test_label, y_pred, df_process, best_epoch
        # writer.close()
        







def main(dirs,                
         evaluate_mode = 'subject-dependent', # 评估模式：LOSO（跨个体）或其他（subject-dependent, subject-specific），
         heads=8,             # heads of MHA
         emb_size=48,         # token embding dim
         depth=3,             # Transformer encoder depth
         dataset_type='A',    # A->'BCI IV2a', B->'BCI IV2b'
         eeg1_f1=20,          # features of temporal conv
         eeg1_kernel_size=64, # kernel size of temporal conv
         eeg1_D=2,            # depth-wise conv 
         eeg1_pooling_size1=8,# p1
         eeg1_pooling_size2=8,# p2
         eeg1_dropout_rate=0.3,
         flatten_eeg1=600,   
         validate_ratio = 0.2
         ):

    if not os.path.exists(dirs):
        os.makedirs(dirs)

    result_write_metric = ExcelWriter(dirs+"/result_metric.xlsx")
    
    result_metric_dict = {}
    y_true_pred_dict = { }

    process_write = ExcelWriter(dirs+"/process_train.xlsx")
    pred_true_write = ExcelWriter(dirs+"/pred_true.xlsx")
    subjects_result = []
    best_epochs = []
    
    for i in range(N_SUBJECT):      
        
        starttime = datetime.datetime.now()
        seed_n = np.random.randint(2024)
        print('seed is ' + str(seed_n))
        random.seed(seed_n)
        np.random.seed(seed_n)
        torch.manual_seed(seed_n)
        torch.cuda.manual_seed(seed_n)
        torch.cuda.manual_seed_all(seed_n)
        index_round =0
        print('Subject %d' % (i+1))
        exp = ExP(i + 1, DATA_DIR, dirs, EPOCHS, N_AUG, N_SEG, gpus, 
                  evaluate_mode = evaluate_mode,
                  heads=heads, 
                  emb_size=emb_size,
                  depth=depth, 
                  dataset_type=dataset_type,
                  eeg1_f1 = eeg1_f1,
                  eeg1_kernel_size = eeg1_kernel_size,
                  eeg1_D = eeg1_D,
                  eeg1_pooling_size1 = eeg1_pooling_size1,
                  eeg1_pooling_size2 = eeg1_pooling_size2,
                  eeg1_dropout_rate = eeg1_dropout_rate,
                  flatten_eeg1 = flatten_eeg1,  
                  validate_ratio = validate_ratio
                  )

        testAcc, Y_true, Y_pred, df_process, best_epoch = exp.train()
        true_cpu = Y_true.cpu().numpy().astype(int)
        pred_cpu = Y_pred.cpu().numpy().astype(int)
        df_pred_true = pd.DataFrame({'pred': pred_cpu, 'true': true_cpu})
        df_pred_true.to_excel(pred_true_write, sheet_name=str(i+1))
        y_true_pred_dict[i] = df_pred_true

        accuracy, precison, recall, f1, kappa = calMetrics(true_cpu, pred_cpu)
        subject_result = {'accuray': accuracy*100,
                          'precision': precison*100,
                          'recall': recall*100,
                          'f1': f1*100, 
                          'kappa': kappa*100
                          }
        subjects_result.append(subject_result)
        df_process.to_excel(process_write, sheet_name=str(i+1))
        best_epochs.append(best_epoch)
    
        print(' THE BEST ACCURACY IS ' + str(testAcc) + "\tkappa is " + str(kappa) )
    

        endtime = datetime.datetime.now()
        print('subject %d duration: '%(i+1) + str(endtime - starttime))

        if i == 0:
            yt = Y_true
            yp = Y_pred
        else:
            yt = torch.cat((yt, Y_true))
            yp = torch.cat((yp, Y_pred))
                
        df_result = pd.DataFrame(subjects_result)
    process_write.close()
    pred_true_write.close()


    print('**The average Best accuracy is: ' + str(df_result['accuray'].mean()) + "kappa is: " + str(df_result['kappa'].mean()) + "\n" )
    print("best epochs: ", best_epochs)
    #df_result.to_excel(result_write_metric, index=False)
    result_metric_dict = df_result

    mean = df_result.mean(axis=0)
    mean.name = 'mean'
    std = df_result.std(axis=0)
    std.name = 'std'
    df_result = pd.concat([df_result, pd.DataFrame(mean).T, pd.DataFrame(std).T])
    
    df_result.to_excel(result_write_metric, index=False)
    print('-'*9, ' all result ', '-'*9)
    print(df_result)
    
    print("*"*40)

    result_write_metric.close()

    
    return result_metric_dict

if __name__ == "__main__":
    #----------------------------------------
    DATA_DIR = r'../mymat_raw/'
    EVALUATE_MODE = 'LOSO-No' # leaving one subject out subject-dependent  subject-indenpedent

    N_SUBJECT = 9       # BCI 
    N_AUG = 3           # data augmentation times for benerating artificial training data set
    N_SEG = 8           # segmentation times for S&R

    EPOCHS = 1000
    EMB_DIM = 48
    HEADS = 8
    DEPTH = 3
    TYPE = 'B'
    validate_ratio = 0.3 # split raw train dataset into real train dataset and validate dataset

    EEGNet1_F1 = 24
    EEGNet1_KERNEL_SIZE=64
    EEGNet1_D=2
    EEGNet1_POOL_SIZE1 = 8
    EEGNet1_POOL_SIZE2 = 8

    FLATTEN_EEGNet1 = 720

    
    if EVALUATE_MODE!='LOSO':
        EEGNet1_DROPOUT_RATE = 0.5
    else:
        EEGNet1_DROPOUT_RATE = 0.25  
    
    
    parameters_list = ['A', 'B']
    for TYPE in parameters_list:
        number_class, number_channel = numberClassChannel(TYPE)
        RESULT_NAME = "Conformer_同等_{}_".format(TYPE)
    
        sModel = EEGTransformer(
            heads=HEADS, 
            emb_size=EMB_DIM,
            depth=DEPTH, 
            database_type=TYPE,
            eeg1_f1=EEGNet1_F1, 
            eeg1_D=EEGNet1_D,
            eeg1_kernel_size=EEGNet1_KERNEL_SIZE,
            eeg1_pooling_size1 = EEGNet1_POOL_SIZE1,
            eeg1_pooling_size2 = EEGNet1_POOL_SIZE2,
            eeg1_dropout_rate = EEGNet1_DROPOUT_RATE,
            eeg1_number_channel = number_channel,
            flatten_eeg1 = FLATTEN_EEGNet1,  
            ).cuda()
        summary(sModel, (1, number_channel, 1000)) 
    
        print(time.asctime(time.localtime(time.time())))
        
        result = main(RESULT_NAME,
                        evaluate_mode = EVALUATE_MODE,
                        heads=HEADS, 
                        emb_size=EMB_DIM,
                        depth=DEPTH, 
                        dataset_type=TYPE,
                        eeg1_f1 = EEGNet1_F1,
                        eeg1_kernel_size = EEGNet1_KERNEL_SIZE,
                        eeg1_D = EEGNet1_D,
                        eeg1_pooling_size1 = EEGNet1_POOL_SIZE1,
                        eeg1_pooling_size2 = EEGNet1_POOL_SIZE2,
                        eeg1_dropout_rate = EEGNet1_DROPOUT_RATE,
                        flatten_eeg1 = FLATTEN_EEGNet1,
                        validate_ratio = validate_ratio,
                      )
        print(time.asctime(time.localtime(time.time())))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 40, 22, 976]           1,040
            Conv2d-2           [-1, 40, 1, 976]          35,240
       BatchNorm2d-3           [-1, 40, 1, 976]              80
               ELU-4           [-1, 40, 1, 976]               0
         AvgPool2d-5            [-1, 40, 1, 61]               0
           Dropout-6            [-1, 40, 1, 61]               0
            Conv2d-7            [-1, 40, 1, 61]           1,640
         Rearrange-8               [-1, 61, 40]               0
         LayerNorm-9               [-1, 61, 40]              80
           Linear-10               [-1, 61, 40]           1,640
           Linear-11               [-1, 61, 40]           1,640
           Linear-12               [-1, 61, 40]           1,640
          Dropout-13           [-1, 10, 61, 61]               0
           Linear-14               [-1,

1_17 train_acc: 0.7903 train_loss: 0.531925	val_acc: 0.852941 val_loss: 0.3680937
1_20 train_acc: 0.7978 train_loss: 0.543555	val_acc: 0.862745 val_loss: 0.3142221
1_26 train_acc: 0.8614 train_loss: 0.367490	val_acc: 0.901961 val_loss: 0.2921475
1_33 train_acc: 0.9176 train_loss: 0.208764	val_acc: 0.887255 val_loss: 0.2638387
1_35 train_acc: 0.8689 train_loss: 0.320940	val_acc: 0.887255 val_loss: 0.2609027
1_39 train_acc: 0.8914 train_loss: 0.263110	val_acc: 0.887255 val_loss: 0.2559395
1_40 train_acc: 0.8614 train_loss: 0.374204	val_acc: 0.931373 val_loss: 0.1919388
1_42 train_acc: 0.9326 train_loss: 0.190137	val_acc: 0.901961 val_loss: 0.1916173
1_43 train_acc: 0.9326 train_loss: 0.158938	val_acc: 0.946078 val_loss: 0.1543802
1_49 train_acc: 0.9101 train_loss: 0.233476	val_acc: 0.950980 val_loss: 0.1175394
1_54 train_acc: 0.9326 train_loss: 0.174541	val_acc: 0.960784 val_loss: 0.1094528
1_76 train_acc: 0.9513 train_loss: 0.141057	val_acc: 0.950980 val_loss: 0.0883269
1_85 train_acc: 

2_228 train_acc: 0.9288 train_loss: 0.196045	val_acc: 1.000000 val_loss: 0.0026569
2_253 train_acc: 0.9476 train_loss: 0.128107	val_acc: 1.000000 val_loss: 0.0024609
2_257 train_acc: 0.9476 train_loss: 0.136660	val_acc: 1.000000 val_loss: 0.0018530
2_258 train_acc: 0.9288 train_loss: 0.152044	val_acc: 1.000000 val_loss: 0.0014816
2_267 train_acc: 0.9476 train_loss: 0.140395	val_acc: 1.000000 val_loss: 0.0012483
2_281 train_acc: 0.9775 train_loss: 0.071973	val_acc: 1.000000 val_loss: 0.0009510
2_283 train_acc: 0.9738 train_loss: 0.094998	val_acc: 1.000000 val_loss: 0.0009349
2_285 train_acc: 0.9888 train_loss: 0.062351	val_acc: 1.000000 val_loss: 0.0005657
2_348 train_acc: 0.9888 train_loss: 0.052947	val_acc: 1.000000 val_loss: 0.0004125
2_349 train_acc: 0.9588 train_loss: 0.102466	val_acc: 1.000000 val_loss: 0.0003722
2_369 train_acc: 0.9663 train_loss: 0.115746	val_acc: 1.000000 val_loss: 0.0003409
2_376 train_acc: 0.9625 train_loss: 0.094445	val_acc: 1.000000 val_loss: 0.0003292
2_37

4_70 train_acc: 0.8951 train_loss: 0.264941	val_acc: 0.965686 val_loss: 0.1146097
4_72 train_acc: 0.8614 train_loss: 0.402047	val_acc: 0.965686 val_loss: 0.1128757
4_73 train_acc: 0.9026 train_loss: 0.259905	val_acc: 0.980392 val_loss: 0.0949056
4_74 train_acc: 0.9064 train_loss: 0.257923	val_acc: 0.975490 val_loss: 0.0795194
4_87 train_acc: 0.9101 train_loss: 0.233146	val_acc: 0.975490 val_loss: 0.0676463
4_90 train_acc: 0.8727 train_loss: 0.321016	val_acc: 0.995098 val_loss: 0.0553836
4_97 train_acc: 0.9326 train_loss: 0.183821	val_acc: 0.995098 val_loss: 0.0415028
4_102 train_acc: 0.8876 train_loss: 0.311548	val_acc: 1.000000 val_loss: 0.0308961
4_119 train_acc: 0.9251 train_loss: 0.214155	val_acc: 0.995098 val_loss: 0.0250017
4_126 train_acc: 0.9326 train_loss: 0.193490	val_acc: 0.990196 val_loss: 0.0240490
4_128 train_acc: 0.9700 train_loss: 0.114817	val_acc: 1.000000 val_loss: 0.0119049
4_143 train_acc: 0.9176 train_loss: 0.200697	val_acc: 1.000000 val_loss: 0.0101807
4_147 train

5_476 train_acc: 0.9850 train_loss: 0.043259	val_acc: 1.000000 val_loss: 0.0001151
5_481 train_acc: 0.9850 train_loss: 0.040852	val_acc: 1.000000 val_loss: 0.0000982
5_484 train_acc: 0.9850 train_loss: 0.058219	val_acc: 1.000000 val_loss: 0.0000935
5_496 train_acc: 0.9775 train_loss: 0.046338	val_acc: 1.000000 val_loss: 0.0000799
5_507 train_acc: 0.9738 train_loss: 0.063079	val_acc: 1.000000 val_loss: 0.0000687
5_522 train_acc: 0.9850 train_loss: 0.033498	val_acc: 1.000000 val_loss: 0.0000298
5_523 train_acc: 0.9775 train_loss: 0.050686	val_acc: 1.000000 val_loss: 0.0000170
5_649 train_acc: 0.9888 train_loss: 0.021600	val_acc: 1.000000 val_loss: 0.0000083
5_725 train_acc: 0.9850 train_loss: 0.057643	val_acc: 1.000000 val_loss: 0.0000067
5_737 train_acc: 0.9888 train_loss: 0.023736	val_acc: 1.000000 val_loss: 0.0000065
5_830 train_acc: 0.9888 train_loss: 0.025827	val_acc: 1.000000 val_loss: 0.0000051
5_835 train_acc: 0.9850 train_loss: 0.063178	val_acc: 1.000000 val_loss: 0.0000050
5_84

7_30 train_acc: 0.9326 train_loss: 0.189407	val_acc: 0.911765 val_loss: 0.2384445
7_31 train_acc: 0.8727 train_loss: 0.313953	val_acc: 0.926471 val_loss: 0.1927720
7_38 train_acc: 0.9026 train_loss: 0.212736	val_acc: 0.946078 val_loss: 0.1417108
7_45 train_acc: 0.9513 train_loss: 0.138357	val_acc: 0.955882 val_loss: 0.1171893
7_54 train_acc: 0.9326 train_loss: 0.205837	val_acc: 0.955882 val_loss: 0.1081715
7_56 train_acc: 0.8839 train_loss: 0.267905	val_acc: 0.980392 val_loss: 0.0583233
7_65 train_acc: 0.9438 train_loss: 0.173838	val_acc: 0.985294 val_loss: 0.0542993
7_67 train_acc: 0.9438 train_loss: 0.169603	val_acc: 0.975490 val_loss: 0.0512776
7_69 train_acc: 0.9326 train_loss: 0.174287	val_acc: 0.980392 val_loss: 0.0511000
7_71 train_acc: 0.9438 train_loss: 0.169604	val_acc: 0.985294 val_loss: 0.0335278
7_86 train_acc: 0.9663 train_loss: 0.098888	val_acc: 0.990196 val_loss: 0.0297412
7_88 train_acc: 0.9326 train_loss: 0.176562	val_acc: 0.995098 val_loss: 0.0245194
7_89 train_acc: 

9_54 train_acc: 0.9738 train_loss: 0.071242	val_acc: 0.995098 val_loss: 0.0213025
9_55 train_acc: 0.9363 train_loss: 0.170063	val_acc: 0.995098 val_loss: 0.0166856
9_69 train_acc: 0.9551 train_loss: 0.088410	val_acc: 0.995098 val_loss: 0.0155613
9_70 train_acc: 0.9551 train_loss: 0.120158	val_acc: 0.995098 val_loss: 0.0088509
9_71 train_acc: 0.9401 train_loss: 0.134399	val_acc: 1.000000 val_loss: 0.0068639
9_78 train_acc: 0.9775 train_loss: 0.068026	val_acc: 1.000000 val_loss: 0.0061781
9_79 train_acc: 0.9888 train_loss: 0.047694	val_acc: 1.000000 val_loss: 0.0052984
9_88 train_acc: 0.9625 train_loss: 0.093630	val_acc: 1.000000 val_loss: 0.0049481
9_94 train_acc: 0.9888 train_loss: 0.036819	val_acc: 1.000000 val_loss: 0.0035867
9_107 train_acc: 0.9775 train_loss: 0.054968	val_acc: 1.000000 val_loss: 0.0022312
9_114 train_acc: 0.9813 train_loss: 0.047123	val_acc: 1.000000 val_loss: 0.0011052
9_116 train_acc: 0.9700 train_loss: 0.086519	val_acc: 1.000000 val_loss: 0.0006453
9_148 train_a

1_0 train_acc: 0.5738 train_loss: 0.787573	val_acc: 0.522968 val_loss: 0.6818902
1_1 train_acc: 0.5738 train_loss: 0.740669	val_acc: 0.639576 val_loss: 0.6516152
1_2 train_acc: 0.5779 train_loss: 0.741151	val_acc: 0.671378 val_loss: 0.6066140
1_3 train_acc: 0.6311 train_loss: 0.637833	val_acc: 0.696113 val_loss: 0.6026442
1_4 train_acc: 0.6475 train_loss: 0.691990	val_acc: 0.717314 val_loss: 0.5871862
1_6 train_acc: 0.6721 train_loss: 0.633349	val_acc: 0.756184 val_loss: 0.5088624
1_10 train_acc: 0.7131 train_loss: 0.566417	val_acc: 0.756184 val_loss: 0.5055631
1_11 train_acc: 0.7500 train_loss: 0.484529	val_acc: 0.809187 val_loss: 0.3832566
1_20 train_acc: 0.7418 train_loss: 0.563054	val_acc: 0.851590 val_loss: 0.3593350
1_22 train_acc: 0.7869 train_loss: 0.467433	val_acc: 0.840989 val_loss: 0.3584891
1_24 train_acc: 0.8402 train_loss: 0.335536	val_acc: 0.869258 val_loss: 0.2845435
1_28 train_acc: 0.8566 train_loss: 0.336474	val_acc: 0.883392 val_loss: 0.2658859
1_30 train_acc: 0.8770

2_304 train_acc: 0.9344 train_loss: 0.198449	val_acc: 0.992933 val_loss: 0.0351071
2_321 train_acc: 0.9098 train_loss: 0.190162	val_acc: 0.992933 val_loss: 0.0278896
2_339 train_acc: 0.9139 train_loss: 0.209165	val_acc: 1.000000 val_loss: 0.0244654
2_341 train_acc: 0.9303 train_loss: 0.169127	val_acc: 1.000000 val_loss: 0.0221939
2_348 train_acc: 0.9467 train_loss: 0.172295	val_acc: 1.000000 val_loss: 0.0205490
2_356 train_acc: 0.9303 train_loss: 0.131061	val_acc: 1.000000 val_loss: 0.0187838
2_372 train_acc: 0.9426 train_loss: 0.137031	val_acc: 1.000000 val_loss: 0.0124411
2_402 train_acc: 0.9344 train_loss: 0.145060	val_acc: 1.000000 val_loss: 0.0112761
2_429 train_acc: 0.9508 train_loss: 0.142136	val_acc: 1.000000 val_loss: 0.0092350
2_460 train_acc: 0.9303 train_loss: 0.183681	val_acc: 1.000000 val_loss: 0.0086558
2_467 train_acc: 0.9426 train_loss: 0.132482	val_acc: 1.000000 val_loss: 0.0083351
2_468 train_acc: 0.9549 train_loss: 0.117106	val_acc: 1.000000 val_loss: 0.0048715
2_53

3_384 train_acc: 0.9180 train_loss: 0.207238	val_acc: 0.996466 val_loss: 0.0294163
3_389 train_acc: 0.9139 train_loss: 0.231242	val_acc: 0.992933 val_loss: 0.0293332
3_406 train_acc: 0.9426 train_loss: 0.143380	val_acc: 0.996466 val_loss: 0.0289945
3_424 train_acc: 0.9508 train_loss: 0.166514	val_acc: 0.996466 val_loss: 0.0271979
3_447 train_acc: 0.9221 train_loss: 0.185551	val_acc: 0.996466 val_loss: 0.0262817
3_453 train_acc: 0.9303 train_loss: 0.184258	val_acc: 0.992933 val_loss: 0.0251587
3_462 train_acc: 0.9344 train_loss: 0.146560	val_acc: 1.000000 val_loss: 0.0166653
3_514 train_acc: 0.9303 train_loss: 0.181820	val_acc: 1.000000 val_loss: 0.0120772
3_540 train_acc: 0.9508 train_loss: 0.144549	val_acc: 1.000000 val_loss: 0.0120000
3_546 train_acc: 0.9098 train_loss: 0.206978	val_acc: 1.000000 val_loss: 0.0114043
3_559 train_acc: 0.9385 train_loss: 0.166770	val_acc: 1.000000 val_loss: 0.0087712
3_569 train_acc: 0.9221 train_loss: 0.212366	val_acc: 1.000000 val_loss: 0.0074810
3_60

5_66 train_acc: 0.9302 train_loss: 0.175630	val_acc: 0.925926 val_loss: 0.2040293
5_70 train_acc: 0.9147 train_loss: 0.213165	val_acc: 0.925926 val_loss: 0.2032105
5_76 train_acc: 0.9496 train_loss: 0.165164	val_acc: 0.936027 val_loss: 0.1871921
5_86 train_acc: 0.9264 train_loss: 0.168132	val_acc: 0.939394 val_loss: 0.1561217
5_96 train_acc: 0.9419 train_loss: 0.140750	val_acc: 0.939394 val_loss: 0.1499123
5_100 train_acc: 0.9109 train_loss: 0.206050	val_acc: 0.939394 val_loss: 0.1387348
5_105 train_acc: 0.9612 train_loss: 0.127908	val_acc: 0.946128 val_loss: 0.1326508
5_111 train_acc: 0.9574 train_loss: 0.112821	val_acc: 0.946128 val_loss: 0.1233876
5_129 train_acc: 0.9419 train_loss: 0.144130	val_acc: 0.959596 val_loss: 0.1197082
5_132 train_acc: 0.9574 train_loss: 0.109922	val_acc: 0.946128 val_loss: 0.1193534
5_135 train_acc: 0.9806 train_loss: 0.082856	val_acc: 0.949495 val_loss: 0.1143817
5_143 train_acc: 0.9496 train_loss: 0.134846	val_acc: 0.962963 val_loss: 0.1106636
5_145 tra

6_736 train_acc: 0.9672 train_loss: 0.095610	val_acc: 1.000000 val_loss: 0.0017166
6_760 train_acc: 0.9672 train_loss: 0.070845	val_acc: 1.000000 val_loss: 0.0014565
6_816 train_acc: 0.9795 train_loss: 0.052410	val_acc: 1.000000 val_loss: 0.0014169
6_818 train_acc: 0.9836 train_loss: 0.043009	val_acc: 1.000000 val_loss: 0.0013949
6_852 train_acc: 0.9590 train_loss: 0.077055	val_acc: 1.000000 val_loss: 0.0009286
6_914 train_acc: 0.9672 train_loss: 0.078325	val_acc: 1.000000 val_loss: 0.0008453
6_915 train_acc: 0.9836 train_loss: 0.066440	val_acc: 1.000000 val_loss: 0.0006100
6_918 train_acc: 0.9713 train_loss: 0.065291	val_acc: 1.000000 val_loss: 0.0005620
6_984 train_acc: 0.9508 train_loss: 0.135954	val_acc: 1.000000 val_loss: 0.0005009
6_988 train_acc: 0.9795 train_loss: 0.056521	val_acc: 1.000000 val_loss: 0.0004210
6_997 train_acc: 0.9754 train_loss: 0.061832	val_acc: 1.000000 val_loss: 0.0003732
epoch:  997 	The test accuracy is: 0.83125
 THE BEST ACCURACY IS 0.83125	kappa is 0.662

8_643 train_acc: 0.9775 train_loss: 0.052930	val_acc: 1.000000 val_loss: 0.0025450
8_682 train_acc: 0.9730 train_loss: 0.075195	val_acc: 1.000000 val_loss: 0.0018492
8_735 train_acc: 0.9820 train_loss: 0.036802	val_acc: 1.000000 val_loss: 0.0014335
8_745 train_acc: 0.9595 train_loss: 0.061406	val_acc: 1.000000 val_loss: 0.0014145
8_748 train_acc: 0.9820 train_loss: 0.078801	val_acc: 1.000000 val_loss: 0.0012711
8_796 train_acc: 0.9865 train_loss: 0.040257	val_acc: 1.000000 val_loss: 0.0011821
8_808 train_acc: 0.9640 train_loss: 0.097603	val_acc: 1.000000 val_loss: 0.0011058
8_818 train_acc: 0.9550 train_loss: 0.110861	val_acc: 1.000000 val_loss: 0.0008342
8_834 train_acc: 0.9414 train_loss: 0.180813	val_acc: 1.000000 val_loss: 0.0005599
8_911 train_acc: 0.9640 train_loss: 0.100402	val_acc: 1.000000 val_loss: 0.0004088
8_932 train_acc: 0.9550 train_loss: 0.135827	val_acc: 1.000000 val_loss: 0.0003086
8_990 train_acc: 0.9865 train_loss: 0.040848	val_acc: 1.000000 val_loss: 0.0002272
epoc