In [1]:
"""
Attention-based multi-scale convolutional neural network for motor imagery classification

author: zhaowei701@163.com

"""

import os
gpu_number = 1
gpus = [gpu_number]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_number)
import numpy as np
import pandas as pd
import random
import datetime
import time

from pandas import ExcelWriter
from torchsummary import summary
import torch
from torch.backends import cudnn
from utils import calMetrics
from utils import calculatePerClass
from utils import numberClassChannel
import math
import warnings
warnings.filterwarnings("ignore")
cudnn.benchmark = False
cudnn.deterministic = True



import torch
from torch import nn
from torch import Tensor
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, reduce, repeat
import torch.nn.functional as F

from utils import numberClassChannel
from utils import load_data_evaluate

import numpy as np
import pandas as pd
from torch.autograd import Variable


class PatchEmbeddingCNN(nn.Module):
    def __init__(self, 
                 f1=16, 
                 pooling_size=52, 
                 dropout_rate=0.5, 
                 number_channel=22):
        super().__init__()
        self.cnn1 = nn.Sequential(
            nn.Conv2d(1, f1, (1, 85), (1, 1), padding='same'),
            nn.Conv2d(f1, f1, (number_channel, 1), (1, 1), groups=f1),
            nn.BatchNorm2d(f1),
            nn.ELU(),
            nn.AvgPool2d((1,pooling_size)), 
            nn.Dropout(dropout_rate),
        )
        self.cnn2 = nn.Sequential(
            nn.Conv2d(1, f1, (1, 65), (1, 1), padding='same'),
            nn.Conv2d(f1, f1, (number_channel, 1), (1, 1), groups=f1),
            nn.BatchNorm2d(f1),
            nn.ELU(),
            nn.AvgPool2d((1,pooling_size)), 
            nn.Dropout(dropout_rate),
        )        
        self.cnn3 = nn.Sequential(
            nn.Conv2d(1, f1, (1, 45), (1, 1), padding='same'),
            nn.Conv2d(f1, f1, (number_channel, 1), (1, 1), groups=f1),
            nn.BatchNorm2d(f1),
            nn.ELU(),
            nn.AvgPool2d((1,pooling_size)), 
            nn.Dropout(dropout_rate),
        )
        self.projection = nn.Sequential(
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
        
        
    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x1 = self.cnn1(x)
        x2 = self.cnn2(x)
        x3 = self.cnn3(x)
        #通道方向合并
        x = torch.cat([x1, x2, x3], dim=1)
        x = self.projection(x)
        return x    
    
  
    
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
    


# PointWise FFN
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )



class ClassificationHead(nn.Sequential):
    def __init__(self, flatten_number, n_classes):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Dropout(0.25),
            nn.Linear(flatten_number, n_classes)
        )

    def forward(self, x):
        out = self.fc(x)
        
        return out

# 先求和再LayerNorm
class ResidualAdd(nn.Module):
    def __init__(self, fn, emb_size, drop_p):
        super().__init__()
        self.fn = fn
        self.drop = nn.Dropout(drop_p)
        self.layernorm = nn.LayerNorm(emb_size)

    def forward(self, x, **kwargs):
        x_input = x
        res = self.fn(x, **kwargs)
        out = self.layernorm(self.drop(res)+x_input)
        return out

    
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=4,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                MultiHeadAttention(emb_size, num_heads, drop_p),
                ), emb_size, drop_p),
            ResidualAdd(nn.Sequential(
                FeedForwardBlock(emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                ), emb_size, drop_p)
            )    
        
        
class TransformerEncoder(nn.Sequential):
    def __init__(self, heads, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size, heads) for _ in range(depth)])


class BranchEEGNetTransformer(nn.Sequential):
    def __init__(self, parameters,
                 **kwargs):
        super().__init__(
            PatchEmbeddingCNN(f1=parameters.f1, 
                                 pooling_size=parameters.pooling_size, 
                                 dropout_rate=parameters.dropout_rate,
                                 number_channel=parameters.number_channel,
                                 ),
#             TransformerEncoder(heads, depth, emb_size),
        )


