In [1]:
from bqplot import pyplot as plt
import bqplot as bq
import ipywidgets as widgets
import math
import matplotlib
import numpy as np
from numpy import matlib
import mat73
from scipy import signal
from saclabel import tools

# init some values
iTrial    = 0
save_flag = False

# check if csv file is present in the /results folder. if not, make it
tools.check_results_file()
# get the data
#x_matrix,y_matrix  = tools.read_eye_data_mat() # these are trials by timepoints
x_matrix= tools.read_eye_data("data/X_train.csv") # these are trials by timepoints
y_matrix= tools.read_eye_data("data/Y_train.csv") # these are trials by timepoints


trialtime          = len(x_matrix[iTrial])

# get the predicted saccade onsets and offsets
onsets   = np.array([])
offsets  = np.array([])
# design butterworth filter and notch filter
Fs = 1000
w  = 50 / (Fs / 2) # Normalize the frequency
b_but, a_but       = signal.butter(4, w, 'low')
w_notch1           = 50 
w_notch2           = 100 
w_notch3           = 150 
b_notch1, a_notch1 = signal.iirnotch(w_notch1,5,Fs)
b_notch2, a_notch2 = signal.iirnotch(w_notch2,5,Fs)
b_notch3, a_notch3 = signal.iirnotch(w_notch3,5,Fs)

## functions are defined here ##
def new_trial(*ignore):
    global iTrial, trialtime, sacvec, patch, saveflag
    # get the new length of the trial
    trialtime             = len(x_matrix[iTrial])
    intsel.scale.max      = trialtime
    patch.scales['x'].max = trialtime

    # set the data for this trial
    Line_eye_x.y   =  x_matrix[iTrial] #- np.mean(x_matrix[ iTrial])
    Line_eye_y.y   =  y_matrix[iTrial] #- np.mean(y_matrix[ iTrial])
    Line_eye_x.x   =  np.arange(trialtime)
    Line_eye_y.x   =  np.arange(trialtime)
    x_ax.scale.max =  trialtime

    # init an empty binary saccade vector
    sacvec        = np.zeros((1000),dtype=int)
    binary_plot.y = []
    
    # set the current trial text
    current_trial_text.value = str(iTrial)

def update_patch(*ignore):
    global patch_begin, patch_end
    if intsel.selected is not None:
        patch_begin = int(intsel.selected) # take the selected interval and round to integer
        # add 1000 ms
        patch_end = patch_begin + 1000

        l1= [patch_begin,patch_begin,patch_end,patch_end]
        patch.x = (l1)
        patch.y = ([-1,1,1,-1])
        
        Line_eye_x_zoomed.x    = np.linspace(patch_begin,patch_end,patch_end-patch_begin)        
        Line_eye_x_zoomed.y    = x_matrix[iTrial][patch_begin:patch_end]
        Line_eye_y_zoomed.x    = np.linspace(patch_begin,patch_end,patch_end-patch_begin)
        Line_eye_y_zoomed.y    = y_matrix[iTrial][patch_begin:patch_end]
        
        # update axes
        x_ax_zoom.scale.min    = patch_begin
        x_ax_zoom.scale.max    = patch_end
        sac_selector.scale.min = patch_begin
        sac_selector.scale.max = patch_end

def save_new_sac(*ignore):
    global sacvec, newsac_onset, newsac_offset, binary_plot, patch_begin
    if sac_selector.selected is not None and len(sac_selector.selected) == 2:
        newsac_onset  = int(np.round(sac_selector.selected[0])) - patch_begin
        newsac_offset = int(np.round(sac_selector.selected[1])) - patch_begin
        
    # make the binary sacvec ones between saccade onset and offset
    sacvec[newsac_onset:newsac_offset] = 1
    binary_plot.y = []
    binary_plot.y = sacvec

def clear_sacs(*ignore):
    global sacvec
    sacvec        = np.zeros(1000)
    binary_plot.y = []

def lp_filter_trace(*ignore):
    Line_eye_y_zoomed.y  = signal.filtfilt(b_but, a_but, np.asarray(Line_eye_y_zoomed.y,dtype="float64") ) # low pass x and y eye signal with zero-phase butterworth
    Line_eye_x_zoomed.y  = signal.filtfilt(b_but, a_but, np.asarray(Line_eye_x_zoomed.y,dtype="float64") )

    
