In [9]:
import matplotlib
%matplotlib nbagg

import sys
import os
import itertools
import multiprocessing as mp

sys.path.append(os.path.expanduser("../GitHub/py2p/tools"))
sys.path.append(os.path.expanduser("../GitHub/haussmeister"))
import intan
os.environ['USER'] = "hello"
import spectral
from haussmeister import spectral as hspectral

import numpy as np
from scipy.optimize import minimize
from scipy.optimize import differential_evolution
from scipy.optimize import brute
from scipy.interpolate import interp1d
from scipy.signal import decimate

import matplotlib.pyplot as plt
import seaborn as sns

In [10]:
def make_intan_file(year, month, day, time, aux=False, user="chunlei", prefix="ephys_"):
    if sys.platform == 'win32':
        path = os.path.join("Y:","data")
    else:
        path = os.path.join(
            os.path.expanduser("~"), "Trillian2", "data")
    yeardir = "{0:04d}".format(year)
    monthdir = yeardir + "-{0:02d}".format(month)
    daydir = monthdir + "-{0:02d}".format(day)
    intan_datestring = "_{0:02d}{1:02d}{2:02d}_{3:06d}".format(
        year-2000, month, day, time)
    if aux:
        auxstring = "_AUX"
    else:
        auxstring = "_A"
            
    fullpath = os.path.join(
        path, yeardir, monthdir, daydir, user,
        prefix+intan_datestring, prefix+auxstring+intan_datestring+".clp")
    return fullpath

In [11]:
def rt_envelope(timestamps, band, treset):
    """
    Real time envelope function
    timestamps : numpy.ndarray
        The time stamps for each sampling point
    band : numpy.ndarray
        The signal for the envelope computation
    treset : float
        The reset time for the envelope computation

    Returns
    -------
    env : numpy.ndarray
        The envelope of band
    """
    dt = np.median(np.diff(timestamps))
    ireset = int(np.round(treset/dt))
    reset = False
    env = np.zeros(band.shape)
    for nts,ts in enumerate(timestamps):
        if nts < ireset:
            continue

        if not reset and np.abs(band[nts]) < np.abs(band[nts-ireset]):
            env[nts] = np.max(np.abs(band[nts-ireset:nts]))
            reset = True
        else:
            env[nts]=env[nts-1]
            reset=False

    return env

In [12]:
def simulate_arduino_envelope(timestamps, ripple_env, noise_env, ardconfig, maxripples=None):
    tlatesttrigger = -1e6 #
    tloop = 0

    if ardconfig['thetanoise'] > np.max(noise_env):
        return None
    if ardconfig['thetaripple'] > np.max(ripple_env):
        return None

    dt = np.median(np.diff(timestamps)) * 1e6 # us
    irefractory = int(np.round(ardconfig['trefractorymicros'] / dt))
    tnoiseback = 10000 # us
    inoiseback = int(np.round(tnoiseback / dt))
    nripples = 0
    if maxripples is None:
        maxripples = 100
    ripple_detected = np.zeros(timestamps.shape).astype(np.bool)
    ripple_trigger = ripple_env > ardconfig['thetaripple']
    noise_trigger = ripple_env > ardconfig['thetanoise']
    ripple_onsets = np.where(np.diff(ripple_trigger.astype(np.int)) == 1)[0]
    if not len(ripple_onsets):
        return None
    
    for ro in ripple_onsets:
        # check if this is in refractory period:
        if ripple_detected[ro]:
            continue

        # check if noise crosses threshold in vicinity:
        if not np.any(noise_trigger[ro-inoiseback:ro]):
            ripple_detected[ro:ro+irefractory] = True
            nripples += 1
            if nripples > maxripples:
                return None
        
    return ripple_detected

In [13]:
def false_negatives(simulated, offline, tolerance):
    falseneg, truepos = 0, 0
    for ripplestart, ripplestop in zip(offline[0, :], offline[1, :]):
        if not np.any(np.diff(simulated.astype(np.int)[ripplestart-tolerance:ripplestop+tolerance]) > 0):
            falseneg += 1
        else:
            truepos += 1
    return falseneg, truepos

