In [14]:
import datetime
import glob
import os
import re
import sys
from contextlib import contextmanager

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display, HTML
from ipywidgets import interact, interactive

from pecking_analysis.peck_data import (
    load_pecking_days,
    get_labels_by_combining_columns,
    plot_data,
    peck_data,
    color_by_reward
)

pd.options.display.float_format = '{:,.4f}'.format

@contextmanager
def block_print():
    sys.stdout = open(os.devnull, 'w')
    yield 
    sys.stdout = sys.__stdout__

##  Fill in data directory containing folders for each subject

In [15]:
DATADIR = "/auto/tdrive/billewood/pecking test data"

In [16]:
BIRDS = [
    os.path.basename(bird)
    for bird in sorted(glob.glob(os.path.join(DATADIR, "*")), key=os.path.getmtime, reverse=True)
    if re.search(r"^[a-zA-Z]{6}[0-9]{4}[MF]?$", os.path.basename(bird))]

## Run this cell

In [18]:
def cached_load(bird, date):
    if (bird, date) not in cached_load._cache:
        blocks, stim_blocks = load_pecking_days(os.path.join(DATADIR, bird, date.strftime("%d%m%y")))
        cached_load._cache[(bird, date)] = (blocks, stim_blocks)
    return cached_load._cache[(bird, date)]
cached_load._cache = {}


def view_pecking_data(date, bird, mode):
    if date is None:
        return

    blocks, stim_blocks = cached_load(bird, date)

    for block, stims in zip(blocks, stim_blocks):
        if block.date == date:
            block.data.index = pd.Series(np.arange(len(block.data)))
            if mode == "by stim":
                labels = get_labels_by_combining_columns(
                    block,
                    ["Class", "Call Type", "Bird Name"],
                    lambda x: "{} {} {}".format(x[2], x[0], x[1])
                )
            elif mode == "by reward":
                labels = get_labels_by_combining_columns(
                    block,
                    ["Class", "Call Type"],
                    lambda x: "{} {}".format(x[0], x[1])
                )

            fig = plot_data(block, labels, label_order=lambda x: x.split()[1], index_by="trial", label_to_color=color_by_reward)
            plt.title("{} {}".format(bird, block.date.strftime("%d%m%y")))

            plt.show()
            plt.close(fig)


def view_stim_data(date, bird):
    if date is None:
        return

    blocks, stim_blocks = cached_load(bird, date)

    for block, stims in zip(blocks, stim_blocks):
        if block.date == date:
            display(HTML(stims[["Bird Name", "Call Type", "Class", "Trials"]].to_html()))
            
            
def view_stats(date, bird):
    if date is None:
        return

    blocks, stim_blocks = cached_load(bird, date)

    with block_print():
        data = peck_data(blocks)
    display(HTML(data.to_html()))


date_picker = widgets.DatePicker(
    description='Pick a Date',
    value=datetime.date.today() - datetime.timedelta(days=1),
    disabled=False
)
bird_picker = widgets.Dropdown(
    options=BIRDS,
    value=BIRDS[0],
    description="Bird",
    disabled=False,
)
mode_picker = widgets.RadioButtons(
    options=["by stim", "by reward"],
    description='view mode:',
    disabled=False
)


out_plot = widgets.interactive_output(
    view_pecking_data, 
    {
        "date": date_picker,
        "bird": bird_picker,
        "mode": mode_picker
    }
)
out_stim = widgets.interactive_output(
    view_stim_data, 
    {
        "date": date_picker,
        "bird": bird_picker,
    }
)
out_stats = widgets.interactive_output(
    view_stats, 
    {
        "date": date_picker,
        "bird": bird_picker,
    }
)

widgets.VBox([
    out_stats,
    widgets.HBox([widgets.VBox([date_picker, bird_picker, mode_picker, out_stim]), out_plot]),
])


VBox(children=(Output(), HBox(children=(VBox(children=(DatePicker(value=datetime.date(2019, 3, 11), descriptio…