In [1]:
from torch_tools.m_dataset import MDataset,ch_names 
from torch_tools.train_test import train,test
from torch_tools.DOCinformer import DOCinformer
from torch.utils.data import DataLoader
from glob import glob
import torch
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split,KFold
import gc
import pandas as pd
import optuna
import os
import joblib

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
t_window = 10
t_overlap = 0.0
sf = 250

In [3]:
n_folds = 5
valid_ratio = 0.2
random_state = 555
ckp_parent_dir = 'ckp//pdoc//DOCinformer'

In [4]:
# 读取pdoc
sub_regex = 'F:\\PLL\\静息态数据没有ICA\\*.mat'
sub_paths = glob(sub_regex)

In [5]:
def sub_kfold_withvalid(sub_dataset_list,n_folds,valid_ratio,random_state=0):
    # 分k折
    train_folds={}
    valid_folds={}
    test_folds={}
    kf = KFold(n_splits=n_folds,shuffle=True,random_state=random_state)
    for i, (train_indexes, test_indexes) in enumerate(kf.split(sub_dataset_list)):
        train_subjects = [sub_dataset_list[train_index]  for train_index in train_indexes]
        test_subjects = [sub_dataset_list[test_index]  for test_index in test_indexes]
        # valid
        if valid_ratio == 0:
            valid_subjects = test_subjects
        else:
            train_subjects,valid_subjects = train_test_split(train_subjects,test_size=valid_ratio,shuffle=True,random_state=random_state)
        # 记录本fold的subject编号
        train_folds[i]=train_subjects
        valid_folds[i]=valid_subjects
        test_folds[i]=test_subjects   
    return train_folds,valid_folds,test_folds

In [6]:
from torch_tools.train_test import test_roc
def testnfold(standard_chs,nets,ckp_dir):
    # load dataset
    pdoc_sub_dataset_list = []
    pdoc_label_dict =  {
        'VS':0,
        'MCS':1,
        # 'MCS-':1,
        # 'MCS+':1,
    } 
    for sub_path in sub_paths:
        sub_dataset = MDataset(sub_path,t_window,t_overlap,standard_chs,pdoc_label_dict)
        if len(sub_dataset)>0:
            pdoc_sub_dataset_list.append(sub_dataset)
    pdoc_train_folds,pdoc_valid_folds,pdoc_test_folds  = sub_kfold_withvalid(pdoc_sub_dataset_list,n_folds,valid_ratio,random_state)
    
    avg_test_acc = 0.0
    avg_test_precision = 0.0
    avg_test_f1 = 0.0
    avg_test_recall = 0.0
    avg_test_auc = 0.0
    fold_test_results = []
    for i in range(n_folds):
        fold_ckp_path = ckp_dir+'\\'+f'fold_{i}.pth'
        net = nets[i]
        test_subs = pdoc_test_folds[i]
        fold_test_result = test(net,fold_ckp_path,DataLoader(ConcatDataset(test_subs),batch_size=32,shuffle=True,drop_last=False),False)
        # 记录本折测试结果
        avg_test_acc += fold_test_result["acc"]
        avg_test_precision += fold_test_result["precision"]
        avg_test_recall += fold_test_result["recall"]
        avg_test_f1 += fold_test_result["f1"]
        # 记录本折ROC结果
        mean_fpr,interp_tpr,roc_auc = test_roc(net,fold_ckp_path,DataLoader(ConcatDataset(test_subs),batch_size=32,shuffle=True,drop_last=False))
        avg_test_auc += roc_auc
        fold_test_result["auc"] = roc_auc
        fold_test_results.append(fold_test_result)

        
    
    # 最后一行是平均
    avg_test_acc = avg_test_acc/n_folds
    avg_test_precision = avg_test_precision/n_folds
    avg_test_recall = avg_test_recall/n_folds
    avg_test_f1 = avg_test_f1/n_folds
    avg_test_auc = avg_test_auc/n_folds
    avg_result = {"precision":avg_test_precision,"recall":avg_test_recall,"f1":avg_test_f1,"acc":avg_test_acc,"auc":avg_test_auc}
    print(avg_result)   
    return avg_result