def false_positives(simulated, offline, tolerance):
    # get ripple onsets for simulated arduino:
    sim_onsets = np.where(np.diff(simulated.astype(np.int)) > 0)[0]
    falsepos = 0
    truepos = 0
    for sim_onset in sim_onsets:
        for ripplestart, ripplestop in zip(offline[0, :], offline[1, :]):
            if sim_onset >= ripplestart-tolerance and sim_onset < ripplestop+tolerance:
                truepos += 1
                break
        falsepos += 1
    return falsepos, truepos

In [14]:
key_sequencede = [
    'thetanoise',
    'thetaripple',
    'trefractorymicros']

def to_params(ardconf, ardcons):
    params = np.array([
        ardconf[k] for k in key_sequencede])
    constraints = np.array([
        ardcons[k] for k in key_sequencede])
    return params, constraints

def to_grid(ardcons, N):
    return [
        np.linspace(ardcons[k][0], ardcons[k][1], N) for k in key_sequencede]

def from_params(params):
    return {k:p for k,p in zip(key_sequencede, params)}

def target(params, timestamps, rippleband, noiseband, ripples):
    currentardconfig = from_params(params)
    nripples = ripples.shape[1]
    sim = simulate_arduino_envelope(timestamps, rippleband, noiseband, currentardconfig, nripples*8)
    if sim is None:
        return np.nan
    dt = (timestamps[1]-timestamps[0])*1e6
    fn, tp1 = false_negatives(sim, ripples, tolerance=int(np.round(20000.0/dt)))
    fp, tp2 = false_positives(sim, ripples, tolerance=int(np.round(20000.0/dt)))
    if tp1 > 3:
        print(from_params(params), fn, fp, tp1, fn+fp-tp1)
    elif tp1 > 0:
        sys.stdout.write('.')
        sys.stdout.flush()
    return fp, tp1

class ParGrid(object):
    def __init__(self, timestamps, rippleband, noiseband, ripples):
        self.timestamps = timestamps
        self.rippleband = rippleband
        self.noiseband = noiseband
        self.ripples = ripples

    def __call__(self, pars):
        return pars, target(pars, self.timestamps, self.rippleband, self.noiseband, self.ripples)

### Load data

In [15]:
year = 2018
month = 9
day = 19
time = 173608

In [None]:
intanfn = make_intan_file(year, month, day, time)
intanfn_aux = make_intan_file(year, month, day, time, aux=True)
intanf = intan.IntanFile(intanfn)
intanf_aux = intan.IntanFile(intanfn_aux)

### Offline ripple detection

In [None]:
config = {
    'vsub': -40, # -40
    'vmin': -150, 
    'apthreshold': -30, #-30
    'min_nripples': 3,
    'std_ripple_thresholds': (2, 10), # standard deviation threshold to detect a ripple
    'ripple_bandpass': (0.1, 0.2), # Corner frequencies for the ripple bandpass filter (in kHz)
    'noise_bandpass': (0.3, 0.5),# Corner frequencies for the noise bandpass filter (in kHz)
    'spindle_bandpass': (0.01, 0.02),
    'resting_threshold': 1.0,
    'running_threshold': 1.0,
    'running_duration': 1000.0,
}

In [None]:
class Timeseries(object):
    def __init__(self, data, dt):
        self.data = data
        self.dt = dt
dt = np.mean(np.diff(intanf_aux.data['Time']))*1e3
LFP_data = Timeseries(intanf_aux.data['ADC'][1]-np.median(intanf_aux.data['ADC'][1]).astype(np.float), dt)
LFP_data_bp = spectral.bandpass(LFP_data, config['ripple_bandpass'][0], config['ripple_bandpass'][1])
LFP_data = spectral.bandpass(LFP_data, 0.002, 10.0)
LFP_noise_data_bp = spectral.bandpass(LFP_data, config['noise_bandpass'][0], config['noise_bandpass'][1])

