# Viewing pecking test raw data

With added comments to describe GUI elements

In [1]:
import datetime
import glob
import os
import re
import sys
from contextlib import contextmanager

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display, HTML, clear_output
from ipywidgets import interact, interactive
from scipy.stats import fisher_exact

from pecking_analysis.peck_data import (
    load_pecking_days,
    get_labels_by_combining_columns,
    plot_data,
    peck_data,
    color_by_reward,
    windows_by_reward,
    get_dates
)

pd.options.display.float_format = '{:,.4f}'.format

%load_ext autoreload
%autoreload 2

@contextmanager
def block_print():
    sys.stdout = open(os.devnull, 'w')
    yield 
    sys.stdout = sys.__stdout__

##  Fill in data directory containing folders for each subject

In [2]:
DATADIR = "/auto/tdrive/billewood/pecking test data/"

In [3]:
BIRDS = [
    os.path.basename(bird)
    for bird in sorted(glob.glob(os.path.join(DATADIR, "*")), key=os.path.getmtime, reverse=True)
    if re.search(r"^[a-zA-Z]{6}[0-9]{4}[MF]?$", os.path.basename(bird))]

In [4]:
BIRDS

['GreBla5671F',
 'GreBla7410M',
 'BroWhi0502F',
 'WhiRed9510F',
 'RedBla0907M',
 'XXXOra0037F',
 'HpiBlu6194F',
 'YelPur7906M',
 'WhiWhi2526M',
 'BluYel2571F',
 'YelRed3010F',
 'GraWhi4040F',
 'BlaGre1349M',
 'XXXHpi0038M',
 'GreBlu5039F',
 'GreBla3404M',
 'XXXRed0088M',
 'XXXOra0039F',
 'XXXBla0054F',
 'XXXBla0055M',
 'XXXBla0081M',
 'BluHpi2765M',
 'OraOra4449F',
 'YelOra8629F']

In [5]:
MAX_DATE_RANGE = 7

## Run this cell to view GUI

In [8]:
###
### The plot and data shown is refreshed any time a new option is selected in the drop down menus.
### Some input choices don't actually change what data range is necessary (i.e. filtering by
### vocalizers) but still refresh the display, so its better to not reload data
### that we've loaded before.
###
def cached_load(bird, date, date_end=None):
    """Load data for a given subject in a given date range"""
    if (bird, date, date_end) not in cached_load._cache:
        if date_end is not None:
            blocks, stim_blocks = load_pecking_days(os.path.join(DATADIR, bird), date_range=(date, date_end))
            cached_load._cache[(bird, date, date_end)] = (blocks, stim_blocks)
        else:
            blocks, stim_blocks = load_pecking_days(os.path.join(DATADIR, bird, date.strftime("%d%m%y")))
            cached_load._cache[(bird, date, date_end)] = (blocks, stim_blocks)

    return cached_load._cache[(bird, date, date_end)]
cached_load._cache = {}

fig = None


###########################################################################################
### Display functions                                                                   ###
###   These functions take inputs from dropdowns and sliders and create a display       ###
###   usually it is either a plot or a dataframe in HTML format.                        ###
###   Where the output of these functions pop up is defined at the very end of the cell ###
###########################################################################################