def show_raw_trace(*ignore):
    Line_eye_x_zoomed.x    = np.linspace(patch_begin,patch_end,patch_end-patch_begin)        
    Line_eye_x_zoomed.y    = x_matrix[iTrial][patch_begin:patch_end]
    Line_eye_y_zoomed.x    = np.linspace(patch_begin,patch_end,patch_end-patch_begin)
    Line_eye_y_zoomed.y    = y_matrix[iTrial][patch_begin:patch_end]

def notch_filter_trace(*ignore):
    Line_eye_y_zoomed.y  = signal.filtfilt(b_notch1, a_notch1, np.asarray(Line_eye_y_zoomed.y,dtype="float16")) # w_0 base freq
    Line_eye_x_zoomed.y  = signal.filtfilt(b_notch1, a_notch1, np.asarray(Line_eye_x_zoomed.y,dtype="float16"))
    Line_eye_y_zoomed.y  = signal.filtfilt(b_notch2, a_notch2, np.asarray(Line_eye_y_zoomed.y,dtype="float16")) # first harmonic
    Line_eye_x_zoomed.y  = signal.filtfilt(b_notch2, a_notch2, np.asarray(Line_eye_x_zoomed.y,dtype="float16"))
    Line_eye_y_zoomed.y  = signal.filtfilt(b_notch3, a_notch3, np.asarray(Line_eye_y_zoomed.y,dtype="float16")) # second harmonic
    Line_eye_x_zoomed.y  = signal.filtfilt(b_notch3, a_notch3, np.asarray(Line_eye_x_zoomed.y,dtype="float16"))

def save_trial(*ignore):
    global save_flag
    tools.append_list_as_row('./results/binary_labels.csv',sacvec)
    tools.append_list_as_row('./results/eye_x.csv',x_matrix[iTrial][patch_begin:patch_end]) # take the original data, since the Line data might be filtered
    tools.append_list_as_row('./results/eye_y.csv',y_matrix[iTrial][patch_begin:patch_end])

def next_trial(obj):
    global iTrial, trialtime, saveflag
    trialtime = len(x_matrix[iTrial])
    iTrial += 1
    saveflag = False
    new_trial()
    update_patch()
 

def previous_trial(obj):
    global iTrial
    iTrial -= 1
    new_trial()
    update_patch()

# Buttons and stuff
next_trial_button     = widgets.Button(description='Next trial',              disabled=False,button_style='') 
previous_trial_button = widgets.Button(description='Previous trial',          disabled=False,button_style='') 
save_new_sac_button   = widgets.Button(description='New saccade',             disabled=False,button_style='') 
clearsac_button       = widgets.Button(description='Clear saccades',          disabled=False,button_style='danger') 
lp_filter_button      = widgets.Button(description='Low pass filter',         disabled=False,button_style='info') 
notch_filter_button   = widgets.Button(description='Line noise notch filter', disabled=False,button_style='info') 
save_trial_button     = widgets.Button(description='Save trial',              disabled=False,button_style='success') 
show_raw_trace_button = widgets.Button(description='Raw trace',               disabled=False,button_style='')

# button calls
next_trial_button.on_click(next_trial)
previous_trial_button.on_click(previous_trial)
save_new_sac_button.on_click(save_new_sac)
clearsac_button.on_click(clear_sacs)
lp_filter_button.on_click(lp_filter_trace)
notch_filter_button.on_click(notch_filter_trace)
save_trial_button.on_click(save_trial)
show_raw_trace_button.on_click(show_raw_trace)


# selectors
intsel = bq.interacts.IndexSelector(scale=bq.LinearScale(min=1,max=trialtime),color='white')
intsel.observe(update_patch)

sac_selector  = bq.interacts.BrushIntervalSelector(scale=bq.LinearScale(),color='#F5CBA7')

# Plotting Stuff
animation_time = 150 # in ms

y_ax        = bq.Axis(label="x/y position",scale=bq.LinearScale(),orientation="vertical")
x_ax        = bq.Axis(label="time (ms)",scale=bq.LinearScale(min=1),orientation="horizontal")
y_ax_zoom   = bq.Axis(label="x/y position",scale=bq.LinearScale(),orientation="vertical")
x_ax_zoom   = bq.Axis(label="time (ms)",scale=bq.LinearScale(),orientation="horizontal")