In [None]:
print(LFP_data.data)
print(LFP_data.data.shape)
print(LFP_data_bp.data)
print(LFP_data_bp.data.shape)
print(LFP_noise_data_bp.data)
print(LFP_noise_data_bp.data.shape)
print(intanf_aux.data['Time'])
print(intanf_aux.data['Time'].shape)

In [None]:
fn_hilbert = os.path.join('dat', intanfn + '_hilbert.mat')
ripples, rippleargmaxs = hspectral.findRipples(
    LFP_data_bp, LFP_noise_data_bp, 
    std_thresholds=config['std_ripple_thresholds'], fn_hilbert=fn_hilbert)
print(ripples)
assert(len(ripples[0, :]))

In [None]:
ripples.shape

### Simulate arduino ripple detection

In [None]:
ardconfigde = {
    'thetanoise': 0.05,
    'thetaripple': 0.16,
    'trefractorymicros': 1000000,
}

ardconstraintsde = {
    'thetanoise': [0.05, 0.20],
    'thetaripple': [0.05, 0.40],
    'trefractorymicros': [10000.0, 2000000],
}

In [None]:
nip = 1
iend = 20000000
timestamps_ip = intanf_aux.data['Time'][::nip]
rippleband_ip_analog = decimate(
    intanf_aux.data['ADC'][2]-np.median(intanf_aux.data['ADC'][2]), nip)
noiseband_ip_analog = decimate(
    intanf_aux.data['ADC'][3]-np.median(intanf_aux.data['ADC'][3]), nip)
ripple_env = rt_envelope(timestamps_ip[:iend], rippleband_ip_analog[:iend], 0.005)
noise_env = rt_envelope(timestamps_ip[:iend], noiseband_ip_analog[:iend], 0.002)

In [None]:
fig = plt.figure()
ax_online_ripple = fig.add_subplot(411)
ax_online_ripple.plot(
    intanf_aux.data['Time'][:iend], np.abs(rippleband_ip_analog[:iend]))
#ax_online_ripple.plot(intanf_aux.data['Time'][:iend], ripple_env)

ax_online_signal = fig.add_subplot(412, sharex=ax_online_ripple)
ax_online_signal.plot(
    intanf_aux.data['Time'][:iend],
    rippleband_ip_analog[:iend])

ax_online_noise = fig.add_subplot(413, sharex=ax_online_ripple)
ax_online_noise.plot(
    intanf_aux.data['Time'][:iend],
    np.abs(noiseband_ip_analog[:iend]))
#ax_online_noise.plot(intanf_aux.data['Time'][:iend], noise_env)

ax_online_unfiltered = fig.add_subplot(414, sharex=ax_online_ripple)
ax_online_unfiltered.plot(
    intanf_aux.data['Time'][:iend],
    intanf_aux.data['ADC'][1][:iend])

In [None]:
simulated_ripples = simulate_arduino_envelope(
    timestamps_ip,
    np.abs(rippleband_ip_analog),
    np.abs(noiseband_ip_analog),
    ardconfigde,
    200
)
print(simulated_ripples is None)
dt = (intanf_aux.data['Time'][::nip][1]-intanf_aux.data['Time'][::nip][0])*1e6
fn, tp1 = false_negatives(simulated_ripples, ripples, tolerance=int(np.round(20000.0/dt)))
fp, tp2 = false_positives(simulated_ripples, ripples, tolerance=int(np.round(20000.0/dt)))
print(simulated_ripples.shape)
print(fn, tp1)
print(fp, tp2)

# Draw the simulation result diagram

In [None]:
fig3 = plt.figure()
ax_online_ripple = fig3.add_subplot(111)
ax_online_ripple.plot(timestamps_ip, np.abs(rippleband_ip_analog), 'b', alpha=0.75)
ax_online_ripple.plot(timestamps_ip, np.abs(noiseband_ip_analog), 'r--')
i = 0
for sr in simulated_ripples:
    if sr is not False:
        ax_online_ripple.plot(timestamps_ip[i], 0.05, "|g")
    i = i + 1
ax_online_ripple.grid(True)

