In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)#, force_remount=True

In [None]:
!pip install wfdb==3.0.1


In [None]:
# 总体参数
%load_ext autoreload
%autoreload 2
root_dir = '/content/drive/MyDrive/ecg_segmention/ecg/keras'

import os
os.chdir(root_dir)

os.environ['KMP_DUPLICATE_LIB_OK']='True'


In [None]:
import os

import utils.data_util as du
import utils.file_util as fu
from utils.utils import setup_gpu
from utils.postprocess import get_predict_points, print_points_accuracy

from sklearn.model_selection import KFold

from utils.Config import Config

from utils.models.model import get_model

import numpy as np
import random
from tensorflow.keras.models import load_model
import copy
import tensorflow as tf
from tensorflow import keras
from random import shuffle
import pandas as pd


from tensorflow.keras.callbacks import EarlyStopping


from sklearn.metrics import classification_report
from sklearn.metrics._plot import confusion_matrix

from utils.utils import plot_confusion_matrix


In [None]:
def setup_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED']=str(seed)

   
setup_gpu('0')

In [None]:


def train(model, config):

    x_train, y_train,y_p_train, y_name_train,x_val, y_val, y_p_val,y_name_val, x_test, y_test, y_p_test,y_name_test = config.data
    y_train = keras.utils.to_categorical(y_train, num_classes=4, dtype='float32')
    y_val = keras.utils.to_categorical(y_val, num_classes=4, dtype='float32')
    callback = keras.callbacks.EarlyStopping(monitor='val_accuracy', min_delta=0, patience=4, verbose=0, mode='auto', baseline=None, restore_best_weights=True)

    history = model.fit(x_train, y_train, batch_size= config.batch_size, epochs=config.epochs, validation_data = (x_val,y_val), verbose=config.train_verbose,callbacks=[callback])

    model.save(config.fname_model)

    ft.save(history.history, config.fname_history)

    return history




def eval(ecgs, y_true, y_pred, labels, target_names, plot_conf_matrix=False, config = None):
    history = ft.load(config.fname_history)
    plot_results(history)
    
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()

    if plot_conf_matrix:
        conf_mat = confusion_matrix.confusion_matrix(y_true=y_true, y_pred=y_pred, labels=labels)

        plot_confusion_matrix(
            confusion_matrix=conf_mat,
            target_names=target_names,
            title='Confusion matrix',
            normalize=True
        )

    report = classification_report(y_true=y_true, y_pred=y_pred, digits = 4, labels=labels, target_names=target_names)
    
    print('Report: ' + str(report))

In [None]:
def evaluation(points, points_pred,ecgs=None,ecg_names=None,config=None,verbose = False):

    points = np.array(points)
    points_pred = np.array(points_pred)
    
    L = points.shape[0]
    size = points.shape[1]

    results = []
    FC = 250
    tol_default = 150
    tolerances = [40,70,tol_default]
    
    output = []
    for tol in tolerances:
        sens = np.zeros((size))
        TP = np.zeros((size))
        FN = np.zeros((size))
        for i in range(size):
            for j in range(L):
                if points_pred[j, i] == -1:
                    if points[j,i] != -1:
                        FN[i]+=1                       
                else:
                    diff = points[j,i] - points_pred[j,i]
                    diff_abs = np.abs(diff)
                    if diff_abs < FC/1000 * tol:
                        sens[i] += 1
                        TP[i] += 1
                    else:
                        FN[i] += 1
        
        sens = TP/(TP + FN) *100
        output.append(sens)
        results.append(np.round(sens,2))

        if tol == tol_default:
            print(TP,FN)
            print('SENS:', np.round(sens,2))

        
    
    errors = []
    for i in range(size):
        err = []
        for j in range(L):
            if points_pred[j,i] == -1 or points[j,i] == -1:
                continue
            diff = points[j,i] - points_pred[j,i]
            diff_abs = np.abs(diff)
            if diff_abs < FC/1000 * tol_default:
                err.append(diff *1000/FC)
        errors.append(err)
    errors = np.array(errors)


    print('ERROR:', end=' ')
    errors_tuple = ('error','-')
    errors_result = []
    err_means = np.zeros((6))
    err_stds = np.zeros((6))
    for i in range(6):
        err_m = np.mean(errors[i])
        err_s = np.std(errors[i])
        err_means[i]= err_m
        err_stds[i] = err_s
        tmp = "{:.1f}+{:.1f}".format(err_m, err_s)
        errors_result.append(tmp)
        errors_tuple = errors_tuple + (tmp,)
        print(tmp,end=' ')
    print('')
    output.append(err_means)
    output.append(err_stds)
    output = np.array(output)
    results.append(errors_result)
    print(results)
 
    return results, output

In [None]:

def qtdb_split(record_list):
    bih_list = ['sel30','sel31','sel32','sel33','sel34','sel35','sel36',
         'sel37','sel38','sel39','sel40','sel41','sel42','sel43','sel44','sel45','sel46','sel47',
         'sel48','sel49','sel50','sel51','sel52','sel17152']
    stt_list = []
    for idx in range(len(record_list)):
        if 'sele' in record_list[idx]:
            stt_list.append(record_list[idx])
    delete_index = []
    for idx in range(len(record_list)):
        if (record_list[idx] in stt_list) or (record_list[idx] in bih_list):
            delete_index.append(idx)
    mit_list = np.delete(record_list, delete_index)
   
    idx = []
    random.seed(config.seed)
    all_list = [bih_list, stt_list, mit_list]
    
    for r_list in all_list:     
        shuffle(r_list)
    for i in range(5):
        ratio = [i*0.2,(i+1)*0.2] 
        train_idx,test_idx = [],[]
        for r_list in all_list:          
            train_idx_tmp, test_idx_tmp = __split_train_test__(r_list, ratio)
            train_idx.extend(train_idx_tmp)
            test_idx.extend(test_idx_tmp)
        idx.append((train_idx, test_idx))
    return idx