Line_eye_x         = plt.plot(x=np.arange(trialtime),scales={'x':bq.LinearScale(),'y':bq.LinearScale()},colors = ['#229954'])
Line_eye_y         = plt.plot(x=np.arange(trialtime),scales={'x':bq.LinearScale(),'y':bq.LinearScale()},colors = ['#2E86C1'])
Line_eye_x_zoomed  = plt.plot(x=np.arange(1000),scales={'x':bq.LinearScale(),'y':bq.LinearScale()},colors = ['#229954'])
Line_eye_y_zoomed  = plt.plot(x=np.arange(1000),scales={'x':bq.LinearScale(),'y':bq.LinearScale()},colors = ['#2E86C1'])
binary_plot        = plt.plot(x=np.arange(1000),scales={'x':bq.LinearScale(),'y':bq.LinearScale()},colors = ['#F7DC6F'])


patch = plt.plot(scales={'x':bq.LinearScale(min=1,max=trialtime),'y':bq.LinearScale()},close_path=True,
                 stroke_width=0,fill='inside',opacities=[0.25],colors=['#839192'])

eyetrace_fig = bq.Figure(animation_duration=50,layout=widgets.Layout(flex='1 1 60%', width='auto'),axes=[x_ax,y_ax],
    marks=[Line_eye_x,Line_eye_y,patch],fig_margin=dict(top=25,bottom=50,left=50,right=25),interaction=intsel,title="Original trial eye trace")


interval_fig = bq.Figure(animation_duration=300,layout=widgets.Layout(flex='1 1 40%', width='auto'),axes=[x_ax_zoom,y_ax_zoom],
    marks=[Line_eye_x_zoomed,Line_eye_y_zoomed,binary_plot],fig_margin=dict(top=25,bottom=50,left=50,right=25),interaction=sac_selector,title="1s interval trace")

current_trial_text = widgets.Text(value="{}".format(iTrial),description='Current trial:',
                                  disabled=False,style = {'description_width': 'initial'},layout=widgets.Layout( width='auto'))

# throw the 2 plots in onw row
box_layout  = widgets.Layout(display='flex',flex_flow='row',align_items='stretch',width='100%',border='solid 2px white',padding="10px")
plot_box1   = widgets.HBox(children=[eyetrace_fig],layout=box_layout)
plot_box2   = widgets.HBox(children=[interval_fig],layout=box_layout)

button_box = widgets.HBox(children=[previous_trial_button, next_trial_button, save_new_sac_button, lp_filter_button,
                                    notch_filter_button,show_raw_trace_button, clearsac_button, save_trial_button,
                                   current_trial_text],layout=widgets.Layout(display='flex', flex_flow='row', align_items='flex-start', width='100%', justify_content='space-around', border='solid 2px white', padding="20px"))


both_buttons = widgets.HBox(children=[button_box])
box2         = widgets.HBox(children=[save_new_sac_button,clearsac_button,lp_filter_button],layout=widgets.Layout(display='inline-flex',width='96%',justify_content='flex-end'))
app          = widgets.VBox(children=[plot_box1,button_box ,plot_box2],layout=widgets.Layout(border="solid 2px gray",width="1500px"))

# run app
new_trial()
app


VBox(children=(HBox(children=(Figure(animation_duration=50, axes=[Axis(label='time (ms)', scale=LinearScale(ma…

In [6]:
print(Line_eye_x_zoomed.y)

['0.7500521601674208' '0.7486723353723548' '0.7467609360629819'
 '0.7430122910414472' '0.7397015859208338' '0.7409474233460038'
 '0.7500915503300956' '0.7637974752857417' '0.7751258047740937'
 '0.7807234036355428' '0.7796984821915895' '0.774959283462799'
 '0.7714949888148581' '0.7721666788838917' '0.7756809756809915'
 '0.7777708701280428' '0.7764386683787876' '0.7724217273879344'
 '0.7668904354857133' '0.7627644475147211' '0.7603990056507381'
 '0.7563954065856784' '0.7497022909046631' '0.7431439280189629'
 '0.7414845646885926' '0.7470262099123257' '0.7590666669689643'
 '0.7734272818472778' '0.7807261462832997' '0.7767998519398578'
 '0.7689493986547475' '0.7613820432603013' '0.757781465585861'
 '0.7611800101700013' '0.7692419695422217' '0.7768248064092589'
 '0.7792158764848445' '0.7755288937303284' '0.7677145273859569'
 '0.7582639766416772' '0.7506281049637042' '0.7457101010883314'
 '0.7423560882136453' '0.7396491469657356' '0.7390309672690396'
 '0.7425427872676842' '0.7490435226647207'