In [7]:
def result(trial,standard_chs,part_chs = None):
    ckp_dir = ckp_parent_dir + '\\' + f'ch{len(standard_chs)}-trail{trial.number}'
    # 模型结构参数
    pool_time_stride = trial.suggest_categorical("pool_time_stride", [200])
    n_filters_time = trial.suggest_categorical("n_filters_time", [64])
    filter_time_length = trial.suggest_categorical("filter_time_length", [25])
    depth = trial.suggest_int('depth', 1,1)
    att_heads =  trial.suggest_int('att_heads', 1,1)
    # 训练参数
    epochs = 4
    batch_size = trial.suggest_categorical("batch_size", [16])
    lr = (3**trial.suggest_int('lr',1,1))*1e-4
    wd = trial.suggest_categorical("wd", [0.1])
   # setting
    nets = [DOCinformer(n_outputs=2, input_window_seconds=t_window,sfreq=sf, standard_chs = standard_chs,positionEncoding = 'coordinate',channel_drop_prob = 0.2,att_drop_prob = 0.25,
                                              filter_time_length=filter_time_length,att_depth=depth,att_heads=att_heads,n_filters_time=n_filters_time,pool_time_stride = pool_time_stride,pool_time_length=pool_time_stride)
            for i in range(n_folds)]
    optimizers = [torch.optim.AdamW(net.parameters(), lr=lr,weight_decay=lr*wd)
            for net in nets]
    # t
    
    for net in nets:
        net.select_channel(part_chs)
    return testnfold(part_chs,nets,ckp_dir)

In [8]:
best_trial = joblib.load(ckp_parent_dir+"//best_trail.joblib")

In [9]:
standard_chs = ['Fpz','P8', 'Oz', 'TP9', 'C4', 'FC6', 'F4', 'CP6', 'CP1', 'P4',  
           'T7', 'FC1', 'O1', 'TP10', 'FC2', 'CP2', 'C3', 'P7', 'FC5', 'F7', 
           'P3', 'Fp2', 'O2', 'F8', 'Pz', 'CP5', 'F3', 'Fp1', 'T8','Fz']
result(best_trial,standard_chs,standard_chs)

{'precision': 0.7950405765864472, 'recall': 0.895813049856877, 'f1': 0.838874785967129, 'acc': 0.7746225619358532, 'auc': 0.7560074289283094}


{'precision': 0.7950405765864472,
 'recall': 0.895813049856877,
 'f1': 0.838874785967129,
 'acc': 0.7746225619358532,
 'auc': 0.7560074289283094}

In [10]:
lobe_group = {
    'Fp' : ['Fp1', 'Fp2','Fpz'],
    'F' : ['Fz','F7', 'F8', 'F3', 'F4'],
    'FC': ['FC5', 'FC6', 'FC1', 'FC2'],
    'C': ['CP5', 'CP6', 'CP1', 'CP2','C3', 'C4'],
    'P': ['P7', 'P8', 'P3', 'P4', 'Pz'],
    'O': ['O1', 'O2', 'Oz'],
    'T': ['T7', 'T8', 'TP9', 'TP10']
}

In [11]:
lobe_remove_result = {}
for lobe_name in lobe_group:
    print(f"remove lobe:{lobe_name}")
    lobe_chs =  lobe_group[lobe_name]
    other_chs = [item for item in standard_chs if item not in lobe_chs]
    result_dict = result(best_trial,standard_chs,other_chs)
    result_dict['channels'] = len(other_chs)
    lobe_remove_result[lobe_name] = result_dict
    print("")
lobe_remove_df = pd.DataFrame.from_dict(lobe_remove_result, orient='index')
lobe_remove_df.to_csv('lobe_remove_result.csv')

remove lobe:Fp
{'precision': 0.683562204538702, 'recall': 0.9800601493515286, 'f1': 0.8011484449330016, 'acc': 0.6800450345491514, 'auc': 0.6216516028386624}

remove lobe:F
{'precision': 0.7967619041995913, 'recall': 0.8681558264086144, 'f1': 0.8274414055749049, 'acc': 0.762530352650384, 'auc': 0.7531449935399319}

remove lobe:FC
{'precision': 0.7958536197954473, 'recall': 0.8826185212752087, 'f1': 0.8335636965808222, 'acc': 0.7688477588465721, 'auc': 0.7522546138255825}

remove lobe:C
{'precision': 0.7945917395443516, 'recall': 0.90748899981834, 'f1': 0.8436797912071574, 'acc': 0.7799016894830976, 'auc': 0.7550803603301317}

remove lobe:P
{'precision': 0.7899729800920433, 'recall': 0.9086154530209172, 'f1': 0.8414757369157538, 'acc': 0.7761770001707797, 'auc': 0.7586790847503762}

remove lobe:O
{'precision': 0.795154076200936, 'recall': 0.9041163572002432, 'f1': 0.8429565874264047, 'acc': 0.780025561599105, 'auc': 0.7587357507290353}

remove lobe:T
{'precision': 0.797743924171362, 're