class PositioinalEncoding(nn.Module):
    def __init__(self, embedding, length=100, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.encoding = nn.Parameter(torch.randn(1, length, embedding))
    def forward(self, x): # x-> [batch, embedding, length]
        x = x + self.encoding[:, :x.shape[1], :].cuda()
        return self.dropout(x)        
        
    
class EEGTransformer(nn.Module):
    def __init__(self, 
                 parameters,
                 database_type='A', 
                 
                 **kwargs):
        super().__init__()
        self.number_class, self.number_channel = numberClassChannel(database_type)
        self.emb_size = parameters.emb_size
        parameters.number_channel = self.number_channel
        self.cnn = BranchEEGNetTransformer(parameters)
        self.position = PositioinalEncoding(parameters.emb_size, dropout=0.1)
        self.trans = TransformerEncoder(parameters.heads, 
                                        parameters.depth, 
                                        parameters.emb_size)

        self.flatten = nn.Flatten()
        self.classification = ClassificationHead(self.emb_size , self.number_class) 
    def forward(self, x):
        x = self.cnn(x)
        b, l, e = x.shape
        x = torch.cat((torch.zeros((b, 1, e),requires_grad=True).cuda(),x), 1)
        x = x * math.sqrt(self.emb_size)
        x = self.position(x)
        trans = self.trans(x)
        features = trans[:, 0, :]
        
        out = self.classification(features)
        return features, out


class ExP():
    def __init__(self, nsub, data_dir, result_name,
                 parameters,
                 evaluate_mode = 'LOSO-no',
                 dataset_type='A',
                 n_fold = 0,
                 ):
        
        super(ExP, self).__init__()
        self.n_fold = n_fold
        self.dataset_type = dataset_type
        self.batch_size = parameters.batch_size
        self.lr = parameters.learning_rate
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_epochs = parameters.epochs
        self.nSub = nsub
        self.nFold = n_fold
        self.number_augmentation = parameters.number_aug
        self.number_seg = parameters.number_seg
        self.root = data_dir
        self.result_name = result_name
        self.evaluate_mode = evaluate_mode
        self.Tensor = torch.cuda.FloatTensor
        self.LongTensor = torch.cuda.LongTensor
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()

        self.number_class, self.number_channel = numberClassChannel(dataset_type)
        self.model = EEGTransformer(
            database_type=self.dataset_type, 
            parameters = parameters, 
            ).cuda()
        #self.model = nn.DataParallel(self.model, device_ids=gpus)
        self.model = self.model.cuda()
        self.model_filename = self.result_name + '/model_nsub_{}_nfold_{}.pth'.format(self.nSub, n_fold+1)

    # Segmentation and Reconstruction (S&R) data augmentation
    def interaug(self, timg, label):  
        aug_data = []
        aug_label = []
        number_records_by_augmentation = self.number_augmentation * int(self.batch_size / self.number_class)
        number_segmentation_points = 1000 // self.number_seg
        for clsAug in range(self.number_class):
            cls_idx = np.where(label == clsAug + 1)
            tmp_data = timg[cls_idx]
            tmp_label = label[cls_idx]
            tmp_aug_data = np.zeros((number_records_by_augmentation, 1, self.number_channel, 1000))
            for ri in range(number_records_by_augmentation):
                for rj in range(self.number_seg):
                    rand_idx = np.random.randint(0, tmp_data.shape[0], self.number_seg)
                    tmp_aug_data[ri, :, :, rj * number_segmentation_points:(rj + 1) * number_segmentation_points] = \
                        tmp_data[rand_idx[rj], :, :, rj * number_segmentation_points:(rj + 1) * number_segmentation_points]

            aug_data.append(tmp_aug_data)
            aug_label.append(tmp_label[:number_records_by_augmentation])
        aug_data = np.concatenate(aug_data)
        aug_label = np.concatenate(aug_label)
        aug_shuffle = np.random.permutation(len(aug_data))
        aug_data = aug_data[aug_shuffle, :, :]
        aug_label = aug_label[aug_shuffle]

        aug_data = torch.from_numpy(aug_data).cuda()
        aug_data = aug_data.float()
        aug_label = torch.from_numpy(aug_label-1).cuda()
        aug_label = aug_label.long()
        return aug_data, aug_label



    def get_source_data(self):
        (self.train_data,    # (batch, channel, length)
         self.train_label, 
         self.test_data, 
         self.test_label) = load_data_evaluate(self.root, self.dataset_type, self.nSub, mode_evaluate=self.evaluate_mode)

        self.train_data = np.expand_dims(self.train_data, axis=1)  # (288, 1, 22, 1000)
        self.train_label = np.transpose(self.train_label)  

        self.allData = self.train_data
        self.allLabel = self.train_label[0]  

        shuffle_num = np.random.permutation(len(self.allData))
        # print("len(self.allData):", len(self.allData))
        self.allData = self.allData[shuffle_num, :, :, :]  # (288, 1, 22, 1000)
        # print("shuffle_num", shuffle_num)
        # print("self.allLabel", self.allLabel)
        self.allLabel = self.allLabel[shuffle_num]


        print('-'*20, 
              "raw train size：", self.train_data.shape, 
              "test size：", self.test_data.shape, 
              "subject:", self.nSub,
              "fold:", self.nFold+1)
        # self.test_data = np.transpose(self.test_data, (2, 1, 0))
        self.test_data = np.expand_dims(self.test_data, axis=1)
        self.test_label = np.transpose(self.test_label)

        self.testData = self.test_data
        self.testLabel = self.test_label[0]
        
        
        
        # standardize
        target_mean = np.mean(self.allData)
        target_std = np.std(self.allData)
        self.allData = (self.allData - target_mean) / target_std
        self.testData = (self.testData - target_mean) / target_std
        
        isSaveDataLabel = False #True
        if isSaveDataLabel:
            np.save("./gradm_data/train_data_{}.npy".format(self.nSub), self.allData)
            np.save("./gradm_data/train_lable_{}.npy".format(self.nSub), self.allLabel)
            np.save("./gradm_data/test_data_{}.npy".format(self.nSub), self.testData)
            np.save("./gradm_data/test_label_{}.npy".format(self.nSub), self.testLabel)

        
        # data shape: (trial, conv channel, electrode channel, time samples)
        return self.allData, self.allLabel, self.testData, self.testLabel

    
    def test_model(self, model, dataloader):
        model.eval()
        outputs_list = []
        label_list = []
        with torch.no_grad():
            for i, (img, label) in enumerate(dataloader):
                # val model
                img = img.type(self.Tensor).cuda()
                label = label.type(self.LongTensor).cuda()
                _, Cls = model(img)
                outputs_list.append(Cls)
                del img, Cls
                torch.cuda.empty_cache()
                label_list.append(label)
            
        Cls = torch.cat(outputs_list)
        val_label = torch.cat(label_list)
        val_loss = self.criterion_cls(Cls, val_label)
        val_pred = torch.max(Cls, 1)[1]
        val_acc = float((val_pred == val_label).cpu().numpy().astype(int).sum()) / float(val_label.size(0))
        return val_acc, val_loss, val_pred

    
    
    

    def train(self):
        timg, label, test_data, test_label = self.get_source_data()
        
        # 训练数据集合做5折交叉验证
        # 先分类别
        train_data_list_per_class = []
        train_label_list_per_class = []
        for clsAug in range(self.number_class):
            cls_idx = np.where(label == clsAug + 1)
            tmp_data = timg[cls_idx]
            tmp_label = label[cls_idx]        
            train_data_list_per_class.append(tmp_data)
            train_label_list_per_class.append(tmp_label)
            
        # 再从每个类别里取出一折用于测试，其他用于训练
        train_data_list = []
        train_label_list = []
        val_data_list, val_label_list =  [], []
        seed = 1234+self.nSub
        for clsAug in range(self.number_class):
            # 计算每个类别的数量
            number_samples = len(train_data_list_per_class[clsAug])
            number_test = number_samples // 5
            # 索引乱序，用于切分训练集和测试集
            index_list = list(range(number_samples))
            np.random.seed(seed+clsAug)
            # 乱序的索引，根据索引取数，相当于对数据集做shuffle
            index_shuffled = np.random.permutation(index_list)    
            # print(index_shuffled[:10])
            # 训练集和测试的索引的索引序号, 用于取出第几折的乱序索引
            if self.n_fold!=4 :
                index_val = [i for i in range(self.n_fold*number_test, (self.n_fold+1)*number_test)]
            else:
                #     由于288无法被5整除，最后一折取剩下所有的
                index_val = [i for i in range(self.n_fold*number_test, number_samples)]

            index_train = [i for i in range(number_samples) if i not in index_val]
            # 训练集和测试集的索引
            index_train = index_shuffled[index_train]
            index_val = index_shuffled[index_val]   
            
            train_data_class = train_data_list_per_class[clsAug][index_train]
            train_label_class = train_label_list_per_class[clsAug][index_train]
            train_data_list.append(train_data_class)
            train_label_list.append(train_label_class)
            
            val_data_class = train_data_list_per_class[clsAug][index_val]
            val_label_class = train_label_list_per_class[clsAug][index_val]            
            val_data_list.append(val_data_class)
            val_label_list.append(val_label_class)        
        # 合并各类别的各折的数据
        img = np.concatenate(train_data_list)
        label = np.concatenate(train_label_list)
        val_data = np.concatenate(val_data_list)
        val_label = np.concatenate(val_label_list)
#         print("image size:", img.shape, label.shape)
        

        
        print('-'*20, 
              "train size：", img.shape, 
              "val size：", val_data.shape, )
        
        img = torch.from_numpy(img)
        label = torch.from_numpy(label - 1)
        dataset = torch.utils.data.TensorDataset(img, label)
        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
        
        val_data = torch.from_numpy(val_data)
        val_label = torch.from_numpy(val_label - 1)
        val_dataset = torch.utils.data.TensorDataset(val_data, val_label)
        self.val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=self.batch_size, shuffle=True)
                
        
        test_data = torch.from_numpy(test_data)
        test_label = torch.from_numpy(test_label - 1)
        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False)

        # Optimizers
        weight_decay = 0 if TYPE=='A' else 0.001
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2),
                                         weight_decay=weight_decay
                                         )

        test_data = Variable(test_data.type(self.Tensor))
        test_label = Variable(test_label.type(self.LongTensor))
        best_epoch = 0
        num = 0
        min_loss = 100
        max_acc = 0
        # recording train_acc, train_loss, test_acc, test_loss
        result_process = []
        # Train the cnn model
        for e in range(self.n_epochs):
            epoch_process = {}
            epoch_process['epoch'] = e
            # in_epoch = time.time()
            self.model.train()
            outputs_list = []
            label_list = []
            for i, (img, label) in enumerate(self.dataloader):
                number_sample = img.shape[0]
                
                # split raw train dataset into real train dataset and validate dataset
                train_data = img
                train_label = label

                
                # real train dataset
                img = Variable(train_data.type(self.Tensor))
                label = Variable(train_label.type(self.LongTensor))
                
                # data augmentation
                aug_data, aug_label = self.interaug(self.allData, self.allLabel)
                # concat real train dataset and generate aritifical train dataset
                img = torch.cat((img, aug_data))
                label = torch.cat((label, aug_label))

                # training model
                features, outputs = self.model(img)
                outputs_list.append(outputs)
                label_list.append(label)
                # print("train outputs: ", outputs.shape, type(outputs))
                # print(features.size())
                loss = self.criterion_cls(outputs, label) 
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            del img
            torch.cuda.empty_cache()
            # out_epoch = time.time()
            # test process
            if (e + 1) % 1 == 0:
                self.model.eval()
                # validate model
                val_acc, val_loss, _ = self.test_model(self.model, self.val_dataloader)
                
                epoch_process['val_acc'] = val_acc                
                epoch_process['val_loss'] = val_loss.detach().cpu().numpy()  
                
                train_pred = torch.max(outputs, 1)[1]
                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
                epoch_process['train_acc'] = train_acc
                epoch_process['train_loss'] = loss.detach().cpu().numpy()

                num = num + 1

                # if min_loss>val_loss:                
                if max_acc<val_acc or (max_acc==val_acc and min_loss>val_loss):
                    max_acc = val_acc
                    min_loss = val_loss
                    
                    best_epoch = e
                    epoch_process['epoch'] = e
                    torch.save(self.model, self.model_filename)
                    test_acc, test_loss, y_pred  = self.test_model(self.model, self.test_dataloader)
                    print("{}_{} train_acc: {:.4f} train_loss: {:.6f}\tval_acc: {:.6f} val_loss: {:.7f} test_acc:{:.6f}".format(self.nSub,
                                                                                           epoch_process['epoch'],
                                                                                           epoch_process['train_acc'],
                                                                                           epoch_process['train_loss'],
                                                                                           epoch_process['val_acc'],
                                                                                           epoch_process['val_loss'],
                                                                                           test_acc                                     
                                                                                        ))
            
                
            result_process.append(epoch_process)  

        
        # load model for test
        self.model.eval()
        self.model = torch.load(self.model_filename).cuda()
        outputs_list = []
        with torch.no_grad():
            for i, (img, label) in enumerate(self.test_dataloader):
                img_test = Variable(img.type(self.Tensor)).cuda()
                # label_test = Variable(label.type(self.LongTensor))

                # test model
                features, outputs = self.model(img_test)
                val_pred = torch.max(outputs, 1)[1]
                outputs_list.append(outputs)
        outputs = torch.cat(outputs_list) 
        y_pred = torch.max(outputs, 1)[1]
        
        
        test_acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
        
        print("epoch: ", best_epoch, '\tThe test accuracy is:', test_acc)


        df_process = pd.DataFrame(result_process)

        return test_acc, test_label, y_pred, df_process, best_epoch, outputs
        # writer.close()
        

