### 配置环境

In [None]:
# from google.colab import drive
# drive.mount('/gdrive')
# %cd /gdrive/My Drive/wcd

In [None]:
# %tensorflow_version 2.x

### 第一部分 所用到的库

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

# TensorFlow and tf.keras
# import tensorflow as tf

import os
import io
# from tensorflow.keras import layers



# Helper libraries
# import imageio
import datetime
import numpy as np
import time
import matplotlib.pyplot as plt
# from IPython import display
# import PIL
# import glob
from scipy import signal
# print(tf.__version__)

### 第二部分 常用预处理函数


In [None]:
def presolveH(path):
    record = wfdb.rdrecord(path, sampfrom=180, sampto=1260, channels=[1])
    series = record.p_signal[:, 0]
    Fs = record.fs  # 采样频率
    F_low = 40  # 截止频率
    F_high = 0.5
    W_low = 2 * F_low / Fs  # 归一化截止频率
    W_high = 2 * F_high / Fs
    b, a = signal.butter(5, W_low, 'lowpass')
    series = signal.filtfilt(b, a, series)  # data为要过滤的信号
    b, a = signal.butter(5, W_high, 'highpass')
    series = signal.filtfilt(b, a, series)  # data为要过滤的信号
    series = signal.medfilt(series, kernel_size=None)
    series = abs(min(series)) + series
    series = series / max(series)
    record.p_signal[:, 0] = series
    return record


def myfilter(x, low=0, high=0, fs=250, Order = 5,**kwargs):
    '''
    简易巴德沃斯带通滤波器
    参数
    x       输入信号
    fs      输入信号的采样频率
    low     低通截至频率  40
    high    高通截至频率  0.4
    输出     采样后的数据
    作者     starhou
    日期     2019.1.23
    邮箱     1029588176@qq.com
    '''
    x=x.T
    Fs = fs  # 采样频率
    if not low==0:
        F_low = low  # 截止频率
        W_low = 2 * F_low / Fs  # 归一化截止频率
        b, a = signal.butter(Order, W_low, 'lowpass')
        x = signal.filtfilt(b, a, x)  #
    if not high==0:
        F_high = high
        W_high = 2 * F_high / Fs
        #data为要过滤的信号
        b, a = signal.butter(Order, W_high, 'highpass')
        x = signal.filtfilt(b, a, x)
        # data为要过滤的信号
    return x.T


def resamH(data, fs, tofs):
    '''
    重采样
    参数
    data   要滤波的数据
    fs     原始数据的采样频率
    tofs   重采样频率
    输出
    y      重采样后的数据
    num    画图用的x轴数据
    作者     starhou
    日期     2019.1.23
    邮箱     1029588176@qq.com
    '''
    lendata = np.shape(data)[0]
    num = lendata//fs*tofs
    x = np.linspace(0, num, num)
    x = np.array(x).reshape(num,-1)
    y = signal.resample(data, num)
    return y, x


def standardH(data, mode):
    '''
    对数据进行标准化
    参数
    data 输入的数据numpy数组
    mode 模式 1，2，3
         1 min-max标准化
         2 z-score标准化
         3 简单标准化
    输出
    标准化后的数据
    作者     starhou
    日期     2019.1.23
    邮箱     1029588176@qq.com
    '''
    for x in range(0, len(data)+1):
        if mode == 0:
            data[x] = float(data[x]-data.mean())/data.std()
        if mode == 1:
            data[x] = float(data[x] - np.min(data)) / \
                (np.max(data) - np.min(data))
        if mode == 2:
            data[x] += np.abs(np.min(data))
            data[x] = data[x] / np.max(data)
    return data
    
def GradientSignal(x):
    '''
    求信号梯度
    参数
    x      要求梯度的数据
    输出
    y       梯度化数据
    作者     starhou
    日期     2019.1.23
    邮箱     1029588176@qq.com
    '''
    x=np.array(x)
    x=np.diff(x)
    x=np.insert(x,-1,x[-1])
    return x
    
def smooth(a,WSZ):
    # a:原始数据，NumPy 1-D array containing the data to be smoothed
    # 必须是1-D的，如果不是，请使用 np.ravel()或者np.squeeze()转化 
    # WSZ: smoothing window size needs, which must be odd number,
    # as in the original MATLAB implementation
    out0 = np.convolve(a,np.ones(WSZ,dtype=int),'valid')/WSZ
    r = np.arange(1,WSZ-1,2)
    start = np.cumsum(a[:WSZ-1])[::2]/r
    stop = (np.cumsum(a[:-WSZ:-1])[::2]/r)[::-1]
    return np.concatenate((start , out0, stop))


def zac_whiten(x):
	  #n行，m列，m为samples
    cov=np.dot(x.T, x)/x.shape[1]
    U, S, V = np.linalg.svd(cov)  
     # U 是 covMat 的特征向量矩阵，S 是其特征值矩阵；因为 covMat 是对称方#  阵，所以 V=U'，covMat=USV
    S=np.diag(1.0/np.sqrt(S+1e-5))
    xT=np.dot(x,U.T)
    y=np.dot(S,xT.T)
    z=np.dot(U,y)
    return z

def PCA_whitening(x):
    cov=np.dot(x.T, x)/x.shape[1]
    U, S, V = np.linalg.svd(cov)  
    # U 是 covMat 的特征向量矩阵，S 是其特征值矩阵；因为 covMat 是对称方#  阵，所以 V=U'，covMat=USV
    S=np.diag(1.0/np.sqrt(S+1e-5))
    xT=np.dot(x,U.T)
    y=np.dot(S,xT.T)
    return y


def baselineRemove(x:np.array)->np.array:
  '''
  去出信号中直流成分
  参数
  x      要求的数据
  输出
  x       
  作者     starhou
  日期     2020.5.8
  邮箱     1029588176@qq.com
  '''
  baseline = smooth(x,125)
  x = x-baseline
  return x
    
# encoding: UTF-8
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Date    : 2019-11-08 11.47
# @Author  : Star Monkey (houxin@fudan.edu.cn)
# @Link    : 1029588176@qq.com
# @Version : $1$
def LMS_filter(xn,dn,M,mu):
	#自适应LMS算法
	# input 
	#    xn    输入信号一维序列
	#    dn    教师 长度小于等于xn
	#    M     滤波器阶数 正整数
	#    mu    步长 0<mu<1
	# output
	#    y       预测序列
	#    e       误差序列
	# Call: e,y=LMS_filter(xn,dn,M,mu)
	# Author: starhou
	# Email:1029588176@qq.com
	# Date: 2019.11.28
	import numpy as np
	xn = np.array(xn)
	dn = np.array(dn)
	e = np.zeros_like(dn)
	w = np.zeros((M,1))
	pre=np.zeros_like(dn)
	for i,x in enumerate(dn):
		x= xn[np.arange(i+M,i,-1)]
		y = np.dot(w.T,x)
		e[i+M] = dn[i+M]-y
		w = w+2*mu*e[i+M]*x
		if i==dn.shape[0]-M-1:break
	return y,e
    
def SinWave(T,fs,L):
    '''
    正弦波发生器
    input: 
           T  几个周期
           fs 采样频率
           L  几s长 (s)
    output:
           y 方波
    '''
    t = np.linspace(-T*np.pi, T*np.pi, L*fs)
    f = np.sin(t)
    return t,f

### 第三部分 常用的信号处理方法

myemd是自己写的版本