In [None]:
for sr in simulated_ripples:
    if sr is not False:
        print(sr)

In [None]:
grid = to_grid(ardconstraintsde, 20)
grid_all = list(itertools.product(*grid))

args_ip_envelope = (timestamps_ip, np.abs(rippleband_ip_analog), np.abs(noiseband_ip_analog), ripples)

In [None]:
# x0, c0 = to_params(ardconfigde, ardconstraintsde)
# pmin = differential_evolution(target, c0, args=args_ip_envelope)# pmin = differential_evolution(target, c0, args=args)
# pmin = brute(target, c0, args=args)
"""
pool = mp.Pool(processes=2)
pargrid = ParGrid(*args_ip_envelope)
res = pool.map(pargrid, grid_all)
pool.close()"""

In [None]:
print(intanf_aux.data['Time'][:iend])
print(intanf_aux.data['ADC'][1])
print(intanf_aux.data['ADC'][adc][:iend]-np.median(intanf_aux.data['ADC'][adc]))
print(LFP_data_bp.data[:iend])

In [None]:
fig = plt.figure(figsize=(8,16))
axref = None
nsp = 1
for adc in range(1,4):
    ax = fig.add_subplot(910+nsp, sharex=axref)
    if axref is None:
        axref = ax
    ax.plot(intanf_aux.data['Time'][:iend], intanf_aux.data['ADC'][adc][:iend]-np.median(intanf_aux.data['ADC'][adc]), '-k', alpha=0.5) 
    if adc == 1:
        ax.set_ylabel("Broadband LFP")
    elif adc == 2:
        ax.plot(intanf_aux.data['Time'][:iend], LFP_data_bp.data[:iend], '-r', alpha=0.5)
        ax.set_ylabel("Ripple LFP")
    elif adc == 3:
        ax.plot(intanf_aux.data['Time'][:iend], LFP_noise_data_bp.data[:iend], '-r', alpha=0.5)
        ax.set_ylabel("Noise LFP")
    nsp += 1
lfpmin = np.min(intanf_aux.data['ADC'][1]-np.median(intanf_aux.data['ADC'][1]))
lfpmax = np.max(intanf_aux.data['ADC'][1]-np.median(intanf_aux.data['ADC'][1]))
"""
for itrigger in itriggers:
    axref.plot(intanf_aux.data['Time'][itrigger], lfpmin, "^")
    """
for ripplestart, ripplestop, rippleargmax in zip(ripples[0, :], ripples[1, :], rippleargmaxs):
    axref.plot(intanf_aux.data['Time'][ripplestart], lfpmax, "|g")
    axref.plot(intanf_aux.data['Time'][ripplestop], lfpmax, "|r")
    axref.plot(intanf_aux.data['Time'][rippleargmax], lfpmax, "|k")
    
for dig in range(4):
    ax = fig.add_subplot(910+nsp, sharex=axref)    
    ax.plot(intanf_aux.data['Time'][:iend], intanf_aux.data['DigitalInAll'][dig][:iend])
    nsp += 1
    if dig == 0:
        ax.set_ylabel("VR sync")
    elif dig == 1:
        ax.set_ylabel("Ripple threshold")
    elif dig == 2:
        ax.set_ylabel("Noise threshold")
    elif dig == 3:
        ax.set_ylabel("Arduino ripple detection")
ax = fig.add_subplot(910+nsp, sharex=axref)    
ax.plot(intanf_aux.data['Time'][:iend], simulated_ripples[:iend])
ax.set_ylabel("Simulated ripple detection")
ax.set_xlabel("Time (s)")
sns.despine()

# The FIR Filter (Dutta et al. 2017)

In [None]:
from scipy.signal import firwin
from scipy.signal import lfilter
FS=3000

#Online Filter Taps...To Use for Simulated Detection

#Bandpass FIR Filter Coeffs 150-250Hz passband

bandpassFilterTaps=firwin(30, [150,250], nyq=FS/2, pass_zero=False)

#Lowpass FIR Filter Coeffs (After Absolute Value)