#
# This is the main function that is reloaded to update the plot when a drop down selection is changed.
# The arguments to this function represent the values that will be passed from all the little
# dropdown and slider widgets - they will be linked when we call widgets.interactive_output() later
#
def view_pecking_data(
        date,
        date_end, 
        bird, 
        mode,
        show_stims_re, 
        show_stims_un, 
        limit_trials,
        sig_threshold=0.05):
    
    # Every time we update the plot, I want to make sure to close the previous figure.
    global fig
    if fig is not None:
        plt.close(fig)
    if date is None:
        return
    
    # This is the set of stimulus files to show data for
    show_stims = show_stims_re + show_stims_un
    
    ###############################################################
    # Can ignore this chunk, code for dealing with the date range #
    ###############################################################
    if date_end != "None" and (((date_end - date).days < 0) or ((date_end - date).days > MAX_DATE_RANGE)):
        if date in date_end_picker.options:
            date_end_picker.value = date
        else:
            date_end_picker.options = ["None"] + list(date_picker.options)
            date_end_picker.value = "None"
        return
    
    valid_dates = get_dates(os.path.join(DATADIR, bird))
    valid_dates = list(reversed(valid_dates))
    old_date = date_picker.value
    
    if np.any([d1 != d2 for d1, d2 in zip(date_picker.options, valid_dates)]):
        date_picker.options = valid_dates
        if old_date not in valid_dates:
            date = valid_dates[0]
            date_picker.value = date
            old_date_end = date_end_picker.value
            date_end_picker.options = ["None"] + valid_dates
            date_end_picker.value = date
        return
    ####################
    # End ignore chunk #
    ####################

    # Load the data from the chosen date range
    if date_end == "None":
        blocks, stim_blocks = cached_load(bird, date)
    else:
        blocks, stim_blocks = cached_load(bird, date, date_end)
    
    # This is the main plotting loop
    block = blocks[0]
    stims = stim_blocks[0]

    # Ignore this for GUI purposes...
    # To plot by stim rather than just by reward/nonreward
    # I had to modify the block data. So I wanted to save it to reset it after plotting.
    block.data.index = pd.Series(np.arange(len(block.data)))    
    old_data = block.data

    #############################################################################################################
    # This section modifies the data for plotting and calls the plotting function (defined in a different file) #
    #############################################################################################################
    orig_len = None
    if mode == "bystim":
        if len(show_stims):
            orig_len = len(block.data)
            block.data = block.data[block.data["Bird Name"].isin(show_stims)]
        labels = get_labels_by_combining_columns(
            block,
            ["Class", "Call Type", "Bird Name"],
            lambda x: "{} {} {}".format(x[2], x[0], x[1])
        )

    elif mode == "byreward":
        labels = get_labels_by_combining_columns(
            block,
            ["Class", "Call Type"],
            lambda x: "{} {}".format(x[0], x[1])
        )

    # This function actually updates the data in the plot.
    fig = plot_data(block, labels, force_len=orig_len, label_order=lambda x: x.split()[1], index_by="trial", label_to_color=color_by_reward)

    if limit_trials:
        plt.xlim(0, limit_trials)
    plt.title("{} {}".format(bird, block.date.strftime("%d%m%y")))
    plt.show()
    ########################
    # End plotting section #
    ########################
    
    # Ignore for GUI purposes, just resets the data to what it was before
    block.data = old_data


#
# This function handles displaying the stimulus list (bottom left of display)
# Like view_pecking_data(), the parameters are values from drop downs / sliders that are relevant
# to this display and will be linked by widgets.interactive_output().
# It also updates the available options for selecting individual stimuli.
#
def view_stim_data(date, date_end, bird):
    if date is None:
        return
    
    # Load 
    if date_end == "None":
        blocks, stim_blocks = cached_load(bird, date)
    else:
        blocks, stim_blocks = cached_load(bird, date, date_end)
        
    block = blocks[0]
    stims = stim_blocks[0]
    
    # Update the picker boxes that let you choose individual stimuli to show
    # by limiting your options to only stimuli present in the currently selected date
    # range
    stim_picker_re.options = [
        ("{} ({})".format(bird_name, rewarded), bird_name)
        for bird_name, rewarded in zip(stims["Bird Name"], stims["Class"])
        if rewarded == "Rewarded"
    ]
    stim_picker_un.options = [
        ("{} ({})".format(bird_name, rewarded), bird_name)
        for bird_name, rewarded in zip(stims["Bird Name"], stims["Class"])
        if rewarded != "Rewarded"
    ]
    
    # Display the dataframe as HTML. The location of where it will be displayed
    # is defined later at the end of the cell.
    display(HTML(stims[["Bird Name", "Call Type", "Class", "Trials"]].to_html()))
    
#
# This function handles displaying the entire raw dataframe of trials (at the very bottom of display)
# Again, the parameters are values from drop downs / sliders that are relevant
# to this display - they will be linked when we call widgets.interactive_output() later
#
def view_raw_dataframe(date, date_end, show_stims_re, show_stims_un, bird):
    show_stims = show_stims_re + show_stims_un
    
    if date_end == "None":
        blocks, stim_blocks = cached_load(bird, date)
    else:
        blocks, stim_blocks = cached_load(bird, date, date_end)
        
    for block, stims in zip(blocks, stim_blocks):
        if block.date == date:
            if len(show_stims):
                old_data = block.data
                block.data = block.data[block.data["Bird Name"].isin(show_stims)]
                
            # Display the full dataframe data as HTML - location will be defined at the end of the
            # cell.
            display(HTML(block.data[[
                                     "Date",
                                     "Trial",
                                     "Bird Name",
                                     "Class",
                                     "Response",
                                     "RT",
                                     "Call Type", 
                                     "Stimulus Name",
                                    ]].to_html(index=False)))
            if len(show_stims):
                block.data = old_data
        

#
# This is the function responsible for plotting the summary stats at the top of the display
# Again, the parameters are values from drop downs / sliders that are relevant
# to this display - they will be linked when we call widgets.interactive_output() later
#
def view_stats(date, date_end, bird):
    if date is None:
        return

    if date_end == "None":
        blocks, stim_blocks = cached_load(bird, date)
    else:
        blocks, stim_blocks = cached_load(bird, date, date_end)

    with block_print():
        data = peck_data(blocks)
    display(HTML(data.to_html()))


    