[copyemd](https://github.com/parkus/emd)

[MEMD](https://github.com/mariogrune/MEMD-Python-/blob/master/MEMD_all.py)

In [None]:
def GradientSignal(x):
  '''
  求信号梯度

  input   x     信号段 

  output  gradient  原信号的梯度

  example:
  gradient = gradientSignal(x)

  author：star hou  2019.8.9
  email: 1029588176@qq.com
  '''
  gradient = np.zeros((x.shape[0]-2,))
  peak = np.diff(x)
  for i in range(peak.shape[0]):
    if i<peak.shape[0]-1:
      gradient[(i)] = peak[(i)]*peak[(i+1)]
  gradient = np.append(gradient,0)
  gradient = np.insert(gradient,0,0)
  return gradient

def diffSignal(x:np.array)->np.array:
  '''
  求信号的差分

  input   x     信号段 

  output   y  原信号的差分

  example:
  gradient = diffSignal(x)

  author：star hou  2020.5.8
  email: 1029588176@qq.com
  '''
  y = np.diff(x)
  y = np.append(y, y[-1])
  return y  
def integralSignal(data:np.array)->np.array:
  '''
  求信号的积分

  input   x     信号段 

  output   y  原信号的积分

  example:
  gradient = integralSignal(x)

  author：star hou  2020.5.8
  email: 1029588176@qq.com
  '''
  fs = 250
  K = 0.10
  mid = data[:-1]+data[1:]
  mid = np.insert(mid,-1,mid[-1])
  y = np.zeros_like(mid)
  for i in range(data.size):
    y[i] = sum(mid[:i+1])/(K*fs)
  return y

def timeFrequencyDiagram(data:np.array,f:int=6):
  '''
  求信号时频分析
  小波变换 'mor1'

  input   data     信号段
        f      调节频率

  output  cwtmatr,    小波系数 
       frequencies   频率
  example:
  timeFrequencyDiagram(data)

  author：star hou  2019.8.9
  email: 1029588176@qq.com
  '''
  import pywt
  import numpy as np
  import matplotlib.pyplot as plt
  import matplotlib
  #from matplotlib.font_manager import FontProperties
  
  #chinese_font = FontProperties(fname='XXX.ttc')
  sampling_rate = 250

  t = np.arange(0, data.size//sampling_rate, 1.0 / sampling_rate)

  

  wavename = 'morl'

  fc = pywt.central_frequency(wavename)

  [cwtmatr, frequencies] = pywt.cwt(data, np.arange(f, 31, 0.1), wavename, 1.0 / sampling_rate)
  plt.figure(figsize=(8, 4))
  plt.subplot(211)
  plt.plot(t, data)
  plt.xlabel(u"Time(Second)") # fontproperties=chinese_font
  plt.title(u"300Hz 200Hz 100Hz Time spectrum")
  plt.subplot(212)
  cntr2 = plt.contourf(t, frequencies, abs(cwtmatr))
  plt.colorbar(cntr2)
  plt.ylabel(u"Frequency(Hz)")
  plt.xlabel(u"Time(Second)")
  plt.subplots_adjust(hspace=0.4)
  plt.show()
  # return cwtmatr, frequencies

def SignalSpectrum(x, fs = 250, figname = "signal", verbose=0):
  '''
  求信号频谱

  input   x     信号段
       fs     信号采样频率
       verbose   是否可视化

  output  fmax  信号主频

  example:
  absY = signalSpectrum(x,fs)

  author：star hou  2019.8.9
  email: 1029588176@qq.com
  '''
  
  fft_size = x.shape[0]    # 采样点数
  # 利用np.fft.rfft()进行FFT计算，rfft()是为了更方便对实数信号进行变换，
  # 由公式可知/fft_size为了正确显示波形能量
  xf = np.fft.rfft(x)/fft_size 
  # rfft函数的返回值是N/2+1个复数，分别表示从0(Hz)到sampling_rate/2(Hz)的部分。
  #于是可以通过下面的np.linspace计算出返回值中每个下标对应的真正的频率：
  freqs = np.linspace(0, int(fs/2), int(fft_size/2+1))
  xfp = np.abs(xf)
  index = np.where(freqs<=30)[0]
  fmax = np.argmax(xfp)

  top_k = 3
  
  maximum, minimum, maximumID, minimumID = findpeaks(xfp,1)

  top3 = maximum.argsort()[::-1][0:top_k]
  
  #绘图显示结果
  if verbose:
    plt.figure(figsize=(8,4))
    plt.subplot(211)
    plt.plot(x)
    plt.xlabel(u"Time(S)")
    plt.title(figname)
    plt.subplot(212)
    plt.plot(freqs[index], xfp[index])
    plt.xlabel(u"Freq(Hz)")
    # plt.subplots_adjust(hspace=0.4)
    # plt.plot(freqs[top3], xfp[top3], 'ro')
    plt.plot(freqs[maximumID[top3]], maximum[top3], 'ro')
    plt.xlabel(u"Freq(Hz)")
    plt.subplots_adjust(hspace=0.4)
    # plt.show()
    # freqs[top3], xfp[top3]

  return freqs[maximumID[top3]], maximum[top3], freqs[index], xfp[index]

#   return absY

def  SignalEnvelope(x):
  '''
  求信号包络

  input   x     信号段

  output  envelope  信号包络

  example:
  envelope = SignalEnvelope(x)

  author：star hou  2019.8.9
  email: 1029588176@qq.com
  '''
  from scipy import fftpack
  a = x
  b = fftpack.hilbert(x)
  envelope = np.sqrt(pow(a,2)+pow(b,2))
  return envelope

def autocorrelation(x,lags):
  #求信号自相关
  x = np.array(x)
  n = x.shape[0]
  variance = x.var()
  x = x-x.mean()
  result = np.correlate(x, x, mode = 'full')[-n+1:-n+lags+1]/\
    (variance*(np.arange(n-1,n-1-lags,-1)))
  return result

In [None]:
def autoCorrelate(x:np.array,y:np.array,mode:str = 'circle'):
  '''
  ref. https://fanyublog.com/2015/11/16/corr_python/

  计算两个序列的互相关
  mode: linear 线性互相关  相关系数
      cycle 循环互相关  
  Author：starhou
  Date：2020.5.8
  E-mail：1029588176@qq.com
  '''
  import numpy as np
  from scipy import fftpack
  nom = np.linalg.norm(x[:])*np.linalg.norm(y[:])
  if mode == 'circle':
    cor = fftpack.irfft(fftpack.rfft(x)*fftpack.rfft(y[::-1]))
  else:
    cor = np.correlate(x,y,mode='valid')
  return np.abs(cor/nom)

###### myemd

In [None]:
def cubic_spline_3pts(x, y, T):
    """
    Apperently scipy.interpolate.interp1d does not support
    cubic spline for less than 4 points.
    """

    x0, x1, x2 = x
    y0, y1, y2 = y

    x1x0, x2x1 = x1-x0, x2-x1
    y1y0, y2y1 = y1-y0, y2-y1
    _x1x0, _x2x1 = 1./x1x0, 1./x2x1

    m11, m12, m13= 2*_x1x0, _x1x0, 0
    m21, m22, m23 = _x1x0, 2.*(_x1x0+_x2x1), _x2x1
    m31, m32, m33 = 0, _x2x1, 2.*_x2x1

    v1 = 3*y1y0*_x1x0*_x1x0
    v3 = 3*y2y1*_x2x1*_x2x1
    v2 = v1+v3

    M = np.array([[m11,m12,m13],[m21,m22,m23],[m31,m32,m33]])
    v = np.array([v1,v2,v3]).T
    k = np.array(np.linalg.inv(M).dot(v))

    a1 = k[0]*x1x0 - y1y0
    b1 =-k[1]*x1x0 + y1y0
    a2 = k[1]*x2x1 - y2y1
    b2 =-k[2]*x2x1 + y2y1

    t = T[np.r_[T>=x0] & np.r_[T<=x2]]
    t1 = (T[np.r_[T>=x0]&np.r_[T< x1]] - x0)/x1x0
    t2 = (T[np.r_[T>=x1]&np.r_[T<=x2]] - x1)/x2x1
    t11, t22 = 1.-t1, 1.-t2

    q1 = t11*y0 + t1*y1 + t1*t11*(a1*t11 + b1*t1)
    q2 = t22*y1 + t2*y2 + t2*t22*(a2*t22 + b2*t2)
    q = np.append(q1,q2)

    return t, q
def findpeaks(x,choose,seg):
  '''
  maximum, minimum, maximumID, minimumID = findpeaks(x)
  寻找信号极值
  x 输入信号
  choose 是否按下述准则筛选
  seg  每隔n个点取一个极值

  综上，胸外按压的深度为至少5厘米而不大于6厘米，频率为至少100次/分~120次/分；按压通气比仍为30：2.

  120bpm时，应该有16个极值点，125个点取一个最大极值点 8s段信号 每秒5/3~2个
  '''
  # 找到所有极值
  df = np.diff(x)
  dff = df[:-1]*df[1:]
  peakID = np.where(dff<0)[0]+1
  maximumID = []
  minimumID = []
  for i in peakID:
    if x[i]>x[i-1]:
      maximumID.append(i)
    else:
      minimumID.append(i)
   
  maximum = x[maximumID]
  minimum = x[minimumID]

  if choose:
    mapping = dict(zip(minimumID,minimum))
    i = 0
    state = -1
    miniout = []
    for index in minimumID:
      stateNow = index//seg
      if stateNow==state:
        if mapping[index]<mapping[minnow]:
          minnow = index
      else:
        if state>-1:
          miniout.append(minnow)
        minnow = index
        state = stateNow
    minimumID = miniout   
    minimum = x[miniout]
  return maximum, minimum, np.array(maximumID), np.array(minimumID), 

  # # 筛选符合条件的
  # i = 125
  # maxi = []
  # mini = []
  # while i < x.size:
  #   maxi.apend(np.argmax(maximum[i-125,i]))
  #   mini.apend(np.argmax(minimum[i-125,i]))

#判断当前的序列是否为 IMF 序列
def isImf(x, peakNum):
  passZero = np.sum(x[:-1]*x[1:]<0) #过零点的个数
  if abs(passZero-peakNum)>1:
      return False
  else:
      return True

def isStop(x, IMF, range_thr=0.001, total_power_thr=0.005):
  IMF = np.array(IMF)
  tmp = x - np.sum(IMF, axis=0)

  if np.max(tmp) - np.min(tmp) < range_thr:
      return True

  if np.sum(np.abs(tmp)) < total_power_thr:
      return True

  return False

#获取当前样条曲线
def getspline(x):
  from scipy import interpolate 
  maximum, minimum, maximumID, minimumID = findpeaks(x)
  print(maximum.size,minimum.size)
  peakNum = maximum.size + minimum.size
  t = np.linspace(0,x.size-1,x.size)
  if maximum.size>3 and minimum.size>3:
    maxSpline = interpolate.splrep(maximumID, maximum)
    minSpline = interpolate.splrep(minimumID, minimum)
    return interpolate.splev(t,maxSpline), interpolate.splev(t,minSpline), peakNum
  else:
    return 0,0, 3


def emd(x):
  imf=[]
  padding = int(x.size/4)
  x = np.concatenate((x[0]*np.zeros((padding,)),x,x[-1]*np.zeros((padding,))))
  # while not isStop(x,imf):
  while 1:
    x1 = x
    sd = np.inf
    peakNum = x.size
    while sd>0.1 and (not isImf(x1,peakNum)):
      minSpline, maxSpline, peakNum = getspline(x1)      
      if peakNum==3:
        imf.append(x[padding:-padding])
        return imf  
      x2=x1-(minSpline+maxSpline)/2
      sd=np.sum((x1-x2)**2)/np.sum(x1**2)
      x1=x2  
    imf.append(x1[padding:-padding])
    x=x-x1
  imf.append(x[padding:-padding])
  return imf

###### copyemd


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

def saw_sift(t, y, bc='extrap', tol=0.0):
    """
    Use the sawtooth transform to find the dominant intrinsic mode in the data.
    Parameters
    ----------
    t : 1D array-like
        The independent data, length N.
    y : 1D array-like
        The dependent data, length N.
    bc : {'auto'|'even'|'odd'|'periodic'|'extend'}, optional
        extrap :
            default. extrapolate the envelopes from the last two extram as
            necessary
        even :
            use the endpoints as extrema and reflect the nearest exrema
            to extend the opposing envelope
        odd :
            reflect and flip the nearest two extrema about each endpoint
            without using the endpoint as an extremum (like an odd function
            with the endpoint as the origin) to extrapolate both envelopes
        periodic :
            treat the function (thus extrema) as periodic to append
            the necessary extra extrema
        tol : float
            tolerance. changes between points below this level are set to zero.
    Returns
    -------
    h : 1D array
        The intrinsic mode, length N.
    References
    ----------
    http://arxiv.org/pdf/0710.3170.pdf
    """
    t, y = map(np.asarray, [t, y])

    # identify the relative extrema
    argext = _allrelextrema(y, tol=tol)

    # if there are too few extrema, raise an exception
    if len(argext) < 2:
        raise FlatFunction('Too few relative max and min to sift the series')

    # parse out the relative extrema
    T = t[argext]
    E = y[argext]

    ## add extra extrema as necessary for boundary conditions
    if bc == 'extrap':
        if len(argext) < 4:
            raise FlatFunction('Too few relative max and min to sift the series')
        t0, tn1 = t[0], t[0]
        E0 = (E[3] - E[1]) / (T[3] - T[1]) * (t0 - T[1]) + E[1]
        En1 = (E[2] - E[0]) / (T[2] - T[0]) * (tn1 - T[0]) + E[0]
        tmn1, tm = t[-1], t[-1]
        Emn1 = (E[-4] - E[-2]) / (T[-4] - T[-2]) * (tmn1 - T[-2]) + E[-2]
        Em = (E[-3] - E[-1]) / (T[-3] - T[-1]) * (tm - T[-1]) + E[-1]
    elif bc == 'even':
        t0, E0 = t[0], y[0]
        tn1, En1 = _reflect(t[0], T[0]), E[0]
        tmn1, Emn1 = t[-1], y[-1]
        tm, Em = _reflect(t[-1], T[-1]), E[-1]
    elif bc == 'odd':
        t0, tn1 = _reflect(t[0], T[:2])
        E0, En1 = _reflect(y[0], E[:2])
        tm, tmn1 = _reflect(t[-1], T[-2:])
        Em, Emn1 = _reflect(y[-1], E[-2:])
    elif bc == 'periodic':
        if _oppsign(y[1] - y[0], y[0] - y[-1], tol):
            # left endpt is a relative extremum
            t0, E0 = t[0], y[0]
            tn1, En1 = t[0] - (t[-1] - T[-1]), E[-1]
        else:
            t0, E0 = t[0] - (t[-1] - T[-1]), E[-1]
            tn1, En1 = t[0] - (t[-1] - T[-2]), E[-2]
        if _oppsign(y[-1] - y[-2], y[0] - y[-1], tol):
            # right endpt is a relative extremum
            tmn1, Emn1 = t[-1], y[-1]
            tm, Em = t[-1] + (T[0] - t[0]), E[0]
        else:
            tmn1, Emn1 = t[-1] + (T[0] - t[0]), E[0]
            tm, Em = t[-1] + (T[1] - t[0]), E[1]
    else:
        raise ValueError('Boundary condition (bc) not understood.')

    # add the boundary points to the extrema
    N = len(T)
    T = np.insert(T, [0, 0, N, N], [tn1, t0, tmn1, tm])
    E = np.insert(E, [0, 0, N, N], [En1, E0, Emn1, Em])
    argext = np.insert(argext, [0, 0, N, N], [0, 0, len(t), len(t)])

    # parse the saw function and envelope points
    Tsaw, Ysaw = T[1:-1], E[1:-1]
    env1 = np.interp(Tsaw, T[::2], E[::2])
    env2 = np.interp(Tsaw, T[1::2], E[1::2])

    # subtract envelope mean from sawtooth (sawtooth is the extrema)
    env_mean = (env1 + env2) / 2.0
    Hsaw = Ysaw - env_mean

    # transform from sawtooth to data space
    u = _saw_transform(t, y, T, E, argext)
    h = np.interp(u, Tsaw, Hsaw)

    return h

def _saw_transform(t, y, T, E, argext):
    """Return the sawtooth transform of the t coordinate."""
    u = []
    for i in range(1, len(argext) - 2):
        piece = slice(argext[i], argext[i+1])
        upiece = (T[i] + (y[piece] - E[i]) / (E[i+1] - E[i])
                    * (T[i+1] - T[i]))
        u.extend(upiece)
    return np.array(u)

def saw_emd(t, y, Nmodes=None, bc='extrap', tol=1e-10):
    """
    Decompose function into "intrinsic modes" using the decomposition method
    of Lu [1].
    Parameters
    ----------
    t : 1D array-like
        The independent data, length N.
    y : 1D array-like
        The dependent data, length N.
    Nmodes : int, optional
        The maximum number of modes to return.
    bc : {'auto'|'even'|'odd'|'periodic'|'extend'}, optional
        extrap :
            default. extrapolate the envelopes from the last two extram as
            necessary
        even :
            use the endpoints as extrema and reflect the nearest exrema
            to extend the opposing envelope
        odd :
            reflect and flip the nearest two extrema about each endpoint
            without using the endpoint as an extremum (like an odd function
            with the endpoint as the origin) to extrapolate both envelopes
        periodic :
            treat the function (thus extrema) as periodic to append
            the necessary extra extrema
    tol : float
        tolerance relative to the initial range of the function. decomposition
        will stop once wiggles in y are below this level.
    Returns
    -------
    c : 2D array
        An NxM array giving M emprical modes as columns.
    r : 1D array
        The residual, length N.
    References
    ----------
    [1] http://arxiv.org/pdf/0710.3170.pdf
    Notes
    -----
    The function does not properly handle the special (and presumably rare)
    case where two consecutive, identical points form a relative maximum or
    minimum in the supplied data.
    """
    # groom the input
    t, y = map(np.asarray, [t, y])
    if t.ndim > 1:
        raise ValueError("t array must be 1D")
    if y.ndim > 1:
        raise ValueError("y array must be 1D")

    atol = tol * (np.max(y) - np.min(y))

    c = []
    r = np.copy(y)
    while True:
        try:
            h = saw_sift(t, r, bc=bc, tol=atol)
            c.append(h)
            r = r - h
        except FlatFunction: #if the residue has too few extrema
            break
        if len(c) == Nmodes:
            break

    return np.transpose(c), r

def emd(t, y, Nmodes=None):
    """
    Decompose function into "intrinsic modes" using empirical mode
    decompisition from Huang et al. [1].
    Parameters
    ----------
    t : 1D array-like
        The independent data, length N.
    y : 1D array-like
        The dependent data, length N.
    Nmodes : int, optional
        The maximum number of modes to return.
    Returns
    -------
    c : 2D array
        An NxM array giving M emprical modes as columns.
    r : 1D array
        The residual, length N.
    References
    ----------
    [1] Huang et al. (1998; RSPA 454:903)
    Notes
    -----
    The function does not properly handle the special (and presumably rare)
    case where two consecutive, identical points form a relative maximum or
    minimum in the supplied data.
    """

    # groom the input
    t, y = map(np.asarray, [t, y])
    if t.ndim > 1:
        raise ValueError("t array must be 1D")
    if y.ndim > 1:
        raise ValueError("y array must be 1D")

    c = np.empty([len(y), 0])
    h, r = map(np.copy, [y, y])
    hold = np.zeros(y.shape)
    while True:
        try:
            while True:
                h = sift(t, h)
                var = np.sum((h-hold)**2 / hold**2)
                if var < 0.25:
                    c = np.append(c, h[:, np.newaxis], axis=1)
                    r = r - h

                    #if the user doesn't want any more modes
                    if len(c) == Nmodes:
                        return c, r

                    h = r
                    hold = np.zeros(y.shape)
                    break
                hold = h
        except FlatFunction: #if the residue has too few extrema
            return c, r

class FlatFunction(Exception):
    pass

def sift(t, y, nref=100, plot=False):
    """
    Identify the dominant "intinsic mode" in a series of data by fitting
    spline envelopes to the extrema.
    Parameters
    ----------
    t : 1D array-like
        The independent data, length N.
    y : 1D array-like
        The dependent data, length N.
    nref : int, optional
        Number of extema to reflect about each end when fitting splines.
    plot : {True|False}, optional
        If True, create a diagnostic plot of the function and results using
        the matplotlib.pyplot plot functions. If there is alread a plot
        window active, the plotting will be done there. Plot handles are not
        returned.
    Returns
    -------
    h : 1D array
        The intrinsic mode, length N.
    Summary
    -------
    Identifies the relative max and min in the series, fits spline curves
    to these to estimate an envelope, then subtracts the mean of the envelope
    from the series. The difference is then returned. The extrema are refelcted
    about the extrema nearest each end of the series to mitigate end
    effects, where nref controls the maximum total number of extrema (max and
    min) that are reflected.
    References
    ----------
    Huang et al. (1998; RSPA 454:903)
    """

    # identify the relative extrema
    argext = _allrelextrema(y)

    # if there are too few extrema, raise an exception
    if len(argext) < 2:
        raise FlatFunction('Too few max and min in the series to sift')

    # include the left and right endpoints as extrema if they are beyond the
    # limits set by the nearest two extrema
    inclleft = not _inrange(y[[0]], y[argext[0]], y[argext[1]])
    inclright = not _inrange(y[[-1]], y[argext[-2]], y[argext[-1]])
    if inclleft and inclright: argext = np.concatenate([[0],argext,[-1]])
    if inclleft and not inclright: argext = np.insert(argext,0,0)
    if not inclleft and inclright: argext = np.append(argext,-1)
    #if neither, do nothing

    # now reflect the extrema about both sides
    T, E  = t[argext], y[argext]
    tleft, yleft = T[0] - (T[nref:0:-1] - T[0]) , E[nref:0:-1]
    tright, yright = T[-1] + (T[-1] - T[-2:-nref-2:-1]), E[-2:-nref-2:-1]
    tall = np.concatenate([tleft, T, tright])
    yall = np.concatenate([yleft, E, yright])

    # parse out the min and max. the extrema must alternate, so just figure out
    # whether a min or max comes first
    if yall[0] < yall[1]:
        tmin, tmax, ymin, ymax = tall[::2], tall[1::2], yall[::2], yall[1::2]
    else:
        tmin, tmax, ymin, ymax = tall[1::2], tall[::2], yall[1::2], yall[::2]

    # check again if there are enough extrema, now that the endpoints may have
    # been added
    if len(tmin) < 4 or len(tmax) < 4:
        raise FlatFunction('Too few max and min in the series to sift')

    # compute spline enevlopes and mean
    spline_min, spline_max = map(interp1d, [tmin,tmax], [ymin,ymax], ['cubic']*2)
    m = (spline_min(t) + spline_max(t))/2.0
    h = y - m

    if plot:
        plt.plot(t, y, '-', t, m, '-')
        plt.plot(tmin, ymin, 'g.', tmax, ymax, 'k.')
        tmin = np.linspace(tmin[0], tmin[-1], 1000)
        tmax = np.linspace(tmax[0], tmax[-1], 1000)
        plt.plot(tmin, spline_min(tmin), '-r', tmax, spline_max(tmax), 'r-')

    return h

def _allrelextrema(y, tol=0.0):
    """
    Finds all of the relative extrema in order in about half the time
    as using the scipy.signal.argrel{min|max} functions and combining the
    results. The scipy.signal version also misses multi-point max and mins.
    This version returns the midpoint of multi point extrema, or the point
    just left of the middle for multi-point exrtema.
    """

    # compute difference between successive values (like the slope)
    slope = np.diff(y)
    slope[np.abs(slope) < tol] = 0.0

    # remove all zeros while tracking original indices
    nonzero = (slope != 0.0)
    slope = slope[nonzero]
    indices = np.arange(len(y) - 1)
    indices = indices[nonzero]

    # we just want the sign of the slope
    slope_sign = np.zeros(len(slope), 'i1')
    slope_sign[slope > 0] = 1
    slope_sign[slope < 0] = -1

    # so that we can find the sign of the curvature at points with differing
    # slope signs to either side
    curve_sign = np.diff(slope_sign)
    arg_curve_chng = np.nonzero(curve_sign != 0)[0]
    i0 = indices[arg_curve_chng]
    i1 = indices[arg_curve_chng + 1] + 1
    i = np.floor((i0 + i1) / 2.0)

    return i.astype(int)

def _inrange(y, y0, y1):
    """
    Return True if y is within the range (y0, y1).
    """
    if y0 > y1:
        return (y < y0) and (y > y1)
    else:
        return (y > y0) and (y < y1)

def _reflect(x0, x):
    return x0 - (x - x0)

def _oppsign(x, y, tol):
    return (x < -tol and y > tol) or (x > tol and y < -tol)

###### MEMD

In [None]:
import numpy as np
from scipy.interpolate import interp1d,CubicSpline
from math import pi,sqrt,sin,cos
import warnings
import sys



# =============================================================================

def hamm(n,base):
    seq = np.zeros((1,n))
    
    if 1 < base:
        seed = np.arange(1,n+1)
        base_inv = 1/base
        while any(x!=0 for x in seed):
            digit = np.remainder(seed[0:n],base)
            seq = seq + digit*base_inv
            base_inv = base_inv/base
            seed = np.floor (seed/base)
    else:
        temp = np.arange(1,n+1)
        seq = (np.remainder(temp,(-base+1))+0.5)/(-base)
        
    return(seq)

# =============================================================================

def zero_crossings(x):
    indzer = np.where(x[0:-1]*x[1:]<0)[0]
    
    if any(x == 0):
        iz = np.where(x==0)[0]
        if any(np.diff(iz)==1):
            zer = x == 0
            dz = np.diff([0,zer,0])
            debz = np.where(dz == 1)[0]
            finz = np.where(dz == -1)[0]-1
            indz = np.round((debz+finz)/2)
        else:
            indz = iz
        indzer = np.sort(np.concatenate((indzer,indz)))
        
    return(indzer)

# =============================================================================

#defines new extrema points to extend the interpolations at the edges of the
#signal (mainly mirror symmetry)
def boundary_conditions(indmin,indmax,t,x,z,nbsym):
    lx = len(x)-1
    end_max = len(indmax)-1
    end_min = len(indmin)-1
    indmin = indmin.astype(int)
    indmax = indmax.astype(int)

    if len(indmin) + len(indmax) < 3:
        mode = 0
        tmin=tmax=zmin=zmax=None
        return(tmin,tmax,zmin,zmax,mode)
    else:
        mode=1 #the projected signal has inadequate extrema
    #boundary conditions for interpolations :
    if indmax[0] < indmin[0]:
        if x[0] > x[indmin[0]]:
            lmax = np.flipud(indmax[1:min(end_max+1,nbsym+1)])
            lmin = np.flipud(indmin[:min(end_min+1,nbsym)])
            lsym = indmax[0]

        else:
            lmax = np.flipud(indmax[:min(end_max+1,nbsym)])
            lmin = np.concatenate((np.flipud(indmin[:min(end_min+1,nbsym-1)]),([0])))
            lsym = 0

    else:
        if x[0] < x[indmax[0]]:
            lmax = np.flipud(indmax[:min(end_max+1,nbsym)])
            lmin = np.flipud(indmin[1:min(end_min+1,nbsym+1)])
            lsym = indmin[0]

        else:
            lmax = np.concatenate((np.flipud(indmax[:min(end_max+1,nbsym-1)]),([0])))
            lmin = np.flipud(indmin[:min(end_min+1,nbsym)])
            lsym = 0

    if indmax[-1] < indmin[-1]:
        if x[-1] < x[indmax[-1]]:
            rmax = np.flipud(indmax[max(end_max-nbsym+1,0):])
            rmin = np.flipud(indmin[max(end_min-nbsym,0):-1])
            rsym = indmin[-1]

        else:
            rmax = np.concatenate((np.array([lx]),np.flipud(indmax[max(end_max-nbsym+2,0):])))
            rmin = np.flipud(indmin[max(end_min-nbsym+1,0):])
            rsym = lx

    else:
        if x[-1] > x[indmin[-1]]:
            rmax = np.flipud(indmax[max(end_max-nbsym,0):-1])
            rmin = np.flipud(indmin[max(end_min-nbsym+1,0):])
            rsym = indmax[-1]

        else:
            rmax = np.flipud(indmax[max(end_max-nbsym+1,0):])
            rmin = np.concatenate((np.array([lx]),np.flipud(indmin[max(end_min-nbsym+2,0):])))
            rsym = lx

    tlmin = 2*t[lsym]-t[lmin]
    tlmax = 2*t[lsym]-t[lmax]
    trmin = 2*t[rsym]-t[rmin]
    trmax = 2*t[rsym]-t[rmax]

    #in case symmetrized parts do not extend enough
    if tlmin[0] > t[0] or tlmax[0] > t[0]:
        if lsym == indmax[0]:
            lmax = np.flipud(indmax[:min(end_max+1,nbsym)])
        else:
            lmin = np.flipud(indmin[:min(end_min+1,nbsym)])
        if lsym == 1:
            sys.exit('bug')
        lsym = 0
        tlmin = 2*t[lsym]-t[lmin]
        tlmax = 2*t[lsym]-t[lmax]
        
    if trmin[-1] < t[lx] or trmax[-1] < t[lx]:
        if rsym == indmax[-1]:
            rmax = np.flipud(indmax[max(end_max-nbsym+1,0):])
        else:
            rmin = np.flipud(indmin[max(end_min-nbsym+1,0):])
        if rsym == lx:
            sys.exit('bug')
        rsym = lx
        trmin = 2*t[rsym]-t[rmin]
        trmax = 2*t[rsym]-t[rmax]

    zlmax =z[lmax,:]
    zlmin =z[lmin,:]
    zrmax =z[rmax,:]
    zrmin =z[rmin,:]

    tmin = np.hstack((tlmin,t[indmin],trmin))
    tmax = np.hstack((tlmax,t[indmax],trmax))
    zmin = np.vstack((zlmin,z[indmin,:],zrmin))
    zmax = np.vstack((zlmax,z[indmax,:],zrmax))

    return(tmin,tmax,zmin,zmax,mode)

# =============================================================================

# computes the mean of the envelopes and the mode amplitude estimate
def envelope_mean(m,t,seq,ndir,N,N_dim): #new

    NBSYM = 2
    count = 0

    env_mean=np.zeros((len(t),N_dim))
    amp = np.zeros((len(t)))
    nem = np.zeros((ndir))
    nzm = np.zeros((ndir))
    
    dir_vec = np.zeros((N_dim,1))
    for it in range(0,ndir):
        if N_dim !=3:     # Multivariate signal (for N_dim ~=3) with hammersley sequence
            #Linear normalisation of hammersley sequence in the range of -1.00 - 1.00
            b=2*seq[it,:]-1 
            
            # Find angles corresponding to the normalised sequence
            tht = np.arctan2(np.sqrt(np.flipud(np.cumsum(b[:0:-1]**2)))\
                             ,b[:N_dim-1]).transpose()
            
            # Find coordinates of unit direction vectors on n-sphere
            dir_vec[:,0] = np.cumprod(np.concatenate(([1],np.sin(tht))))
            dir_vec[:N_dim-1,0] =  np.cos(tht)*dir_vec[:N_dim-1,0]
            
        else:     # Trivariate signal with hammersley sequence
            # Linear normalisation of hammersley sequence in the range of -1.0 - 1.0
            tt = 2*seq[it,0]-1
            if tt>1:
                tt=1
            elif tt<-1:
                tt=-1         
            
            # Normalize angle from 0 - 2*pi
            phirad = seq[it,1]*2*pi
            st = sqrt(1.0-tt*tt)
            
            dir_vec[0]=st*cos(phirad)
            dir_vec[1]=st*sin(phirad)
            dir_vec[2]=tt
           
        # Projection of input signal on nth (out of total ndir) direction vectors
        y  = np.dot(m,dir_vec)

        # Calculates the extrema of the projected signal
        indmin,indmax = local_peaks(y)      

        nem[it] = len(indmin) + len(indmax)
        indzer = zero_crossings(y)
        nzm[it] = len(indzer)

        tmin,tmax,zmin,zmax,mode = boundary_conditions(indmin,indmax,t,y,m,NBSYM)
        
        # Calculate multidimensional envelopes using spline interpolation
        # Only done if number of extrema of the projected signal exceed 3
        if mode:
            fmin = CubicSpline(tmin,zmin,bc_type='not-a-knot')
            env_min = fmin(t)
            fmax = CubicSpline(tmax,zmax,bc_type='not-a-knot')
            env_max = fmax(t)
            amp = amp + np.sqrt(np.sum(np.power(env_max-env_min,2),axis=1))/2
            env_mean = env_mean + (env_max+env_min)/2
        else:     # if the projected signal has inadequate extrema
            count=count+1
            
    if ndir>count:
        env_mean = env_mean/(ndir-count)
        amp = amp/(ndir-count)
    else:
        env_mean = np.zeros((N,N_dim))
        amp = np.zeros((N))
        nem = np.zeros((ndir))
        
    return(env_mean,nem,nzm,amp)

# =============================================================================

#Stopping criterion
def stop(m,t,sd,sd2,tol,seq,ndir,N,N_dim):
    try:
        env_mean,nem,nzm,amp = envelope_mean(m,t,seq,ndir,N,N_dim)
        sx = np.sqrt(np.sum(np.power(env_mean,2),axis=1))
        
        if all(amp):     # something is wrong here
            sx = sx/amp
            
        if ((np.mean(sx > sd) > tol or any(sx > sd2)) and any(nem > 2)) == False:
            stp = 1
        else:
            stp = 0
    except:
        env_mean = np.zeros((N,N_dim))
        stp = 1
        
    return(stp,env_mean)
    
# =============================================================================
    
def fix(m,t,seq,ndir,stp_cnt,counter,N,N_dim):
    try:
        env_mean,nem,nzm,amp = envelope_mean(m,t,seq,ndir,N,N_dim)
        
        if all(np.abs(nzm-nem)>1):
            stp = 0
            counter = 0
        else:
            counter = counter+1
            stp = (counter >= stp_cnt)
    except:
        env_mean = np.zeros((N,N_dim))
        stp = 1
        
    return(stp,env_mean,counter)

# =============================================================================

def peaks(X):
    dX = np.sign(np.diff(X.transpose())).transpose()
    locs_max = np.where(np.logical_and(dX[:-1] >0,dX[1:] <0))[0]+1
    pks_max = X[locs_max]
    
    return(pks_max,locs_max)

# =============================================================================

def local_peaks(x):
    if all(x < 1e-5):
        x=np.zeros((1,len(x)))

    m = len(x)-1
    
    # Calculates the extrema of the projected signal
    # Difference between subsequent elements:
    dy = np.diff(x.transpose()).transpose()
    a = np.where(dy!=0)[0]
    lm = np.where(np.diff(a)!=1)[0] + 1
    d = a[lm] - a[lm-1] 
    a[lm] = a[lm] - np.floor(d/2)
    a = np.insert(a,len(a),m)
    ya  = x[a]
    
    if len(ya) > 1:
        # Maxima
        pks_max,loc_max=peaks(ya)
        # Minima
        pks_min,loc_min=peaks(-ya)
        
        if len(pks_min)>0:
            indmin = a[loc_min]
        else:
            indmin = np.asarray([])
            
        if len(pks_max)>0:
            indmax = a[loc_max]
        else:
            indmax = np.asarray([])
    else:
        indmin=np.array([])
        indmax=np.array([])
        
    return(indmin, indmax)

# =============================================================================

def stop_emd(r,seq,ndir,N_dim):
    ner = np.zeros((ndir,1))
    dir_vec = np.zeros((N_dim,1))
    
    for it in range(0,ndir):
        if N_dim != 3: # Multivariate signal (for N_dim ~=3) with hammersley sequence
            # Linear normalisation of hammersley sequence in the range of -1.00 - 1.00
            b=2*seq[it,:]-1
            
            # Find angles corresponding to the normalised sequence
            tht = np.arctan2(np.sqrt(np.flipud(np.cumsum(b[:0:-1]**2)))\
                             ,b[:N_dim-1]).transpose()
            
            # Find coordinates of unit direction vectors on n-sphere
            dir_vec[:,0] = np.cumprod(np.concatenate(([1],np.sin(tht))))
            dir_vec[:N_dim-1,0] =  np.cos(tht)*dir_vec[:N_dim-1,0]
    
        else: # Trivariate signal with hammersley sequence
            # Linear normalisation of hammersley sequence in the range of -1.0 - 1.0
            tt = 2*seq[it,0]-1
            if tt>1:
                tt=1
            elif tt<-1:
                tt=-1  
            
            # Normalize angle from 0 - 2*pi
            phirad = seq[it,1]*2*pi
            st = sqrt(1.0-tt*tt)
            
            dir_vec[0]=st*cos(phirad)
            dir_vec[1]=st*sin(phirad)
            dir_vec[2]=tt
        # Projection of input signal on nth (out of total ndir) direction
        # vectors
        y = np.dot(r,dir_vec)

        # Calculates the extrema of the projected signal
        indmin, indmax = local_peaks(y)

        ner[it] = len(indmin) + len(indmax)
    
    # Stops if the all projected signals have less than 3 extrema
    stp = all(ner<3)
    
    return (stp)

# =============================================================================

def is_prime(x):
    if x == 2:
        return True
    else:
        for number in range (3,x): 
            if x % number == 0 or x % 2 == 0:
         #print number
                return (False)
            
# =============================================================================
                
        return (True)
def nth_prime(n):
    lst = [2]
    for i in range(3,104745):
        if is_prime(i) == True:
            lst.append(i)
            if len(lst) == n:
                return (lst)
# =============================================================================

def set_value(*args):
    args = args[0]
    narg = len(args)
    q = args[0]
    
    ndir=stp_cnt=MAXITERATIONS=sd=sd2=tol = None
    stp_crit,stp_vec,base = [],[],[]
                                          
    if narg == 0:
        sys.exit('Not enough input arguments.')
    elif narg > 4:
        sys.exit('Too many input arguments.')
    elif narg == 1:
        ndir = 64     # default
        stp_crit = 'stop'     # default
        stp_vec = np.array([0.075,0.75,0.075])     # default
        sd,sd2,tol = stp_vec[0],stp_vec[1],stp_vec[2]        
    elif narg == 2:
        ndir = args[1]
        stp_crit = 'stop'     # default
        stp_vec = np.array([0.075,0.75,0.075])     # default
        sd,sd2,tol = stp_vec[0],stp_vec[1],stp_vec[2]
    elif narg == 3:
        if args[1] != None:
            ndir = args[1]
        else:
            ndir = 64     # default
        stp_crit = args[2]
        if stp_crit == 'stop':
            stp_vec = np.array([0.075,0.75,0.075])     # default
            sd,sd2,tol = stp_vec[0],stp_vec[1],stp_vec[2]
        elif stp_crit == 'fix_h':
            stp_cnt = 2     # default
    elif narg == 4:
        if args[1] != None:
            ndir = args[1]
        else:
            ndir = 64     # default        
        stp_crit = args[2]        
        if args[2] == 'stop':
            stp_vec = args[3]
            sd,sd2,tol = stp_vec[0],stp_vec[1],stp_vec[2]
        elif args[2] == 'fix_h':
            stp_cnt = args[3]

    # Rescale input signal if required
    if len(q) == 0:                                                            # Doesn't do the same as the Matlab script
        sys.exit('emptyDataSet. Data set cannot be empty.')
    if np.shape(q)[0] < np.shape(q)[1]:
        q=q.transpose()
        
    # Dimension of input signal
    N_dim = np.shape(q)[1]
    if N_dim < 3:
        sys.exit('Function only processes the signal having more than 3.')
        
    # Length of input signal
    N = np.shape(q)[0]

    # Check validity of Input parameters                                       #  Doesn't do the same as the Matlab script
    if not isinstance(ndir,int) or ndir < 6:
        sys.exit('invalid num_dir. num_dir should be an integer greater than or equal to 6.')
    if not isinstance(stp_crit, str) or (stp_crit != 'stop' and stp_crit != 'fix_h'):
        sys.exit('invalid stop_criteria. stop_criteria should be either fix_h or stop')
    if not isinstance(stp_vec,(list, tuple, np.ndarray)) or any(x for x in stp_vec if not isinstance(x,(int, float, complex))):
        sys.exit('invalid stop_vector. stop_vector should be a list with three elements e.g. default is [0.75,0.75,0.75]')
    if stp_cnt != None:
        if not isinstance(stp_cnt,int) or stp_cnt < 0:
            sys.exit('invalid stop_count. stop_count should be a nonnegative integer.')

    # Initializations for Hammersley function
    base.append(-ndir)
    
    # Find the pointset for the given input signal
    if N_dim==3:
        base.append(2)
        seq = np.zeros((ndir,N_dim-1))
        for it in range(0,N_dim-1):
            seq[:,it] = hamm(ndir,base[it])
    else:
        #Prime numbers for Hammersley sequence
        prm = nth_prime(N_dim-1)
        for itr in range(1,N_dim):
            base.append(prm[itr-1])
        seq = np.zeros((ndir,N_dim))
        for it in range(0,N_dim):
            seq[:,it] = hamm(ndir,base[it])
    # Define t
    t = np.arange(1,N+1)
    #Counter
    nbit = 0
    MAXITERATIONS = 1000     #default    
    
    return(q,seq,t,ndir,N_dim,N,sd,sd2,tol,nbit,MAXITERATIONS,stp_crit,stp_cnt)
    
# =============================================================================
    
def memd(*args):
    x,seq,t,ndir,N_dim,N,sd,sd2,tol,nbit,MAXITERATIONS,stop_crit,stp_cnt = set_value(args)

    r=x
    n_imf=1
    q = []

    while stop_emd(r,seq,ndir,N_dim) == False:
        # current mode
        m = r
        
        # computation of mean and stopping criterion
        if stop_crit == 'stop':
            stop_sift,env_mean = stop(m,t,sd,sd2,tol,seq,ndir,N,N_dim)
        else:
            counter=0
            stop_sift,env_mean,counter = fix(m,t,seq,ndir,stp_cnt,counter,N,N_dim)
            
        # In case the current mode is so small that machine precision can cause
        # spurious extrema to appear
        if np.max(np.abs(m)) < (1e-10)*(np.max(np.abs(x))):
            if stop_sift == False:
                warnings.warn('emd:warning','forced stop of EMD : too small amplitude')
            else:
                print('forced stop of EMD : too small amplitude')
            break
        
        # sifting loop
        while stop_sift == False and nbit < MAXITERATIONS:
            # sifting
            m = m - env_mean
            
            # computation of mean and stopping criterion
            if stop_crit =='stop':
                stop_sift,env_mean = stop(m,t,sd,sd2,tol,seq,ndir,N,N_dim)
            else:
                stop_sift,env_mean,counter = fix(m,t,seq,ndir,stp_cnt,counter,N,N_dim)
        
            nbit=nbit+1
            
            if nbit == (MAXITERATIONS-1) and  nbit > 100:
                warnings.wanr('emd:warning','forced stop of sifting : too many erations')
            
        q.append(m.transpose())
        
        n_imf = n_imf+1
        r = r - m
        nbit = 0
        
    # Stores the residue
    q.append(r.transpose())
    q = np.asarray(q)
    #sprintf('Elapsed time: %f\n',toc);

    return(q)

### 第四部分 复现的一种LMS建模滤波


[Ref: 余明 胸外按压过程中伪迹抑制与心电节律辨识算法研究](https://kns.cnki.net/KCMS/detail/detail.aspx?dbcode=CDFD&dbname=CDFDLAST2017&filename=1017230620.nh&uid=WEEvREcwSlJHSldRa1FhcTdnTnhXY2c0T0VXK0hUSStVN2g5eVJsWXNrND0=$9A4hF_YAuvQ5obgVAqNKPCYcEjKensW4IQMovwHtwkF4VYPoHbKxJw!!&v=MDEyOTMzcVRyV00xRnJDVVI3cWZaT1pvRkNqaFZiL0FWRjI2R2JHN0h0Zk9yNUViUElSOGVYMUx1eFlTN0RoMVQ=)



其中所用模型为，

$$S_{C P R}(n)=\sum_{k=1}^{N} a_{k}(n) \times \cos (k \phi(n))+b_{k}(n) \times \sin (k \phi(n))......(1)$$

其中，$N$为滤波器阶数，$\phi(n)$是根据参考信号获得的相位，按照LMS准则，迭代更新$a_{k}(n),b_{k}(n)$，其中，所求参数可表示为

$$
\begin{aligned}
&a(n)=\left[a_{1}(n), \cdots \cdots, a_{N}(n)\right]^{T}\\
&b(n)=\left[b_{1}(n), \cdots \cdots, b_{N}(n)\right]^{T}
\end{aligned}
$$

求取瞬时频率

$$f_{i}=\frac{1}{T_{s}\left(n_{i+1}-n_{i}\right)}=\frac{f_{s}}{\Delta n_{i}} \quad n_{i} \leq n<n_{i+1}$$

求取相位函数

$$\phi(n)=\frac{2 \pi}{\Delta n_{i}}\left(n-n_{i}\right)+i \times 2 \pi \quad n_{i} \leq n<n_{i+1}$$

相位部分可表示为，

$$
\begin{array}{l}
s_{I}(n)=[\cos (\phi(n)), \cdots \cdots, \cos (k \phi(n))] \\
s_{Q}(n)=[\sin (\phi(n)), \cdots \cdots, \sin (k \phi(n))]
\end{array}
$$

公式(1)可表示为

$$
\begin{array}{c}
S_{C P R}(n)=s_{I}(n) a(n)+s_{Q}(n) b(n) ......(2)\\
S_{E C G}(n)=S_{E C G_{-} M L X E D}(n)-S_{C P R}(n)
\end{array}
$$

根据LMS准则，可以得到所求参数的更新公式为，

$$
\begin{array}{l}
a(n+1)=a(n)+2 S_{E C G}(n) M s_{I}^{T}(n) \\
b(n+1)=b(n)+2 S_{E C G}(n) M s_{Q}^{T}(n)
\end{array}
$$

其中$M$为LMS自适应滤波器中步长，定义如下

$$
\begin{aligned}
&M=\operatorname{diag}\left(\mu_{1}, \cdots \cdots, \mu_{N}\right)\\
&\mu_{k}=\frac{1}{k} \mu_{0}
\end{aligned}
$$

自适应滤波器的阶数和步长取，选取准则依照滤波前后信噪比增益最大化gSNR

$$\mathrm{N}=2, \quad \mu_{0}=0.52$$

In [None]:
def modelRLS(ecg, cpr, verbose):
  '''
  建模版  rls自适应滤波
  Scpr,Serr = modelRLS(ecg, cpr, verbose)
  input:
      ecg     待滤波信号
      cpr     参考信号
      verbose   可视化
  output:
      Serr     滤波后信号
      Scpr     建模的cpr信号
  author：star hou  
  date ：2020.3.7
  email: 1029588176@qq.com
  '''
  # 根据参考信号求取对应相位
  maximum, minimum, maximumID, minimumID = findpeaks(cpr, seg=0)
  phase = getPhase(minimumID, ecg.shape[0])

  N = 1    # 阶数
  lmd = 0.97 # 遗忘因子
  phy = np.zeros((phase.size,N*2,1))
  
  for i,pha in enumerate(phase):
    weight = []
    for j in range(1,N+1):
      weight.append(np.cos(j*pha))
      weight.append(np.sin(j*pha))
    weight = np.array(weight)
    # np.array([np.cos(pha),np.sin(pha),np.cos(2*pha),np.sin(2*pha)]
    phy[i,:,:] = np.expand_dims(weight,1)

  w = np.zeros((2*N,1))
  fn = 0.03*np.identity(2*N)

  Serr = np.zeros_like(phase)
  Scpr = np.zeros_like(phase)

  for i in range(phase.size):
    Scpr[i] = np.dot(w.T, phy[i,:])
    Serr[i] = ecg[i] - Scpr[i]
    fn = 1/lmd*(fn-fn*np.dot(phy[i,:],phy[i,:].T)*fn/(lmd+np.dot(np.dot(phy[i,:].T, fn), phy[i,:])))
    w = w + np.dot(fn, phy[i,:])*Serr[i]

  if verbose:
    fig, ax = plt.subplots(4, 1, figsize=(6, 6), constrained_layout=True)
    ax[0].plot(cpr)
    ax[0].plot(minimumID,minimum,'ro') 
    ax[0].set_title('ref signal and maxminum peaks')   

    ax[1].plot(ecg)
    ax[1].set_title("source signal")

    ax[2].plot(Scpr)
    ax[2].set_title("estimate cpr signal")
 
    ax[3].plot(Serr)
    ax[3].set_title("filtered signal")

  return Scpr,Serr


def getPhase(n, shape):
  '''
  输入：求出的极小值 shape 信号总长度
  输出：对应各个点的相位值

  example:
  phase = getPhase(n)

  author：star hou  
  date ：2020.3.7
  email: 1029588176@qq.com
  '''
  n = np.insert(n,n.shape,shape)
  n = np.insert(n,0,0)
  i = 0
  phase = []
  for i in range(n.size-1):
    count = n[i]
    while count < n[i+1]:
      a = np.pi*2/(n[i+1]-n[i])*(count-n[i])+i*2*np.pi
      phase.append(a)
      count+=1
  return np.array(phase) 
  
def findpeaks(x,seg):
  '''
  maximum, minimum, maximumID, minimumID = findpeaks(x,seg)
  寻找信号极值
  x 输入信号
  seg  每隔n个点取一个极值

  综上，胸外按压的深度为至少5厘米而不大于6厘米，频率为至少100次/分~120次/分；按压通气比仍为30：2.

  120bpm时，应该有16个极值点，125个点取一个最大极值点
  100bpm时，应该有40/3个极值点，150个点取一个最大极值点

  author：star hou  
  date ：2020.3.6
  email: 1029588176@qq.com
  '''
  # 找到所有极值
  df = np.diff(x)
  dff = df[:-1]*df[1:]
  peakID = np.where(dff<0)[0]+1
  maximumID = []
  minimumID = []
  for i in peakID:
    if x[i]>x[i-1]:
      maximumID.append(i)
    else:
      minimumID.append(i)
   
  maximum = x[maximumID]
  minimum = x[minimumID]

  # 相隔seg个采样点取一个极小值
  if seg:
    minimumID.append(2000)
    maximumID.append(2000)
    state = -1
    miniout = []
    for index in minimumID:
      stateNow = index//seg
      if stateNow==state:
        if x[index]<x[minnow]:
          minnow = index
      else:
        if state>-1:
          miniout.append(minnow)
        minnow = index
        state = stateNow

    thr = 100
    err = np.where(np.diff(miniout)<thr)[0]
    errid = [] 
    for it in err:
      if(x[miniout[it]]<x[miniout[it+1]]):
        errid.append(it+1)
      else:
        errid.append(it)
    minimumID = np.delete(miniout, errid)
      
    minimum = x[minimumID] if(len(minimumID)) else []

    state = -1
    maxiout = []
    for index in maximumID:
      stateNow = index//seg
      if stateNow==state:
        if x[index]>x[maxnow]:
          maxnow = index
      else:
        if state>-1:
          maxiout.append(maxnow)
        maxnow = index
        state = stateNow

    err = np.where(np.diff(maxiout)<thr)[0]
    errid = [] 
    for it in err:
      if(x[maxiout[it]]>x[maxiout[it+1]]):
        errid.append(it+1)
      else:
        errid.append(it)

    maximumID = np.delete(maxiout, errid)   
    maximum = x[maximumID] if(len(maximumID)) else []
  return maximum, minimum, np.array(maximumID), np.array(minimumID),

def modelLMS(ecg, cpr, verbose):
  '''
  建模版  LMS自适应滤波
  Scpr,Serr = modelLMS(ecg, cpr)
  input:
      ecg     待滤波信号
      cpr     参考信号
      verbose   可视化
  output:
      Serr     滤波后信号
      Scpr     建模的cpr信号
  author：star hou  
  date ：2020.3.7
  email: 1029588176@qq.com
  '''
  # 根据参考信号求取对应相位
  maximum, minimum, maximumID, minimumID = findpeaks(cpr, seg=0)
  phase = getPhase(minimumID, ecg.shape[0])

  # 生成M矩阵
  N = 2
  u0 = 0.52
  M = np.diag(u0/np.arange(1,N+1))

  SI = np.zeros((phase.size,2,1))
  SQ = np.zeros((phase.size,2,1))

  for i,pha in enumerate(phase):
    SI[i,:,:] = np.expand_dims(np.array([np.cos(pha),np.cos(2*pha)]),1)
    SQ[i,:,:] = np.expand_dims(np.array([np.sin(pha),np.sin(2*pha)]),1)

  A = np.zeros((phase.size+1,N,1))
  B = np.zeros((phase.size+1,N,1))
  Serr = np.zeros_like(phase)
  Scpr = np.zeros_like(phase)
  step = 0.03

  for i in range(phase.size):
    Scpr[i] = np.dot(A[i,:,:].T,SI[i,:,:])+np.dot(B[i,:,:].T,SQ[i,:,:])
    Serr[i] = ecg[i] - Scpr[i]
    A[i+1,:,:] = A[i,:,:]+step*Serr[i]*np.dot(M,SI[i,:,:])
    B[i+1,:,:] = B[i,:,:]+step*Serr[i]*np.dot(M,SQ[i,:,:])
  if verbose:
    fig, ax = plt.subplots(4, 1, figsize=(6, 6), constrained_layout=True)
    ax[0].plot(cpr)
    ax[0].plot(minimumID,minimum,'ro') 
    ax[0].set_title('ref signal and maxminum peaks')   

    ax[1].plot(ecg)
    ax[1].set_title("source signal")

    ax[2].plot(Scpr)
    ax[2].set_title("estimate cpr signal")
 
    ax[3].plot(Serr)
    ax[3].set_title("filtered signal")
    
  return Scpr,Serr

def mixSignal(ecg:np.array, cpr:np.array, SNR:int)->np.array:
  '''
  返回混合信号
  Para
    ecg  心电信号 (len, )
    cpr  cpr干扰信号 (len, )
    SNR  信噪比
  '''
  K = 0.8
  mix = ecg + K*cpr
  # SNR = 5
  smix = ecg+np.power(10,-SNR/20)*np.std(ecg)/np.std(cpr)*cpr
  return smix

def evalue_rSNR(filtered:np.array, ecg:np.array)->float:
  '''
  计算恢复信号的质量
  Para
    filtered 滤波后的信号
    ecg   原始信号  纯净
  '''
  rSNR = 20*np.log10(np.std(ecg)/np.std(ecg-filtered))
  return rSNR
  
def evalue_SNR(cpr:np.array, ecg:np.array)->float:
  '''
  计算信号的质量
  Para
    cpr   cpr信号
    ecg   原始信号  纯净
  '''
  SNR = 20*np.log10(np.std(ecg)/np.std(cpr))
  return SNR

def evalue_gSNR(filtered:np.array, cpr:np.array, ecg:np.array)->float:
  '''
  计算信号的质量
  Para
    filtered 滤波后的信号
    cpr    cpr信号
    ecg    原始信号  纯净
  '''
  SNR = evalue_SNR(cpr, ecg)
  rSNR = evalue_rSNR(filtered, ecg)
  gSNR = rSNR-SNR
  return gSNR
  
#获取当前样条曲线
def getspline(x, seg):
  from scipy import interpolate 
  maximum, minimum, maximumID, minimumID = findpeaks(x,seg)
  peakID = np.sort(np.hstack((maximumID,minimumID)))
  np.insert(peakID, 0, 0)
  np.insert(peakID, -1, x.size)
  # print(maximum.size,minimum.size)
  # peakNum = maximum.size + minimum.size
  t = np.linspace(0,x.size-1,x.size)
  Spline = interpolate.splrep(peakID, x[peakID])
  return interpolate.splev(t,Spline)

def getRefSig(ecg, verbose):
  '''
  根据MEMD方法获取取信号的参考信号
  input:
     ecg  混合信号secg+scpr
     verbose
  output:
     cpr
  author：star hou  
  date ：2020.3.7
  email: 1029588176@qq.com
  '''
  sig = np.expand_dims(ecg,1)
  # 混合3路高斯噪声
  inp = np.random.randn(sig.size,3)
  inp = np.concatenate((sig,inp),1)
  imf = memd(inp)
  cpr = np.zeros_like(imf[0,0,:])
  for i in range(imf.shape[0]):
    # if i ==0:
    #   source = imf[i,0,:]
    # else:
    #   source += imf[i,0,:]
    # show(imf[i,0,:])
    fmax = SignalSpectrum(imf[i,0,:],250,figname = str(i),verbose=verbose)
    if max(fmax[0])<2.5:
      cpr+=imf[i,0,:]
  resig = ecg-cpr
  SignalSpectrum(cpr,250,verbose=verbose)
  
  SignalSpectrum(resig,250,verbose=verbose)
  return cpr 

def getPSD(x):
  fs = 250
  fft_size = x.shape[0]  
  # 利用np.fft.rfft()进行FFT计算，rfft()是为了更方便对实数信号进行变换，
  # 由公式可知/fft_size为了正确显示波形能量
  xf = np.abs(np.fft.rfft(x)/fft_size) 
  # rfft函数的返回值是N/2+1个复数，分别表示从0(Hz)到sampling_rate/2(Hz)的部分。
  #于是可以通过下面的np.linspace计算出返回值中每个下标对应的真正的频率：
  freqs = np.linspace(0, int(fs/2), int(fft_size/2+1))
  psd =  np.power(xf, 2)
  return psd

def getRef(scpr, smix):
  '''
  求取参考信号
  input:
    scpr 胸阻抗或按压深度
    smix 混合信号
  output:
    sref 参考信号 
  '''
  scpr = baselineRemove(scpr)
  scpr = smooth(scpr, 25)
  scpr1 = integralSignal(scpr)
  scpr1 = baselineRemove(scpr1)
  scpr2 = integralSignal(scpr1)
  scpr2 = baselineRemove(scpr2)
  scpr3 = diffSignal(scpr)
  scpr3 = smooth(scpr3, 25)
  scpr4 = diffSignal(scpr3)
  scpr4 = smooth(scpr4, 25)

  scpr = np.expand_dims(scpr,1)
  scpr1 = np.expand_dims(scpr1,1)
  scpr2 = np.expand_dims(scpr2,1)
  scpr3 = np.expand_dims(scpr3,1)
  scpr4 = np.expand_dims(scpr4,1)

  sref = np.concatenate((scpr,scpr1,scpr2,scpr3,scpr4),1)

  maxvalue = 0
  maxid = 0
  for i in range(sref.shape[1]):
    s = sref[:,i]
    now = autoCorrelate(getPSD(smix),getPSD(s),mode='linear')
    if now>maxvalue:
      maxvalue = now
      maxid = i
  sref = sref[:,maxid]
  return sref

### 第五部分 定义读写函数
还包含几个常用的小函数

In [None]:
def readtxt(path):
  '''
  reading date from txt and decode
  # Author: starhou
  # Email: 1029588176@qq.com
  # Date:  2020.02.21
  '''
  # reading
  out = []
  with open(path, 'r') as f:
      data = f.readlines()  
      for line in data:
        out.append(line)
  # decoding
  outArray = np.array(out[3:])
  index = int((outArray.shape[0]-2)/2)
  ecg = outArray[:index+1].astype(np.float32)
  res = outArray[index+2:].astype(np.float32)
  ecg=ecg/6524472.5*1000-1300
  res=res/330
  return ecg, res


In [None]:
def SinWave(T,fs,L):
    '''
    正弦波发生器
    input: 
           T  几个周期
           fs 采样频率
           L  几s长 (s)
    output:
           y 方波
    '''
    t = np.linspace(-T*np.pi, T*np.pi, L*fs)
    f = np.sin(t)
    return t,f

In [None]:
def maxMinScale(signal):
  '''
  将信号缩放到[up,down]之间
  input:
      signal one-dimention np.array  
  output:
      out   one-dimention np.array
  Author: Starhou
  E-mail: 1029588176@qq.com
  Date : 2020.02.24
  '''
  up = 1
  down = -1
  z = signal
  Scalar = len(signal.shape)==1 # 标量运算还是矢量运算
  if Scalar:
    zstd = (z - np.min(z))/(np.max(z)-np.min(z))
  else:
    zstd = (z - np.min(z, axis=1))/(np.max(z,axis=1)-np.min(z, axis=1))
  s = zstd * (up - down) + down
  out = s
  return out

In [None]:
def show(ecg):
  '''
  show ecg
  '''
  plt.figure
  plt.plot(ecg)
  plt.show()

### 第六部分 开始开发

#### 6.1 读取数据

两种方式，一种直接读取txt文件，还有一种是直接读取保存好的npy文件

In [None]:
### 方式-
path = 'data/data/NSHR_couple.txt'
secg, sres = readtxt(path)

In [None]:
### ECG信号
database = 'ahadb'
feat_label = np.load('data/'+database+'_8.npy')
trainECG   = feat_label[:,:2000]
trainECG = signal.resample(trainECG,2000,axis=1)
### cpr信号（胸阻抗）
cpr = np.load('data/new/cpr.npy')

In [None]:
ECG = np.load('data\\data\\ECG.npy')
CPR = np.load('data\\data\\CPR.npy')

In [None]:
CPR.shape

#### 6.2 频谱分析和时频谱分析

###### 6.2.1 典型SR信号

>典型的SR信号频率在[0.5,30]Hz之间，信号具很强的准周期性，同时能量最高的谐波所对应的频率不一定是心率。由于不同人的心电信号特点各不相同，同时，心电检测时可能引入不同程度的干扰，如肌电，工频干扰等。对可电击心率的判别来说，一般对信号进行[0.5,30]Hz滤波




In [None]:
SR = trainECG[297,:]
SignalSpectrum(SR, 250, verbose=1)

In [None]:
cwtmatr, frequencies = timeFrequencyDiagram(SR,f = 6)

###### 6.2.2 典型的VF信号
VF信号比较复杂，基本没有准周期性，频率成分一般在[0,20]Hz之内，无差别滤波，我们一般分析其[0.5,30]Hz之内的频谱。其频谱呈现不集中的特点。

In [None]:
VF = trainECG[0,:]
SignalSpectrum(VF, 250, verbose=1)

In [None]:
cwtmatr, frequencies = timeFrequencyDiagram(VF,f = 6)

###### 6.2.3 典型的VT信号
VT信号的频谱成分较为集中，可以辨别明显的主频，信号也有较强的准周期性，但其谐波成分相较SR来说较少

In [None]:
VT = trainECG[161,:]
SignalSpectrum(VT,250, verbose=1)

In [None]:
cwtmatr, frequencies = timeFrequencyDiagram(VT,f = 6)

###### 6.2.4 带有CPR的ASYS信号
典型 ASYS 中所包含的能量很低，频谱中除包含有少量低频漂
移成份外，基本类似于高斯噪声信号，下图时带有CPR的停博信号和对应胸阻抗信号（用来近似CPR），可以看到，CPR信号的干扰，使得ASYS信号的频谱变的变的类似VT和VF


In [None]:
path = 'data/data/ecg12.txt'
ASYS, ASYSres = readtxt(path)
ASYS = myfilter(ASYS, low=30, high=0.5, fs=250)
SignalSpectrum(ASYS,250, verbose=1)
ASYSres = myfilter(ASYSres, low=30, high=0.5, fs=250)
SignalSpectrum(ASYSres,250, verbose=1)

In [None]:
cwtmatr, frequencies = timeFrequencyDiagram(ASYSres,f = 6)

###### 6.2.5 带有CPR干扰的SR
SR信号在CPR干扰下，信号频谱不再集中，尤其是低频成分,CPR信号的主要频谱分布在10Hz之内，主频在2.5Hz左右

In [None]:
path = 'data/data/ecg27.txt'
CPR_SR, CPR_SRres = readtxt(path)
CPR_SR = myfilter(CPR_SR, low=30, high=0.5, fs=250)
CPR_SRres = myfilter(CPR_SRres, low=30, high=0.5, fs=250)
SignalSpectrum(CPR_SR,250, verbose=1)
SignalSpectrum(CPR_SRres,250, verbose=1)

##### 6.2.6 典型的NSHRcouple信号

In [None]:
path = 'data/data/NSHR_couple.txt'
secg, sres = readtxt(path)
secg = myfilter(secg,low=30,high=1,fs=250)
SignalSpectrum(secg,verbose=1)

In [None]:
timeFrequencyDiagram(secg)

##### 6.2.7 典型的ASYScouple信号

In [None]:
path = 'data/data/ASYS_couple.txt'
secg, sres = readtxt(path)
secg = myfilter(secg,low=30,high=1,fs=250)
SignalSpectrum(secg,verbose=1)

In [None]:
timeFrequencyDiagram(secg)

#### 6.3 获取参考信号

###### 6.3.1 参考信号积分效果

In [None]:
scpr1 = integralSignal(scpr)
scpr1 = baselineRemove(scpr1)
timeFrequencyDiagram(scpr1,f=20)

In [None]:
scpr2 = integralSignal(scpr1)
scpr2 = baselineRemove(scpr2)
timeFrequencyDiagram(scpr2,f=20)

###### 6.3.2 参考信号差分效果


In [None]:
scpr3 = diffSignal(scpr)
scpr3 = smooth(scpr3, 25)
timeFrequencyDiagram(scpr3,f=20)

In [None]:
scpr4 = diffSignal(scpr3)
scpr4 = smooth(scpr4, 25)
timeFrequencyDiagram(scpr4,f=20)

#### 6.4 总结
大部分CPR信号集中集中在低频部分，这和VF,VT类似，两者频谱也有不可忽视的重叠部分。如果仅通过滤去低频的方法，容易把VT，和VF信号误判为ASYS

### 第七部分 目标工作流程图
signal(secg,scpr) --> 信号归一化（maxMinScale）--> 滤波预处理myfilter(0.5-30Hz) --> 信号混合（smix）--> 滤波预处理myfilter(0.5-30Hz) --> EMD+自适应滤波法，神经网络(GAN)法 --> secg

**目前完成：**
   EMD+自适应滤波法(第四部分)
   信号混合
   信号预处理
   
**待完成：**
   优化评估现有工作


#### 7.1 求取参考信号
基于6.2频谱分析部分，混合三路零均值的高斯白噪声，采用MEMD方法分解得到不同频率成分的子信号，使用所有主频小于2.5Hz的子信号重构作为参考信号

In [None]:
sref = getRef(CPR_SRres,CPR_SR)

#### 7.2 测试效果

#### 7.3 Mixing two signals
\begin{equation}
S_{E C G_{-} M X E D}(\mathrm{n})=S_{E C G}(n)+10^{-\frac{S N R}{20}} \times \frac{s t d\left[S_{E C G}(n)\right]}{\operatorname{std}\left[S_{C P R}(n)\right]} S_{C P R}(n)
\end{equation}
  \\
\begin{equation}
S_{E C G_{-} M X E D}(\mathrm{n})=S_{E C G}(n)+K\times S_{C P R}(n)
\end{equation}

衡量信号干扰大小

$$S N R=20 \log _{10}\left(\frac{\operatorname{std}\left(S_{E C G}(n)\right)}{\operatorname{std}\left(S_{C P R}(n)\right)}\right)$$

衡量恢复信号的质量

$$r S N R=20 \log _{10}\left(\frac{\operatorname{std}\left(S_{E C G}(n)\right)}{\operatorname{std}\left(S_{E C G}(n)-S_{E C G}(n)^{\prime}\right)}\right)$$



衡量滤波器的好坏
$$g S N R=r S N R-S N R$$




In [None]:
K = 0.8
smix = secg + K*scpr
# smix = secg+np.power(10,-SNR/20)*np.std(secg)/np.std(scpr)*scpr

##### 7.3.1 混合效果

In [None]:
_,sincpr = SinWave(25, 250, 10)
plt.plot(sincpr)

In [None]:
secg = trainECG[290,:]
# cprid =  np.random.randint(0,31)
scpr = cpr[9,:] 
# scpr = baselineRemove(scpr)
scpr = myfilter(scpr, low=30, high=0.5, fs=250)
# scpr = getspline(scpr, 125)
mainf = SignalSpectrum(secg, 250, verbose=1)
print(mainf[0],mainf[1])
# SignalSpectrum(scpr, 250, verbose=1)
mainf = SignalSpectrum(scpr, 250, verbose=1)
print(max(mainf[0]),mainf[1])

##### 7.3.2 滤波效果

In [None]:
smix = mixSignal(secg, scpr, -10)
# smix = diffSignal(smix)
mainf = SignalSpectrum(smix, 250, verbose=1)
print(mainf[0],mainf[1])

In [None]:
frefref, amref,_,_ = SignalSpectrum(scpr, 250, verbose=1)
print(frefref, amref/amref[0])

frefref, amref,_,_ = SignalSpectrum(scpr, 250, verbose=1)
_,sincpr = SinWave(frefref[0]*10, 250, 10)
Scpr, Serr = modelRLS(smix, sincpr, verbose=1)
_,sincpr = SinWave(frefref[1]*10, 250, 10)
Scpr, Serr = modelRLS(Serr, sincpr, verbose=1)
_,sincpr = SinWave(frefref[2]*10, 250, 10)
Scpr, Serr = modelRLS(Serr[:], sincpr[:], verbose=1)

frefref, amref,_,_ = SignalSpectrum(Serr, 250, verbose=1)
frefref, amref,_,_ = SignalSpectrum(smix, 250, verbose=1)

#### 7.4 混合滤除数据库

> 随机从21条干扰中选取1条干扰，混入纯净信号

###### 我的方法

In [None]:
databases = ["vfdb","cudb","ahadb"]
mapping = {"ahadb":ahadb,"vfdb":vfdb,"cudb":cudb}

In [None]:
for db in databases:
  a = [-3,-5,-8,-10]
  trainECG = mapping[db]
  for SNR in a:
    # 信号
    secg = trainECG[0,:]
    # 干扰
    cprid =  np.random.randint(0,21)
    scpr = cpr[cprid,:]
    scpr = myfilter(scpr, low=30, high=0.5, fs=250)
    frefref, amref, _, _ = SignalSpectrum(scpr, 250, verbose=0)

    # 混合
    smix = mixSignal(secg, scpr, SNR)


    _,sincpr = SinWave(frefref[0]*10, 250, 10)
    Scpr, Serr = modelRLS(smix, amref[0]/amref[0]*sincpr, verbose=0)
    _,sincpr = SinWave(frefref[1]*10, 250, 10)
    Scpr, Serr = modelRLS(Serr, amref[1]/amref[0]*sincpr, verbose=0)
    _,sincpr = SinWave(frefref[2]*10, 250, 10)
    Scpr, Serr = modelRLS(Serr[250:2250], amref[2]/amref[0]*sincpr[250:2250], verbose=0)

    ## 保存的两个变量
    swithcpr = smix
    sfilteredcpr = Serr


    for i in range(1,trainECG.shape[0]):
      secg = trainECG[i,:]
      # 干扰
      cprid =  np.random.randint(0,21)
      scpr = cpr[cprid,:]
      scpr = myfilter(scpr, low=30, high=0.5, fs=250)
      frefref, amref, _, _ = SignalSpectrum(scpr, 250, verbose=0)
      smix = mixSignal(secg, scpr, SNR)

      _,sincpr = SinWave(frefref[0]*10, 250, 10)
      Scpr, Serr = modelRLS(smix, amref[0]/amref[0]*sincpr, verbose=0)
      _,sincpr = SinWave(frefref[1]*10, 250, 10)
      Scpr, Serr = modelRLS(Serr, amref[1]/amref[0]*sincpr, verbose=0)
      _,sincpr = SinWave(frefref[2]*10, 250, 10)
      Scpr, Serr = modelRLS(Serr[250:2250], amref[2]/amref[0]*sincpr[250:2250], verbose=0)
      # 更新保存的变量
      
      smix = myfilter(smix, low=30, high=0.5, fs=250)
      Serr = myfilter(Serr, low=30, high=0.5, fs=250)

      swithcpr = np.vstack((swithcpr,smix))
      sfilteredcpr = np.vstack((sfilteredcpr,Serr))

    np.save("data\\data\\" + db + "\\scpr_" + str(SNR) +".npy", swithcpr)
    np.save("data\\data\\" + db + "\\sfiltered_" + str(SNR) +".npy", sfilteredcpr)

###### 余明的方法

In [None]:
  secg = trainECG[2,:]
  # 干扰
  cprid =  np.random.randint(0,21)
  scpr = cpr[cprid,:]
  scpr = myfilter(scpr, low=30, high=0.5, fs=250)

  # 混合
  smix = mixSignal(secg, scpr, SNR)


  scref = getRefSig(smix,0)
  Scpr, Serr = modelLMS(smix[250:2250], scref[250:2250], verbose=0)

In [None]:
a = [-3,-5,-8,-10]



for SNR in a:
  # 信号
  secg = trainECG[0,:]
  # 干扰
  cprid =  np.random.randint(0,21)
  scpr = cpr[cprid,:]
  scpr = myfilter(scpr, low=30, high=0.5, fs=250)

  # 混合
  smix = mixSignal(secg, scpr, SNR)


  scref = getRefSig(smix,0)
  Scpr, Serr = modelLMS(smix[250:2250], scref[250:2250], verbose=0)

  ## 保存的两个变量
  swithcpr = smix
  sfilteredcpr = Serr


  for i in range(1,trainECG.shape[0]):
    secg = trainECG[i,:]
    # 干扰
    cprid =  np.random.randint(0,21)
    scpr = cpr[cprid,:]
    scpr = myfilter(scpr, low=30, high=0.5, fs=250)

    scref = getRefSig(smix,0)
    Scpr, Serr = modelLMS(smix[250:2250], scref[250:2250], verbose=0)
    # 更新保存的变量
    
    smix = myfilter(smix, low=30, high=0.5, fs=250)
    Serr = myfilter(Serr, low=30, high=0.5, fs=250)

    swithcpr = np.vstack((swithcpr,smix))
    sfilteredcpr = np.vstack((sfilteredcpr,Serr))

  np.save("data\\paper1\\PUDB\\scpr_" + str(SNR) +".npy", swithcpr)
  np.save("data\\paper1\\PUDB\\sfiltered_" + str(SNR) +".npy", sfilteredcpr)

###### 龚余顺的方法

In [None]:
a = [-3,-5,-8,-10]



for SNR in a:
  # 信号
  secg = trainECG[0,:]
  # 干扰
  cprid =  np.random.randint(0,21)
  scpr = cpr[cprid,:]
  scpr = myfilter(scpr, low=30, high=0.5, fs=250)
  frefref, amref, _, _ = SignalSpectrum(scpr, 250, verbose=0)

  # 混合
  smix = mixSignal(secg, scpr, SNR)


  scpr = getRef(scpr, smix) 
  Scpr, Serr = modelLMS(smix[250:2250], scpr[250:2250], verbose=0)

  ## 保存的两个变量
  swithcpr = smix
  sfilteredcpr = Serr


  for i in range(1,trainECG.shape[0]):
    secg = trainECG[i,:]
    
    # 干扰
    cprid =  np.random.randint(0,21)
    scpr = cpr[cprid,:]
    scpr = myfilter(scpr, low=30, high=0.5, fs=250)
    smix = mixSignal(secg, scpr, SNR)

    scpr = getRef(scpr, smix) 
    Scpr, Serr = modelLMS(smix[250:2250], scpr[250:2250], verbose=0)
    # 更新保存的变量
    
    smix = myfilter(smix, low=30, high=0.5, fs=250)
    Serr = myfilter(Serr, low=30, high=0.5, fs=250)

    swithcpr = np.vstack((swithcpr,smix))
    sfilteredcpr = np.vstack((sfilteredcpr,Serr))

  np.save("data\\paper2\\PUDB\\scpr_" + str(SNR) +".npy", swithcpr)
  np.save("data\\paper2\\PUDB\\sfiltered_" + str(SNR) +".npy", sfilteredcpr)