def main(dirs,           
         paramters,
         evaluate_mode = 'subject-dependent', # 评估模式：LOSO（跨个体）或其他（subject-dependent, subject-specific），
         dataset_type='A',    # A->'BCI IV2a', B->'BCI IV2b'
         ):

    if not os.path.exists(dirs):
        os.makedirs(dirs)
    result_write_metric = ExcelWriter(dirs+"/result_metric.xlsx")
    result_metric_dict = {}
    y_true_pred_dict = { }

    process_write = ExcelWriter(dirs+"/process_train.xlsx")
    pred_true_write = ExcelWriter(dirs+"/pred_true.xlsx")
    pred_softmax = ExcelWriter(dirs+"/pred_softmax.xlsx")
    subjects_result = []
    
    
    best_epochs = []
    result_fold = []
    for i in range(paramters.subject_number):      
        starttime = datetime.datetime.now()
        seed_n = np.random.randint(2024)
        print('seed is ' + str(seed_n))
        random.seed(seed_n)
        np.random.seed(seed_n)
        torch.manual_seed(seed_n)
        torch.cuda.manual_seed(seed_n)
        torch.cuda.manual_seed_all(seed_n)
        index_round =0
        print('Subject %d' % (i+1))
        # 每个人做5折交叉验证
        subjects_result_fold = []
        for n_fold in range(5):
            exp = ExP(i + 1, DATA_DIR, dirs, 
                      paramters,
                      evaluate_mode = evaluate_mode,
                      dataset_type=dataset_type,
                      n_fold=n_fold,
                      )

            testAcc, Y_true, Y_pred, df_process, best_epoch,pred_output = exp.train()
            probs = torch.softmax(pred_output, dim=1).cpu().numpy()
            df_probs = pd.DataFrame(probs)
            df_probs.to_excel(pred_softmax, sheet_name=str(i+1)+'_'+str(n_fold))
            true_cpu = Y_true.cpu().numpy().astype(int)
            pred_cpu = Y_pred.cpu().numpy().astype(int)
            df_pred_true = pd.DataFrame({'pred': pred_cpu, 'true': true_cpu})
            df_pred_true.to_excel(pred_true_write, sheet_name=str(i+1)+'_'+str(n_fold))
            y_true_pred_dict[i] = df_pred_true

            accuracy, precison, recall, f1, kappa = calMetrics(true_cpu, pred_cpu)
            subject_result = {'accuray': accuracy*100,
                              'precision': precison*100,
                              'recall': recall*100,
                              'f1': f1*100, 
                              'kappa': kappa*100
                              }
            df_process.to_excel(process_write, sheet_name=str(i+1)+'_'+str(n_fold))
            best_epochs.append(best_epoch)
            print(' THE BEST ACCURACY IS ' + str(testAcc) + "\tkappa is " + str(kappa) )
    

            endtime = datetime.datetime.now()
            print('subject %d duration: '%(i+1) + str(endtime - starttime))

            if i == 0:
                yt = Y_true
                yp = Y_pred
            else:
                yt = torch.cat((yt, Y_true))
                yp = torch.cat((yp, Y_pred))
            subjects_result_fold.append(subject_result)
            # 每个人每一折的结果都写入excel存盘
            df_result_fold = pd.DataFrame(subjects_result)
            
        df = pd.DataFrame(subjects_result_fold)
        df.to_excel(result_write_metric, index=False,  sheet_name=str(i+1))
        result_fold_mean = df.mean()
        print("{} subject {} fold mean: \n {}".format(i+1, n_fold+1, result_fold_mean))
        subjects_result.append(result_fold_mean)
    df_result = pd.DataFrame(subjects_result)
    process_write.close()
    pred_true_write.close()
    pred_softmax.close()


    print('**The average Best accuracy is: ' + str(df_result['accuray'].mean()) + "kappa is: " + str(df_result['kappa'].mean()) + "\n" )
    print("best epochs: ", best_epochs)
    #df_result.to_excel(result_write_metric, index=False)
    result_metric_dict = df_result

    mean = df_result.mean(axis=0)
    mean.name = 'mean'
    std = df_result.std(axis=0)
    std.name = 'std'
    df_result = pd.concat([df_result, pd.DataFrame(mean).T, pd.DataFrame(std).T])
    
    df_result.to_excel(result_write_metric, index=False, sheet_name='mean')
    print('-'*9, ' all result ', '-'*9)
    print(df_result)
    
    print("*"*40)

    result_write_metric.close()

    
    return result_metric_dict