lowpassFilterTaps=np.asarray([0.0203770957,

        0.0108532903,

        0.0134954582,

        0.0163441640,

        0.0193546202,

        0.0224738014,

        0.0256417906,

        0.0287934511,

        0.0318603667,

        0.0347729778,

        0.0374628330,

        0.0398648671,

        0.0419196133,

        0.0435752600,

        0.0447894668,

        0.0455308624,

        0.0457801628,

        0.0455308624,

        0.0447894668,

        0.0435752600,

        0.0419196133,

        0.0398648671,

        0.0374628330,

        0.0347729778,

        0.0318603667,

        0.0287934511,

        0.0256417906,

        0.0224738014,

        0.0193546202,

        0.0163441640,

        0.0134954582,

        0.0108532903,

        0.0203770957])



def rippleBandFilterSimulated(lfp, time, FS, bpFilterTaps, lpFilterTaps):

    """

    Ripple band filter and envelope simulating real-time algorithm

    """

    #Bandpass filter into ripple band

    rippleData = lfilter(bpFilterTaps,1,lfp)

    #Envelope

    rippleEnvelope = np.absolute(rippleData)

    #smooth

    smoothed_envelope = lfilter(lpFilterTaps,1,rippleEnvelope)

    return smoothed_envelope, rippleData

In [None]:
#ripple_env = rt_envelope(timestamps_ip[:iend], rippleband_ip_analog[:iend], 0.005)
#noise_env = rt_envelope(timestamps_ip[:iend], noiseband_ip_analog[:iend], 0.002)
smoothed_envelope2, rippleData2 = rippleBandFilterSimulated(rippleband_ip_analog[:iend], 
                                                                             timestamps_ip[:iend], 
                                                                             FS,
                                                                             bandpassFilterTaps,
                                                                             lowpassFilterTaps)
smoothed_envelope_noise, rippleData_noise = rippleBandFilterSimulated(noiseband_ip_analog[:iend], 
                                                                             timestamps_ip[:iend], 
                                                                             FS,
                                                                             bandpassFilterTaps,
                                                                             lowpassFilterTaps)

In [None]:
fig3 = plt.figure()
ax_online_ripple = fig3.add_subplot(211)
ax_online_ripple.plot(intanf_aux.data['Time'][:iend], np.abs(rippleData2))
ax_online_ripple.plot(intanf_aux.data['Time'][:iend], smoothed_envelope2)
ax_online_ripple.legend(('normalized ripple data', 'smoothed envelope'), loc='best')
ax_online_noise = fig.add_subplot(212, sharex=ax_online_ripple)
ax_online_noise.plot(intanf_aux.data['Time'][:iend], np.abs(rippleData_noise))
ax_online_noise.plot(intanf_aux.data['Time'][:iend], smoothed_envelope_noise)
ax_online_noise.legend(('normalized noise data', 'smoothed envelope for noise'), loc='best')

# Example Code of signal.lfilter

In [None]:
from scipy import signal
import matplotlib.pyplot as plt
t = np.linspace(-1, 1, 201)
x = (np.sin(2*np.pi*0.75*t*(1-t) + 2.1) +
      0.1*np.sin(2*np.pi*1.25*t + 1) +
      0.18*np.cos(2*np.pi*3.85*t))
xn = x + np.random.randn(len(t)) * 0.08

b, a = signal.butter(3, 0.05)

zi = signal.lfilter_zi(b, a)
z, _ = signal.lfilter(b, a, xn, zi=zi*xn[0])

z2, _ = signal.lfilter(b, a, z, zi=zi*z[0])

y = signal.filtfilt(b, a, xn)

fig2 = plt.figure()
ax_online_ripple = fig2.add_subplot(111)
ax_online_ripple.plot(t, xn, 'b', alpha=0.75)
ax_online_ripple.plot(t, z, 'r--', t, z2, 'r', t, y, 'k')
ax_online_ripple.legend(('noisy signal', 'lfilter, once', 'lfilter, twice',
            'filtfilt'), loc='best')
ax_online_ripple.grid(True)