# These two functions are triggered when specific buttons are pressed
# This function resets which stimuli are selected,
# Search for view_all_button and view_all_button.on_click to see where it is linked
def view_all(*args):
    stim_picker_re.value = []
    stim_picker_un.value = []
    
# Saves the current figure
# Look for save_button and save_button.on_click to see how it is triggered.
def save_fig(butt):
    global fig
    if not os.path.exists("saved_figures"):
        os.makedirs("saved_figures")
    path = os.path.join("saved_figures", "{}_{}.eps".format(path_input.value, mode_picker.value))
    i = 0
    while os.path.exists(path):
        path = os.path.join("saved_figures", "{}_{}_{}.eps".format(path_input.value, mode_picker.value, i))
        i += 1
    fig.tight_layout()
    fig.savefig(path, format="eps")


##########################
### Widget definitions ###
##########################

date_picker = widgets.Dropdown(
    options=[datetime.date.today() - datetime.timedelta(days=x) for x in range(60)],
    value=datetime.date.today(),
    description="Date",
    disabled=False,
)

date_end_picker = widgets.Dropdown(
    options=["None"] + [datetime.date.today() - datetime.timedelta(days=x) for x in range(
        (datetime.date.today() - date_picker.value).days
    )],
    value="None",
    description="Date end",
    disabled=False,
)

limit_trials_picker = widgets.BoundedIntText(
    value=0,
    min=0,
    max=200,
    step=5,
    description='Limit trials (0 for no limit):',
    disabled=False
)

save_button = widgets.Button(
    description="Save Figure",
)
save_button.on_click(save_fig)

path_input = widgets.Text(
    value="figure",
    placeholder='figure',
    description='Save fig to:',
    disabled=False
)

bird_picker = widgets.Dropdown(
    options=BIRDS,
    value=BIRDS[0],
    description="Bird",
    disabled=False,
)

mode_picker = widgets.RadioButtons(
    options=[
        ("by stim", "bystim"),
        ("by reward", "byreward"),
    ],
    description='view mode:',
    disabled=False
)
    
view_all_button = widgets.Button(
    description="View all stims",
)
view_all_button.on_click(view_all)

stim_picker_re = widgets.SelectMultiple(
    options=[],
    value=[],
#     description='Select stims\n(ctrl click)',
    disabled=False
)

stim_picker_un = widgets.SelectMultiple(
    options=[],
    value=[],
#     description='Select stims\n(ctrl click)',
    disabled=False
)


###############################################################################
### Linking the output functions above to their inputs (dropdowns, sliders) ###
###   The keys of the dictionary passed correspond to the argument names    ###
###   of the display function passed                                        ###
###############################################################################

out_plot = widgets.interactive_output(
    view_pecking_data, 
    {
        "date": date_picker,
        "date_end": date_end_picker,
        "bird": bird_picker,
        "show_stims_re": stim_picker_re,
        "show_stims_un": stim_picker_un,
        "limit_trials": limit_trials_picker,
        "mode": mode_picker
    }
)

out_stim = widgets.interactive_output(
    view_stim_data, 
    {
        "date": date_picker,
        "date_end": date_end_picker,
        "bird": bird_picker,
    }
)

raw_dataframe = widgets.interactive_output(
    view_raw_dataframe, 
    {
        "date": date_picker,
        "date_end": date_end_picker,
        "show_stims_re": stim_picker_re,
        "show_stims_un": stim_picker_un,
        "bird": bird_picker,
    }
)

out_stats = widgets.interactive_output(
    view_stats, 
    {
        "date": date_picker,
        "date_end": date_end_picker,
        "bird": bird_picker,
    }
)

#####################################
### Defining the layout positions ###
###   VBox: vertical layout       ###         
###   HBox: horizontal layout     ###
###   These can be nested         ###
#####################################

widgets.VBox([
    out_stats,
    widgets.HBox([widgets.VBox([
        bird_picker,
        date_picker,
        date_end_picker,
        limit_trials_picker,
        widgets.VBox([path_input, save_button]),
        mode_picker,
        view_all_button,
        widgets.VBox([stim_picker_re, stim_picker_un]),
        widgets.Label("Stimulus list:"),
        out_stim
    ]), out_plot]),
    widgets.HBox([
        widgets.Label("Raw trial data: "),
        raw_dataframe
    ]),
])


VBox(children=(Output(), HBox(children=(VBox(children=(Dropdown(description='Bird', options=('GreBla5671F', 'G…