In [None]:
import pandas as pd
import sys
import os
import time
import shutil
import numpy as np
import matplotlib
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
import math
import data_util as du
import exact_util as eu

from data_util import Time
from brainflow.board_shim import BoardShim, BrainFlowInputParams, LogLevels, BoardIds
from brainflow.data_filter import DataFilter, FilterTypes, AggOperations, WindowFunctions, DetrendOperations

import nolds

import iceland_tool as it
from config import Config, get_cfg_defaults

import scipy as sp
import wfdb
import file_util as fu
import nolds

In [None]:
def sample_entropy(signal):
    global global_threshold
    T = global_threshold
    sampen = nolds.sampen(signal, tolerance=0.25*T)
    return sampen


def root_mean_square(signal):
    rms = np.sqrt(1/len(signal) * np.sum(np.power(signal,2)))
    return rms


def curve(signal,method, fs, window_size=60, shift= 1,mean_filter_size = 10):
    y = eu.get_curve(signal, window_size=window_size, shift = shift, method=method)#zero_crossing_1
#     if mean_filter_size != 0:
#         DataFilter.perform_rolling_filter(y, mean_filter_size*fs, AggOperations.MEAN.value)

    return y

def analysis(signal,method):
    global global_threshold
    window_size, shift,base_window_size,base_shift, mean_filter_size = cfg.DETECT.WINDOW_SIZE,cfg.DETECT.SHIFT, cfg.DETECT.BASE_WINDOW_SIZE, cfg.DETECT.BASE_SHIFT, cfg.DETECT.MEAN_FILTER_SIZE

    global_threshold =  np.std(signal)

    y_origin = curve(signal,method, fs, window_size=window_size*fs, shift= fs*shift, mean_filter_size = mean_filter_size)

    iup = eu.remove_base(y_origin, base_window_size = base_window_size*fs, shift = fs*10)

    onsets,offsets, ans = detect_edge(iup, 5*fs, weight = bound_weight, fs = fs, window_size = window_size)
    
    return ans


def detect_edge(signal, d, weight, fs, window_size):
    def judge(arr, th, on = True):
        flag = 0
        for i in range(len(arr)):
            if on and arr[i] >= th:
                flag += 1
            elif on is not True and arr[i] < th:
                flag += 1
        return flag == len(arr)
    s = 1
    onsets, offsets,ans = [],[],[]
    
    signal_ = signal[fs*(window_size+1):-fs*(window_size+1)]
    
    f_min = np.min(signal_)
    f_max = np.max(signal_)

    th = f_min+weight*(f_max-f_min)
   
    flag = False
    i = 0
    while i < len(signal_):
        tmp = signal_[i:i+d]
        if judge(tmp,th, True) and flag is not True:
                flag = True
                onsets.append(i)
        if judge(tmp,th, False) and flag:
                flag = False
                offsets.append(i)
                ans.append(onsets[-1] + np.argmax(signal_[onsets[-1]:i]))
                #i+=fs
        i += s
    if len(onsets)> len(offsets):
        if onsets[-1] != len(signal_)-1:
            offsets.append(len(signal_)-1)
        else:
            onsets = onsets[:len(offsets)] 
    onsets = np.array(onsets)+fs*window_size+fs
    offsets = np.array(offsets)+fs*window_size+fs
    ans = np.array(ans)+fs*window_size+fs
    return onsets, offsets, ans


        

In [None]:
def get_intervals(signal,fs, thr = 30):
    minute_interval = []
    minutes = du.get_total_minute(len(signal), fs)
    
    split = 1
    while minutes/split > thr:
        split+=1
    for i in range(split):
        if i == 0:
            onset = 0.5
        else:
            onset = (minutes -2)/split*i
        if i == split-1:
            offset = minutes -0.5
        else:
            offset = minutes/split*(i+1)
        
        minute_interval.append([onset, offset])
    return minute_interval

def evaluate(r_ref, r_ans,fs = 20, thr=30):
    thr_ = thr
    fs_ = fs
    all_TP = 0
    all_FN = 0
    all_FP = 0

    errors = []
    for i in range(len(r_ref)):
        FN = 0
        FP = 0
        TP = 0

        detect_loc = 0
        for j in range(len(r_ref[i])):
            loc = np.where(np.abs(r_ans[i] - r_ref[i][j]) <= thr_*fs_)[0]
            detect_loc += len(loc)

            if len(loc) >= 1:
                
                TP += 1
                FP = FP + len(loc) - 1
                
                diff = r_ref[i][j] - r_ans[i][loc[0]]
                errors.append(diff/fs)
                
            elif len(loc) == 0:
                FN += 1
        FP = FP+(len(r_ans[i])-detect_loc)
        
        all_FP += FP
        all_FN += FN
        all_TP += TP
    if all_TP == 0:
        DR = 0
        
    else:
        DR = all_TP / (all_FN + all_TP)
      
    if all_FP == 0:
        EDR = 0
    else:
        EDR =  all_FP / (all_FP + all_TP)
    print("DR:{}, EDR:{}".format(DR,EDR))
    print("error mean+std:{}+{},abs_mean+std:{}+{}".format(np.mean(errors), np.std(errors), np.mean(np.abs(errors)), np.std(errors)))

In [None]:
global_threshold = 0

record_time = Time(0,0,0)

cfg = get_cfg_defaults()
fs = cfg.PARAM.FS
channel = cfg.PARAM.CHANNEL
filter = cfg.PARAM.FILTER
method =sample_entropy
cfg.DETECT.METHOD_NAME = method.__name__
method_name = cfg.DETECT.METHOD_NAME 
bound_weight= 0.2
print(cfg)

In [None]:
from contraction_time import con_loc

def main(cfg):
   
    

    names = list(con_loc.keys())

    refs = []
    ans = []


    for name in names:
        true_ref = np.array(con_loc[name])*fs*60
    
        signal = it.get_signal(name, filter = filter, channels = channel, target_fs = fs, normalize = cfg.PARAM.NORMALIZE)   
       

        record_time  = Time(0,0,0)

        total_minute =  du.get_total_minute(len(signal), fs=fs)
        intervals = get_intervals(signal, fs = fs)

        for minute_interval in intervals:
            s_onset,s_offset = int(minute_interval[0]*60*fs), int(minute_interval[1]*60*fs)
            record_time  = Time(0,0,0)
            record_time.add_second((s_onset)//fs)
            signal_= signal[s_onset:s_offset]
            tmp_ref = []
            for i in range(len(true_ref)):
                if true_ref[i] >s_onset and true_ref[i] < s_offset:
                    tmp_ref.append(true_ref[i])
                    
            tmp_ref = np.array(tmp_ref)-s_onset
            refs.append(tmp_ref)
            
            an = analysis(signal_,method)
            ans.append(an)
            
            
    evaluate_thr = 30
    evaluate(refs, ans, thr = evaluate_thr, fs = fs)

In [None]:
main(cfg)