In [2]:
import datetime
import glob
import os
import re
import sys

# import cv2
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.io.wavfile
import traitlets

from IPython.display import display, HTML, Audio, Pretty
from ipywidgets import interact, interactive
from soundsig.sound import spectrogram, plot_spectrogram
from collections import deque

from pecking_analysis.peck_data import (
    load_pecking_days,
    get_dates
)

pd.options.display.float_format = '{:,.4f}'.format

In [3]:
import matplotlib.pyplot as plt

# Tell the program where to look for the bird pecking test folders and the stimulus folders

In [9]:
DATADIR = "/home/fet/data/"
STIMULI_DIR = "/home/fet/stimuli_chubby/"

In [10]:
# BIRDS = ["XXXOra0039F"]
BIRDS = [
    os.path.basename(bird)
    for bird in sorted(glob.glob(os.path.join(DATADIR, "*")), key=os.path.getmtime, reverse=True)
    if re.search(r"^[a-zA-Z]{6}[0-9]{4}[MF]?$", os.path.basename(bird))]

In [11]:
## Helper functions

In [14]:
def cached_load(bird, date):
    if (bird, date) not in cached_load._cache:
        blocks, stim_blocks = load_pecking_days(
            os.path.join(DATADIR, bird, date.strftime("%d%m%y")),
            ("Playback",)
        )
        cached_load._cache[(bird, date)] = (blocks, stim_blocks)
    return cached_load._cache[(bird, date)]
cached_load._cache = {}

def get_new_location(datadir, stimulus_path):
    """Swap out the stimulus directory of a stimulus path with a local one
    """
    rest = stimulus_path.split("stimuli/")[1]
    new_path = os.path.join(STIMULI_DIR, rest)
    return new_path

offsets = deque([], 10)

def view_pecking_data(date, bird, trial):
    if date is None:
        return

    # UJpdate the dates dropdown with dates found for this bird.
    valid_dates = get_dates(os.path.join(DATADIR, bird))
    valid_dates = list(reversed(valid_dates))
    date_picker.options = valid_dates
    if date not in valid_dates:
        date_picker.value = valid_dates[0]
        date = valid_dates[0]

    blocks, stim_blocks = cached_load(bird, date)
    
    if not blocks:
        status_field.value = "Data for {} does not exist".format(date)
        display(None)
    else:
        status_field.value = "Loading {}".format(date)

    for block, stims in zip(blocks, stim_blocks):
        if block.date == date:
            block.data.index = pd.Series(np.arange(len(block.data)))
            stimuli = block.data[["Trial", "Bird Name", "Stimulus"]]
            
            slider.max = len(stimuli) - 1
            
            curr_stim.value = "Trial {Trial}, Vocalizer {Bird Name}".format(**stimuli.iloc[trial])
            date_dir = os.path.join(DATADIR, bird, block.date.strftime("%d%m%y"))
            
            t_rec, f_rec, spec_rec, widget_rec = get_playback_trial(
                date_dir, 
                os.path.basename(os.path.splitext(block.annotations["filename"])[0]),
                stimuli.iloc[trial]["Trial"],
            )
            t_stim, f_stim, spec_stim, widget_stim = get_stimulus(get_new_location(DATADIR, stimuli.iloc[trial]["Stimulus"]))
            spec_rec += 1e-12
            spec_stim += 1e-12
            
            # figure out alignmnet
            template = spec_stim[(f_stim > 2000) & (f_stim < 5000), :spec_stim.shape[1] // 2]
#             convolved = cv2.matchTemplate(
#                 spec_rec[(f_rec > 2000) & (f_rec < 5000)].astype(np.float32),
#                 template.astype(np.float32),
#                 cv2.TM_CCOEFF_NORMED,
#             )[0]
#             peak = t_rec[np.argmax(convolved[:template.shape[1]])]
#             if peak > 1:
#                 peak = np.mean(offsets)
#             else:
#                 offsets.append(peak)
#             peak = np.mean(offsets)
            peak = 0

            fig = plt.figure(figsize=(10, 5))
            ax0 = fig.add_axes([0, 0, 1, 0.7])
            ax1 = fig.add_axes([0, 0.7, 1, 0.3])
            
            plot_spectrogram(t_rec - peak, f_rec, spec_rec, ax=ax0, colorbar=False)
            ax0.text(0.01, 0.99, 'Recorded', fontsize=14, horizontalalignment='left', verticalalignment='top', transform = ax0.transAxes)
            
            plot_spectrogram(t_stim, f_stim, spec_stim, ax=ax1, colorbar=False)
            ax1.text(0.01, 0.99, 'Stimulus: {}'.format(stimuli.iloc[trial]["Stimulus"]), fontsize=14, horizontalalignment='left', verticalalignment='top', transform = ax1.transAxes)
            
            ax0.set_xlim(-1, 7)
            ax1.set_xlim(-1, 7)
            ax0.set_ylim(0, 10000)
            ax1.set_ylim(0, 10000)

            status_field.value = ""
            plt.show()
            plt.close(fig)
            
            display(widget_stim)
            display(widget_rec)
            
            
def get_playback_trial(datedir, session, trial):
    stimulus_file = None
    recording_file = os.path.join(datedir, "audio_recordings", session, "trial{}.wav".format(trial))
    fs, data = scipy.io.wavfile.read(recording_file)
    audio_widget = Audio(recording_file)

    t_spec, f_spec, spec, rms = spectrogram(data, fs, 500, 100, cmplx=False)
    return t_spec, f_spec, spec, audio_widget


def get_stimulus(stim_path):    
    fs, data = scipy.io.wavfile.read(stim_path)
    t_spec, f_spec, spec, rms = spectrogram(data, fs, 500, 100, cmplx=False)
    audio_widget = Audio(stim_path)

    return t_spec, f_spec, spec, audio_widget

def next_page(x):
    slider.value += 1
    
def prev_page(x):
    slider.value -= 1
    
def first_page(x):
    slider.value = slider.min
    
def last_page(x):
    slider.value = slider.max

# Run this cell to view preference test trials

In [15]:
first_page_widget = widgets.Button(description='First')
last_page_widget = widgets.Button(description="Last")
next_page_widget = widgets.Button(description='Next Stim', icon='fa-arrow-right')
previous_page_widget = widgets.Button(description='Previous Stim', icon='fa-arrow-left')
first_page_widget.on_click(first_page)
last_page_widget.on_click(last_page)
next_page_widget.on_click(next_page)
previous_page_widget.on_click(prev_page)

curr_stim = widgets.Label(value="")
status_field = widgets.Label(value="")

date_picker = widgets.Dropdown(
    options=[datetime.date.today() - datetime.timedelta(days=x) for x in range(60)],
    value=datetime.date.today() - datetime.timedelta(days=1),
    description="Date",
    disabled=False,
)

bird_picker = widgets.Dropdown(
    options=BIRDS,
    value=BIRDS[0],
    description="Bird",
    disabled=False,
)


slider = widgets.IntSlider(
    value=0,
    min=0,
    max=100,
    step=1,
    description='Trial:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

out_plot = widgets.interactive_output(
    view_pecking_data, 
    {
        "date": date_picker,
        "bird": bird_picker,
        "trial": slider,
    }
)

widgets.VBox([
    status_field,
    widgets.HBox([
        widgets.VBox([
            bird_picker,
            date_picker,
        ]),
    ]),
    widgets.VBox([
        widgets.HBox([
            first_page_widget,
            previous_page_widget,
            slider,
            next_page_widget,
            last_page_widget
        ])
    ]),
    out_plot,
    curr_stim,
])