class Parameters():
    def __init__(self, dropout_rate):
        self.heads = 8
        self.depth = 5
        self.emb_size = 16*3
        self.f1 = 16
        self.pooling_size = 52
        self.dropout_rate = dropout_rate
        self.subject_number = 9
        self.learning_rate = 0.001
        self.batch_size = 72 
        
        self.epochs=1000
        self.number_aug=3
        # 训练一个batch的真实数量为 self.batch_size*(1+self.number_aug)
        self.number_seg=8
        self.gpus=gpus        

if __name__ == "__main__":
    #----------------------------------------
    DATA_DIR = r'../mymat_raw/'
    EVALUATE_MODE = 'LOSO-No' # leaving one subject out subject-dependent  subject-indenpedent

    TYPE = 'B'
    if EVALUATE_MODE!='LOSO':
        CNN_DROPOUT_RATE = 0.5
    else:
        CNN_DROPOUT_RATE = 0.25    

    
    parameters = Parameters(CNN_DROPOUT_RATE)

    
    parameters_list = ['B']
    for i in parameters_list:
#         parameters.heads = i
        TYPE = i
        number_class, number_channel = numberClassChannel(TYPE)
        RESULT_NAME = "MSCFormer_画ROC_{}_heads_{}_depth_{}_pool_{}".format(TYPE, 
                                                                       parameters.heads,
                                                                       parameters.depth,
                                                                       parameters.pooling_size)
    
        sModel = EEGTransformer(
            database_type=TYPE, 
            parameters = parameters,  
            ).cuda()
        summary(sModel, (1, number_channel, 1000)) 
    
        print(time.asctime(time.localtime(time.time())))
        
        result = main(RESULT_NAME,
                      parameters,
                        evaluate_mode = EVALUATE_MODE,
                        dataset_type=TYPE,
                      )
        print(time.asctime(time.localtime(time.time())))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 16, 3, 1000]           1,376
            Conv2d-2          [-1, 16, 1, 1000]              64
       BatchNorm2d-3          [-1, 16, 1, 1000]              32
               ELU-4          [-1, 16, 1, 1000]               0
         AvgPool2d-5            [-1, 16, 1, 19]               0
           Dropout-6            [-1, 16, 1, 19]               0
            Conv2d-7          [-1, 16, 3, 1000]           1,056
            Conv2d-8          [-1, 16, 1, 1000]              64
       BatchNorm2d-9          [-1, 16, 1, 1000]              32
              ELU-10          [-1, 16, 1, 1000]               0
        AvgPool2d-11            [-1, 16, 1, 19]               0
          Dropout-12            [-1, 16, 1, 19]               0
           Conv2d-13          [-1, 16, 3, 1000]             736
           Conv2d-14          [-1, 16, 

1_14 train_acc: 0.7702 train_loss: 0.530515	val_acc: 0.737500 val_loss: 0.5836277 test_acc:0.696875
1_16 train_acc: 0.7823 train_loss: 0.515269	val_acc: 0.775000 val_loss: 0.5249513 test_acc:0.712500
1_17 train_acc: 0.7218 train_loss: 0.585335	val_acc: 0.775000 val_loss: 0.4673623 test_acc:0.646875
1_19 train_acc: 0.7379 train_loss: 0.545433	val_acc: 0.812500 val_loss: 0.4738925 test_acc:0.715625
1_22 train_acc: 0.7702 train_loss: 0.488601	val_acc: 0.850000 val_loss: 0.4012694 test_acc:0.693750
1_31 train_acc: 0.7823 train_loss: 0.455490	val_acc: 0.850000 val_loss: 0.3462326 test_acc:0.728125
1_33 train_acc: 0.8266 train_loss: 0.391750	val_acc: 0.850000 val_loss: 0.3222168 test_acc:0.715625
1_36 train_acc: 0.8790 train_loss: 0.342267	val_acc: 0.875000 val_loss: 0.3378281 test_acc:0.706250
1_37 train_acc: 0.8185 train_loss: 0.401773	val_acc: 0.887500 val_loss: 0.3020096 test_acc:0.718750
1_44 train_acc: 0.8871 train_loss: 0.333597	val_acc: 0.900000 val_loss: 0.3177181 test_acc:0.712500


1_907 train_acc: 0.9516 train_loss: 0.145320	val_acc: 0.987500 val_loss: 0.0580287 test_acc:0.731250
1_952 train_acc: 0.9315 train_loss: 0.172159	val_acc: 1.000000 val_loss: 0.0562394 test_acc:0.753125
epoch:  952 	The test accuracy is: 0.753125
 THE BEST ACCURACY IS 0.753125	kappa is 0.50625
subject 1 duration: 0:34:33.466817
-------------------- raw train size： (400, 1, 3, 1000) test size： (320, 3, 1000) subject: 1 fold: 3
-------------------- train size： (320, 1, 3, 1000) val size： (80, 1, 3, 1000)
1_0 train_acc: 0.5444 train_loss: 0.718674	val_acc: 0.500000 val_loss: 0.6950446 test_acc:0.500000
1_1 train_acc: 0.5282 train_loss: 0.753140	val_acc: 0.612500 val_loss: 0.6794239 test_acc:0.528125
1_2 train_acc: 0.5444 train_loss: 0.704839	val_acc: 0.625000 val_loss: 0.6714302 test_acc:0.540625
1_4 train_acc: 0.5645 train_loss: 0.685477	val_acc: 0.662500 val_loss: 0.6551999 test_acc:0.534375
1_5 train_acc: 0.5242 train_loss: 0.696659	val_acc: 0.712500 val_loss: 0.6399666 test_acc:0.54375

1_0 train_acc: 0.4274 train_loss: 0.802234	val_acc: 0.500000 val_loss: 0.7250905 test_acc:0.500000
1_1 train_acc: 0.4395 train_loss: 0.784945	val_acc: 0.500000 val_loss: 0.6893985 test_acc:0.503125
1_3 train_acc: 0.4839 train_loss: 0.722601	val_acc: 0.500000 val_loss: 0.6890133 test_acc:0.500000
1_4 train_acc: 0.5282 train_loss: 0.703718	val_acc: 0.537500 val_loss: 0.6872090 test_acc:0.534375
1_5 train_acc: 0.5363 train_loss: 0.702082	val_acc: 0.575000 val_loss: 0.6861282 test_acc:0.550000
1_10 train_acc: 0.5685 train_loss: 0.661645	val_acc: 0.625000 val_loss: 0.6783231 test_acc:0.556250
1_14 train_acc: 0.5887 train_loss: 0.656324	val_acc: 0.700000 val_loss: 0.6231054 test_acc:0.565625
1_15 train_acc: 0.6371 train_loss: 0.621148	val_acc: 0.712500 val_loss: 0.5879614 test_acc:0.606250
1_19 train_acc: 0.6815 train_loss: 0.597222	val_acc: 0.712500 val_loss: 0.5635213 test_acc:0.587500
1_20 train_acc: 0.7177 train_loss: 0.579765	val_acc: 0.737500 val_loss: 0.6236994 test_acc:0.581250
1_21 

-------------------- raw train size： (400, 1, 3, 1000) test size： (280, 3, 1000) subject: 2 fold: 2
-------------------- train size： (320, 1, 3, 1000) val size： (80, 1, 3, 1000)
2_0 train_acc: 0.4476 train_loss: 0.753676	val_acc: 0.500000 val_loss: 0.6939727 test_acc:0.500000
2_1 train_acc: 0.5363 train_loss: 0.713642	val_acc: 0.500000 val_loss: 0.6900002 test_acc:0.500000
2_2 train_acc: 0.5000 train_loss: 0.721695	val_acc: 0.500000 val_loss: 0.6892195 test_acc:0.500000
2_3 train_acc: 0.5121 train_loss: 0.710267	val_acc: 0.575000 val_loss: 0.6868281 test_acc:0.467857
2_5 train_acc: 0.5202 train_loss: 0.699205	val_acc: 0.587500 val_loss: 0.6831863 test_acc:0.439286
2_10 train_acc: 0.5040 train_loss: 0.707498	val_acc: 0.587500 val_loss: 0.6759554 test_acc:0.485714
2_11 train_acc: 0.5565 train_loss: 0.679644	val_acc: 0.612500 val_loss: 0.6662125 test_acc:0.492857
2_12 train_acc: 0.5685 train_loss: 0.675106	val_acc: 0.625000 val_loss: 0.6609021 test_acc:0.496429
2_13 train_acc: 0.5363 trai

2_515 train_acc: 0.8992 train_loss: 0.244166	val_acc: 0.937500 val_loss: 0.2109477 test_acc:0.682143
2_521 train_acc: 0.8427 train_loss: 0.326522	val_acc: 0.937500 val_loss: 0.1980005 test_acc:0.707143
2_598 train_acc: 0.8992 train_loss: 0.227365	val_acc: 0.937500 val_loss: 0.1951571 test_acc:0.696429
2_602 train_acc: 0.8629 train_loss: 0.348106	val_acc: 0.950000 val_loss: 0.2010336 test_acc:0.692857
2_643 train_acc: 0.8710 train_loss: 0.270516	val_acc: 0.950000 val_loss: 0.1952489 test_acc:0.703571
2_706 train_acc: 0.8871 train_loss: 0.283793	val_acc: 0.962500 val_loss: 0.1930168 test_acc:0.675000
epoch:  706 	The test accuracy is: 0.675
 THE BEST ACCURACY IS 0.675	kappa is 0.35
subject 2 duration: 0:52:20.295119
-------------------- raw train size： (400, 1, 3, 1000) test size： (280, 3, 1000) subject: 2 fold: 4
-------------------- train size： (320, 1, 3, 1000) val size： (80, 1, 3, 1000)
2_0 train_acc: 0.5363 train_loss: 0.741428	val_acc: 0.500000 val_loss: 0.6955260 test_acc:0.500000

2_321 train_acc: 0.8347 train_loss: 0.369875	val_acc: 0.900000 val_loss: 0.3258873 test_acc:0.696429
2_351 train_acc: 0.8226 train_loss: 0.365924	val_acc: 0.900000 val_loss: 0.3005592 test_acc:0.689286
2_498 train_acc: 0.8790 train_loss: 0.305876	val_acc: 0.912500 val_loss: 0.3135300 test_acc:0.692857
2_715 train_acc: 0.9234 train_loss: 0.216241	val_acc: 0.912500 val_loss: 0.2569596 test_acc:0.703571
2_726 train_acc: 0.8710 train_loss: 0.348070	val_acc: 0.925000 val_loss: 0.3109320 test_acc:0.696429
2_986 train_acc: 0.8790 train_loss: 0.285302	val_acc: 0.925000 val_loss: 0.2419203 test_acc:0.725000
epoch:  986 	The test accuracy is: 0.725
 THE BEST ACCURACY IS 0.725	kappa is 0.44999999999999996
subject 2 duration: 1:27:50.504354
2 subject 5 fold mean: 
 accuray      71.214286
precision    71.435455
recall       71.214286
f1           71.141761
kappa        42.428571
dtype: float64
seed is 629
Subject 3
-------------------- raw train size： (400, 1, 3, 1000) test size： (320, 3, 1000) sub

3_497 train_acc: 0.8790 train_loss: 0.315837	val_acc: 0.900000 val_loss: 0.2612152 test_acc:0.828125
3_536 train_acc: 0.8468 train_loss: 0.341511	val_acc: 0.900000 val_loss: 0.2410101 test_acc:0.821875
3_552 train_acc: 0.8548 train_loss: 0.351045	val_acc: 0.912500 val_loss: 0.2304839 test_acc:0.825000
3_620 train_acc: 0.8710 train_loss: 0.331571	val_acc: 0.912500 val_loss: 0.2274104 test_acc:0.843750
3_631 train_acc: 0.8306 train_loss: 0.375688	val_acc: 0.925000 val_loss: 0.2315803 test_acc:0.825000
3_787 train_acc: 0.8871 train_loss: 0.279292	val_acc: 0.925000 val_loss: 0.2116683 test_acc:0.815625
3_825 train_acc: 0.8952 train_loss: 0.261759	val_acc: 0.925000 val_loss: 0.1971940 test_acc:0.834375
3_862 train_acc: 0.8952 train_loss: 0.257443	val_acc: 0.937500 val_loss: 0.2191662 test_acc:0.818750
epoch:  862 	The test accuracy is: 0.81875
 THE BEST ACCURACY IS 0.81875	kappa is 0.6375
subject 3 duration: 0:35:30.459526
-------------------- raw train size： (400, 1, 3, 1000) test size： (3

3_47 train_acc: 0.6774 train_loss: 0.584470	val_acc: 0.775000 val_loss: 0.5109367 test_acc:0.837500
3_65 train_acc: 0.7258 train_loss: 0.538360	val_acc: 0.775000 val_loss: 0.4946293 test_acc:0.843750
3_68 train_acc: 0.7419 train_loss: 0.524675	val_acc: 0.787500 val_loss: 0.4746709 test_acc:0.846875
3_77 train_acc: 0.7218 train_loss: 0.544410	val_acc: 0.800000 val_loss: 0.4564450 test_acc:0.840625
3_98 train_acc: 0.7863 train_loss: 0.471331	val_acc: 0.800000 val_loss: 0.4435409 test_acc:0.837500
3_110 train_acc: 0.8347 train_loss: 0.396775	val_acc: 0.800000 val_loss: 0.4355713 test_acc:0.828125
3_113 train_acc: 0.7621 train_loss: 0.492736	val_acc: 0.825000 val_loss: 0.4375253 test_acc:0.828125
3_114 train_acc: 0.7944 train_loss: 0.468530	val_acc: 0.825000 val_loss: 0.4175421 test_acc:0.843750
3_123 train_acc: 0.8065 train_loss: 0.433712	val_acc: 0.837500 val_loss: 0.4092407 test_acc:0.825000
3_149 train_acc: 0.7984 train_loss: 0.413807	val_acc: 0.850000 val_loss: 0.4206168 test_acc:0.83

4_221 train_acc: 0.9886 train_loss: 0.037371	val_acc: 0.976190 val_loss: 0.1293990 test_acc:0.950000
4_242 train_acc: 0.9924 train_loss: 0.024103	val_acc: 0.976190 val_loss: 0.1253626 test_acc:0.968750
4_320 train_acc: 0.9886 train_loss: 0.032823	val_acc: 0.976190 val_loss: 0.1205523 test_acc:0.968750
4_322 train_acc: 0.9735 train_loss: 0.055207	val_acc: 0.976190 val_loss: 0.1053252 test_acc:0.956250
4_389 train_acc: 0.9848 train_loss: 0.051930	val_acc: 0.976190 val_loss: 0.0888669 test_acc:0.978125
4_714 train_acc: 0.9924 train_loss: 0.031529	val_acc: 0.976190 val_loss: 0.0850956 test_acc:0.978125
4_804 train_acc: 0.9811 train_loss: 0.048090	val_acc: 0.976190 val_loss: 0.0840593 test_acc:0.984375
4_826 train_acc: 0.9886 train_loss: 0.039879	val_acc: 0.988095 val_loss: 0.0840576 test_acc:0.978125
4_830 train_acc: 0.9886 train_loss: 0.039940	val_acc: 0.988095 val_loss: 0.0833713 test_acc:0.978125
4_845 train_acc: 0.9811 train_loss: 0.043742	val_acc: 0.988095 val_loss: 0.0764157 test_acc

epoch:  984 	The test accuracy is: 0.96875
 THE BEST ACCURACY IS 0.96875	kappa is 0.9375
subject 4 duration: 0:26:23.753517
-------------------- raw train size： (420, 1, 3, 1000) test size： (320, 3, 1000) subject: 4 fold: 4
-------------------- train size： (336, 1, 3, 1000) val size： (84, 1, 3, 1000)
4_0 train_acc: 0.5341 train_loss: 0.729927	val_acc: 0.500000 val_loss: 0.7409879 test_acc:0.500000
4_1 train_acc: 0.5189 train_loss: 0.748656	val_acc: 0.500000 val_loss: 0.7363310 test_acc:0.500000
4_2 train_acc: 0.5303 train_loss: 0.702019	val_acc: 0.500000 val_loss: 0.7119554 test_acc:0.500000
4_5 train_acc: 0.8068 train_loss: 0.453220	val_acc: 0.761905 val_loss: 0.6239733 test_acc:0.843750
4_7 train_acc: 0.8750 train_loss: 0.273490	val_acc: 0.892857 val_loss: 0.2269846 test_acc:0.768750
4_8 train_acc: 0.9318 train_loss: 0.200438	val_acc: 0.928571 val_loss: 0.1803121 test_acc:0.812500
4_11 train_acc: 0.9205 train_loss: 0.222794	val_acc: 0.940476 val_loss: 0.1825447 test_acc:0.771875
4_14

5_123 train_acc: 0.8561 train_loss: 0.324633	val_acc: 0.892857 val_loss: 0.2318776 test_acc:0.987500
5_131 train_acc: 0.8371 train_loss: 0.389686	val_acc: 0.904762 val_loss: 0.2501477 test_acc:0.987500
5_136 train_acc: 0.8636 train_loss: 0.330001	val_acc: 0.904762 val_loss: 0.2414875 test_acc:0.984375
5_139 train_acc: 0.8712 train_loss: 0.314029	val_acc: 0.916667 val_loss: 0.2444827 test_acc:0.981250
5_160 train_acc: 0.8939 train_loss: 0.245031	val_acc: 0.916667 val_loss: 0.2339142 test_acc:0.993750
5_162 train_acc: 0.8864 train_loss: 0.291151	val_acc: 0.940476 val_loss: 0.2057809 test_acc:0.987500
5_466 train_acc: 0.9242 train_loss: 0.175417	val_acc: 0.940476 val_loss: 0.1666596 test_acc:0.981250
5_744 train_acc: 0.9583 train_loss: 0.119452	val_acc: 0.940476 val_loss: 0.1575481 test_acc:0.968750
5_752 train_acc: 0.9470 train_loss: 0.122725	val_acc: 0.940476 val_loss: 0.1573568 test_acc:0.981250
5_766 train_acc: 0.9091 train_loss: 0.207574	val_acc: 0.952381 val_loss: 0.2093232 test_acc

5_539 train_acc: 0.9205 train_loss: 0.219909	val_acc: 0.940476 val_loss: 0.1786373 test_acc:0.925000
5_698 train_acc: 0.9394 train_loss: 0.150688	val_acc: 0.952381 val_loss: 0.2002592 test_acc:0.959375
epoch:  698 	The test accuracy is: 0.959375
 THE BEST ACCURACY IS 0.959375	kappa is 0.91875
subject 5 duration: 0:20:18.843883
-------------------- raw train size： (420, 1, 3, 1000) test size： (320, 3, 1000) subject: 5 fold: 4
-------------------- train size： (336, 1, 3, 1000) val size： (84, 1, 3, 1000)
5_0 train_acc: 0.5000 train_loss: 0.766244	val_acc: 0.511905 val_loss: 0.7000568 test_acc:0.500000
5_1 train_acc: 0.5189 train_loss: 0.734119	val_acc: 0.523810 val_loss: 0.6907372 test_acc:0.500000
5_2 train_acc: 0.5455 train_loss: 0.697166	val_acc: 0.583333 val_loss: 0.6794089 test_acc:0.684375
5_3 train_acc: 0.4735 train_loss: 0.740458	val_acc: 0.595238 val_loss: 0.6811785 test_acc:0.621875
5_4 train_acc: 0.5265 train_loss: 0.711941	val_acc: 0.619048 val_loss: 0.6809248 test_acc:0.67187

6_17 train_acc: 0.7137 train_loss: 0.579358	val_acc: 0.750000 val_loss: 0.5147780 test_acc:0.750000
6_19 train_acc: 0.7661 train_loss: 0.516005	val_acc: 0.762500 val_loss: 0.5029281 test_acc:0.784375
6_21 train_acc: 0.7661 train_loss: 0.481603	val_acc: 0.800000 val_loss: 0.4548945 test_acc:0.784375
6_22 train_acc: 0.7379 train_loss: 0.524080	val_acc: 0.800000 val_loss: 0.3849216 test_acc:0.818750
6_23 train_acc: 0.7460 train_loss: 0.517063	val_acc: 0.837500 val_loss: 0.4318218 test_acc:0.800000
6_24 train_acc: 0.7702 train_loss: 0.509250	val_acc: 0.875000 val_loss: 0.3399148 test_acc:0.850000
6_26 train_acc: 0.8710 train_loss: 0.328195	val_acc: 0.875000 val_loss: 0.3222870 test_acc:0.843750
6_40 train_acc: 0.8548 train_loss: 0.363826	val_acc: 0.912500 val_loss: 0.2796374 test_acc:0.837500
6_75 train_acc: 0.8548 train_loss: 0.334173	val_acc: 0.912500 val_loss: 0.2462497 test_acc:0.868750
6_82 train_acc: 0.8710 train_loss: 0.302166	val_acc: 0.925000 val_loss: 0.2488469 test_acc:0.865625


6_50 train_acc: 0.7984 train_loss: 0.457573	val_acc: 0.812500 val_loss: 0.3867066 test_acc:0.871875
6_51 train_acc: 0.8024 train_loss: 0.435206	val_acc: 0.825000 val_loss: 0.3851470 test_acc:0.796875
6_56 train_acc: 0.8548 train_loss: 0.381751	val_acc: 0.837500 val_loss: 0.3968234 test_acc:0.846875
6_62 train_acc: 0.7823 train_loss: 0.460117	val_acc: 0.862500 val_loss: 0.3799202 test_acc:0.890625
6_71 train_acc: 0.8508 train_loss: 0.322710	val_acc: 0.887500 val_loss: 0.3736117 test_acc:0.909375
6_90 train_acc: 0.8871 train_loss: 0.293846	val_acc: 0.900000 val_loss: 0.3546025 test_acc:0.896875
6_125 train_acc: 0.8508 train_loss: 0.307904	val_acc: 0.900000 val_loss: 0.2877376 test_acc:0.893750
6_167 train_acc: 0.8468 train_loss: 0.319771	val_acc: 0.900000 val_loss: 0.2867590 test_acc:0.893750
6_194 train_acc: 0.8831 train_loss: 0.273394	val_acc: 0.912500 val_loss: 0.2639467 test_acc:0.878125
6_315 train_acc: 0.9234 train_loss: 0.185661	val_acc: 0.925000 val_loss: 0.2817964 test_acc:0.884

6_205 train_acc: 0.8911 train_loss: 0.283564	val_acc: 0.925000 val_loss: 0.2239411 test_acc:0.840625
6_236 train_acc: 0.8831 train_loss: 0.261085	val_acc: 0.925000 val_loss: 0.2226407 test_acc:0.856250
6_260 train_acc: 0.8669 train_loss: 0.310353	val_acc: 0.925000 val_loss: 0.2181106 test_acc:0.859375
6_268 train_acc: 0.8750 train_loss: 0.297853	val_acc: 0.937500 val_loss: 0.2107200 test_acc:0.896875
6_351 train_acc: 0.9194 train_loss: 0.216214	val_acc: 0.937500 val_loss: 0.2104785 test_acc:0.865625
6_366 train_acc: 0.8992 train_loss: 0.241132	val_acc: 0.937500 val_loss: 0.1922111 test_acc:0.871875
6_388 train_acc: 0.9113 train_loss: 0.209440	val_acc: 0.950000 val_loss: 0.2072396 test_acc:0.878125
6_457 train_acc: 0.9355 train_loss: 0.193930	val_acc: 0.950000 val_loss: 0.1936299 test_acc:0.896875
6_679 train_acc: 0.8992 train_loss: 0.232549	val_acc: 0.950000 val_loss: 0.1926452 test_acc:0.853125
6_682 train_acc: 0.8952 train_loss: 0.257253	val_acc: 0.950000 val_loss: 0.1919731 test_acc

epoch:  779 	The test accuracy is: 0.9375
 THE BEST ACCURACY IS 0.9375	kappa is 0.875
subject 7 duration: 0:13:30.953003
-------------------- raw train size： (400, 1, 3, 1000) test size： (320, 3, 1000) subject: 7 fold: 3
-------------------- train size： (320, 1, 3, 1000) val size： (80, 1, 3, 1000)
7_0 train_acc: 0.4194 train_loss: 0.812793	val_acc: 0.500000 val_loss: 0.7511035 test_acc:0.500000
7_2 train_acc: 0.5242 train_loss: 0.707848	val_acc: 0.512500 val_loss: 0.7198338 test_acc:0.500000
7_4 train_acc: 0.5242 train_loss: 0.711763	val_acc: 0.525000 val_loss: 0.6805753 test_acc:0.528125
7_5 train_acc: 0.5081 train_loss: 0.713054	val_acc: 0.562500 val_loss: 0.6716880 test_acc:0.553125
7_6 train_acc: 0.5363 train_loss: 0.684835	val_acc: 0.662500 val_loss: 0.6628063 test_acc:0.600000
7_8 train_acc: 0.5524 train_loss: 0.672927	val_acc: 0.700000 val_loss: 0.6449684 test_acc:0.615625
7_9 train_acc: 0.5685 train_loss: 0.677424	val_acc: 0.700000 val_loss: 0.6288022 test_acc:0.650000
7_10 tra

7_3 train_acc: 0.4919 train_loss: 0.716983	val_acc: 0.537500 val_loss: 0.6792852 test_acc:0.506250
7_4 train_acc: 0.5645 train_loss: 0.695164	val_acc: 0.725000 val_loss: 0.6653610 test_acc:0.581250
7_9 train_acc: 0.5847 train_loss: 0.676397	val_acc: 0.750000 val_loss: 0.5858803 test_acc:0.678125
7_11 train_acc: 0.6411 train_loss: 0.629894	val_acc: 0.750000 val_loss: 0.5786523 test_acc:0.700000
7_12 train_acc: 0.6290 train_loss: 0.629341	val_acc: 0.787500 val_loss: 0.5352093 test_acc:0.703125
7_13 train_acc: 0.6371 train_loss: 0.636659	val_acc: 0.787500 val_loss: 0.5125898 test_acc:0.709375
7_19 train_acc: 0.7863 train_loss: 0.514949	val_acc: 0.800000 val_loss: 0.4692780 test_acc:0.800000
7_32 train_acc: 0.7863 train_loss: 0.489108	val_acc: 0.812500 val_loss: 0.4583374 test_acc:0.828125
7_55 train_acc: 0.8911 train_loss: 0.273865	val_acc: 0.837500 val_loss: 0.4712297 test_acc:0.893750
7_62 train_acc: 0.8911 train_loss: 0.308432	val_acc: 0.837500 val_loss: 0.3979265 test_acc:0.921875
7_8

8_190 train_acc: 0.8750 train_loss: 0.305935	val_acc: 0.886364 val_loss: 0.3556738 test_acc:0.915625
8_244 train_acc: 0.8643 train_loss: 0.301197	val_acc: 0.886364 val_loss: 0.3229134 test_acc:0.918750
8_245 train_acc: 0.8786 train_loss: 0.272556	val_acc: 0.886364 val_loss: 0.3063569 test_acc:0.918750
8_272 train_acc: 0.9107 train_loss: 0.235992	val_acc: 0.909091 val_loss: 0.2662140 test_acc:0.918750
8_361 train_acc: 0.8821 train_loss: 0.272445	val_acc: 0.909091 val_loss: 0.2540729 test_acc:0.934375
epoch:  361 	The test accuracy is: 0.934375
 THE BEST ACCURACY IS 0.934375	kappa is 0.86875
subject 8 duration: 0:13:24.984174
-------------------- raw train size： (440, 1, 3, 1000) test size： (320, 3, 1000) subject: 8 fold: 3
-------------------- train size： (352, 1, 3, 1000) val size： (88, 1, 3, 1000)
8_0 train_acc: 0.5179 train_loss: 0.744946	val_acc: 0.500000 val_loss: 0.7707474 test_acc:0.500000
8_1 train_acc: 0.5500 train_loss: 0.701707	val_acc: 0.500000 val_loss: 0.7189288 test_acc:0

8_3 train_acc: 0.6679 train_loss: 0.638616	val_acc: 0.670455 val_loss: 0.6063108 test_acc:0.818750
8_4 train_acc: 0.7107 train_loss: 0.562976	val_acc: 0.681818 val_loss: 0.6083933 test_acc:0.850000
8_7 train_acc: 0.7143 train_loss: 0.565500	val_acc: 0.727273 val_loss: 0.5564250 test_acc:0.893750
8_8 train_acc: 0.7679 train_loss: 0.523723	val_acc: 0.738636 val_loss: 0.5994162 test_acc:0.865625
8_9 train_acc: 0.7714 train_loss: 0.519314	val_acc: 0.738636 val_loss: 0.5505543 test_acc:0.871875
8_11 train_acc: 0.7643 train_loss: 0.476076	val_acc: 0.761364 val_loss: 0.5089976 test_acc:0.893750
8_12 train_acc: 0.7286 train_loss: 0.528536	val_acc: 0.761364 val_loss: 0.5017951 test_acc:0.900000
8_15 train_acc: 0.8250 train_loss: 0.409372	val_acc: 0.772727 val_loss: 0.4942768 test_acc:0.884375
8_17 train_acc: 0.8036 train_loss: 0.426530	val_acc: 0.772727 val_loss: 0.4929410 test_acc:0.884375
8_23 train_acc: 0.8429 train_loss: 0.424363	val_acc: 0.795455 val_loss: 0.4931730 test_acc:0.896875
8_29 

9_515 train_acc: 0.9194 train_loss: 0.206637	val_acc: 0.937500 val_loss: 0.3021417 test_acc:0.890625
9_608 train_acc: 0.8911 train_loss: 0.240185	val_acc: 0.937500 val_loss: 0.2754226 test_acc:0.890625
9_898 train_acc: 0.9073 train_loss: 0.247302	val_acc: 0.937500 val_loss: 0.2376166 test_acc:0.881250
epoch:  898 	The test accuracy is: 0.88125
 THE BEST ACCURACY IS 0.88125	kappa is 0.7625
subject 9 duration: 0:13:36.773288
-------------------- raw train size： (400, 1, 3, 1000) test size： (320, 3, 1000) subject: 9 fold: 3
-------------------- train size： (320, 1, 3, 1000) val size： (80, 1, 3, 1000)
9_0 train_acc: 0.5282 train_loss: 0.754011	val_acc: 0.500000 val_loss: 0.7175877 test_acc:0.500000
9_1 train_acc: 0.5403 train_loss: 0.708905	val_acc: 0.525000 val_loss: 0.6950948 test_acc:0.500000
9_9 train_acc: 0.6492 train_loss: 0.637504	val_acc: 0.650000 val_loss: 0.6754226 test_acc:0.771875
9_12 train_acc: 0.7258 train_loss: 0.549865	val_acc: 0.662500 val_loss: 0.6974292 test_acc:0.78125

**The average Best accuracy is: 87.99603174603175kappa is: 75.9920634920635

best epochs:  [950, 952, 984, 943, 570, 809, 713, 706, 705, 986, 838, 862, 933, 907, 941, 930, 906, 984, 906, 840, 977, 434, 698, 990, 982, 945, 844, 944, 850, 930, 819, 779, 729, 814, 920, 541, 361, 937, 820, 337, 941, 898, 969, 362, 990]
---------  all result  ---------
        accuray  precision     recall         f1      kappa
0     78.062500  82.740546  78.062500  77.229602  56.125000
1     71.214286  71.435455  71.214286  71.141761  42.428571
2     82.750000  83.089528  82.750000  82.703109  65.500000
3     97.687500  97.705734  97.687500  97.687271  95.375000
4     96.812500  97.128354  96.812500  96.800028  93.625000
5     87.812500  88.230574  87.812500  87.777378  75.625000
6     94.000000  94.036362  94.000000  93.998723  88.000000
7     94.750000  94.769710  94.750000  94.749454  89.500000
8     88.875000  89.333736  88.875000  88.841873  77.750000
mean  87.996032  88.718889  87.996032  87.881022  