def __split_train_test__(r_list, ratio = [0.8, 1]):
    r_list = np.array(r_list)
    cnt = len(r_list)
    test_start = int(ratio[0]*cnt)
    test_end = int(ratio[1]*cnt)
    test_idx =  r_list[test_start:test_end]
    delete_index = []
    for idx in range(len(r_list)):
        if r_list[idx] in test_idx:
            delete_index.append(idx)
    train_idx = np.delete(r_list, delete_index)
    return train_idx,test_idx



In [None]:
def exec(config):
   
    SEED = config.seed
    DATASET = config.dataset
    WINDOWS_SIZE = config.wave_len
    EPOCHS = config.epochs
    BATCH_SIZE = config.batch_size
    NUM_SEGS_CLASS = 4
    FC = config.fc

    file_prefix = config.file_prefix

    setup_seed(config.seed)
    MODEL_INSTNACE = get_model(config)

    x_train, y_train,y_p_train, y_name_train,x_val, y_val, y_p_val,y_name_val, x_test, y_test, y_p_test,y_name_test = config.data



    train(
      MODEL_INSTNACE,
      config = config
    )

    model = load_model(config.fname_model)

    y_test = keras.utils.to_categorical(y_test, num_classes=4, dtype='float32')

    _,acc = model.evaluate(x_test, y_test, verbose=0) 
    y_pred = model.predict(x_test)

    ecgs_list = x_test
    names_list = y_name_test
    labels_list = np.argmax(y_test,  axis=2)
    predicted_list = np.argmax(y_pred, axis=2)

    
    return ecgs_list, labels_list, predicted_list, names_list#, acc




In [None]:
def make_crossvalidation(config, kfold_splits = 5):
    results =[]
    details = []
    setup_seed(config.seed)
    if config.train:
        dataset = config.dataset
    else:
        dataset = config.test_dataset
    
    data, labels = fu.load(config.fname_data)
    actual_labels = None
    if(dataset == 'qtdb'):
        record_list = du.get_ids(Config.QTDB_RECORD_DIR)
    elif (dataset == 'ludb'):
        record_list = du.get_ids(Config.LUDB_RECORD_DIR)


    for record in record_list:
        if(record not in data.keys() or len(data[record]) == 0)
            record_list = np.delete(record_list, np.where(record_list== record))
    
    kf = KFold(n_splits=kfold_splits, shuffle=True)
    
    time = 1
    ecgs = []
    y_labels = []
    y_predicts = []
    ecg_names = []
    point_labels = []
   
    if dataset == 'ludb':
        fold_idx = kf.split(record_list)
    else:
        fold_idx = qtdb_split(record_list)

    for train_idx, test_idx in fold_idx:
        print('5 fold-No.',time)

        if dataset== 'ludb':
            data_, labels_,_ = fu.load(config.fname_data)
            x_train, y_train, y_p_train,y_name_train = du.__get_data__(data_, labels_,actual_labels, record_list[train_idx], config)
            x_test, y_test, y_p_test,y_name_test = du.__get_data__(data, labels,actual_labels, record_list[test_idx], config)
        else:
            data_, labels_,_ = fu.load(config.fname_data)
            x_train, y_train, y_p_train,y_name_train = du.__get_data__(data_, labels_,actual_labels, train_idx, config)
            x_test, y_test, y_p_test,y_name_test = du.__get_data__(data, labels,actual_labels, test_idx, config)
        print(x_train.shape[0] + x_test.shape[0], x_train.shape,
            x_test.shape)
 
        config.data =  x_train, y_train,y_p_train, y_name_train, x_test, y_test, y_p_test,y_name_test, x_test, y_test, y_p_test,y_name_test

        if time == 1:
            ecgs, y_labels, predicts,ecg_names = exec(config)
            ecg_names = list(ecg_names)
            point_labels = y_p_test
        else:   
            ecgs_list, labels_list, predicted_list,names_list = exec(config)
            ecgs = np.vstack((ecgs, ecgs_list))
            y_labels = np.vstack((y_labels, labels_list))
            predicts = np.vstack((predicts, predicted_list))
            point_labels = np.vstack((point_labels, y_p_test))
            ecg_names.extend(list(names_list))
        time += 1

            
    
  
    fu.save((ecgs, y_labels, predicts, ecg_names, point_labels, config), Config.RESOURCES_DIR + '/result/' + config.dataset + '.pkl')
    points_pred = get_predict_points(predicts)
    results_detail,output = evaluation(point_labels, points_pred,ecgs,ecg_names,config,verbose=False)
    
    
    eval(
      ecgs=ecgs,
      y_true=y_labels,
      y_pred=predicts,
      labels=[0, Config.P_H, Config.QRS_H, Config.T_H],
      target_names=['none', 'p_wave', 'qrs', 't_wave'],

      plot_conf_matrix=True,
      plot_ecg=True,
      plot_ecg_windows_size=config.wave_len,
      config = config
    )

    return output
   


In [None]:
config = Config()
config.dataset = 'ludb'
config.refresh()
config.print()
output = make_crossvalidation(config)

