In [1]:
import datetime
import glob
import os
import re
import sys
from contextlib import contextmanager

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display, HTML, clear_output
from ipywidgets import interact, interactive
from scipy.stats import fisher_exact

from pecking_analysis.peck_data import (
    load_pecking_days,
    get_labels_by_combining_columns,
    plot_data,
    peck_data,
    color_by_reward,
    windows_by_reward,
    get_dates,
    PythonCSV
)

pd.options.display.float_format = '{:,.4f}'.format

import warnings
warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

@contextmanager
def block_print():
    sys.stdout = open(os.devnull, 'w')
    yield 
    sys.stdout = sys.__stdout__

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
# pip install 'panda<2'

##  Fill in data directory containing folders for each subject

In [3]:
DATADIR = "/data/pecking_test/behavior/"

In [4]:
BIRDS = [
    os.path.basename(bird)
    for bird in sorted(glob.glob(os.path.join(DATADIR, "*")), key=os.path.getmtime, reverse=True)
    if re.search(r"^[a-zA-Z]{6}[X0-9]{6}[MF]?$", os.path.basename(bird))]

In [5]:
BIRDS

['HpiPur058093M',
 'BluYel037018F',
 'HpiRed004141M',
 'PurWhi127037F',
 'TesTes123123M',
 'PurGre124015M']

In [6]:
MAX_DATE_RANGE = 7

## Run this cell

In [7]:
def compute_odds_ratio(group, versus, *args):
    dums = [
        [len(group[group["Response"] == True]),
        len(group[group["Response"] == False])],
        [len(versus[versus["Response"] == True]),
        len(versus[versus["Response"] == False])]
    ]
#     if dums[1][0] == 0 and args[0] == "less":
#         dums[1][0] += 1
#     if dums[0][0] == 0 and args[0] == "greater":
#         dums[0][0] += 1
    _, pvalue = fisher_exact(dums, *args)

    for i in range(2):
        for j in range(2):
            if dums[i][j] == 0:
                dums[i][j] += 1

    return fisher_exact(dums, *args)[0], pvalue
 

def set_oddsratio_yticks(ax, biggest):
    ax.set_ylim(np.power(2., -biggest), np.power(2., biggest))
    ax.set_yscale("log")
    
    powers = np.arange(0, biggest + 1)
    n = len(powers)
    powers = powers[::n // 6 + 1]
    vals = np.concatenate([-powers, powers[1:]])

    ticks = np.power(2., vals)
    labels = [r"x{:d}".format(int(2 ** v)) if v >= 0 else r"x1/{:d}".format(int(2 ** -v)) for v in vals]
    
    ax.set_ylabel("Odds Ratio", fontsize=12)
    ax.set_xlabel("Trial", fontsize=12)
    ax.set_yticks(ticks)
    ax.set_yticklabels(labels, fontsize=12)
    ax.hlines(1, *plt.xlim(), linestyle="--", zorder=-1)


In [8]:
def cached_load(bird, date, date_end=None):
    if (bird, date, date_end) not in cached_load._cache:
        if date_end is not None:
            blocks, stim_blocks = load_pecking_days(os.path.join(DATADIR, bird), date_range=(date, date_end))
            cached_load._cache[(bird, date, date_end)] = (blocks, stim_blocks)
        else:
            blocks, stim_blocks = load_pecking_days(os.path.join(DATADIR, bird, date.strftime("%d%m%y")))
            cached_load._cache[(bird, date, date_end)] = (blocks, stim_blocks)

    return cached_load._cache[(bird, date, date_end)]
cached_load._cache = {}

fig = None

def view_pecking_data(date, date_end, bird, mode, show_stims_re, show_stims_un, window_size, sig_threshold=0.05):
    global fig
    if fig is not None:
        plt.close(fig)
    if date is None:
        return
    
    show_stims = show_stims_re + show_stims_un
    
    if date_end != "None" and (((date_end - date).days < 0) or ((date_end - date).days > MAX_DATE_RANGE)):
        if date in date_end_picker.options:
            date_end_picker.value = date
        else:
            date_end_picker.options = ["None"] + date_picker.options
            date_end_picker.value = "None"
        return
    
    # Fix dates
    valid_dates = get_dates(os.path.join(DATADIR, bird))
    valid_dates = list(reversed(valid_dates))
    old_date = date_picker.value
        
    if np.any([d1 != d2 for d1, d2 in zip(date_picker.options, valid_dates)]):
        date_picker.options = valid_dates
        if old_date not in valid_dates:
            date = valid_dates[0]
            date_picker.value = date
            old_date_end = date_end_picker.value
            date_end_picker.options = ["None"] + valid_dates
            date_end_picker.value = date
        return
    
    date_end_picker.options = ["None"] + valid_dates

    if date_end == "None":
        blocks, stim_blocks = cached_load(bird, date)
    else:
        blocks, stim_blocks = cached_load(bird, date, date_end)
    
    # THE MAIN STUFF
    for block, stims in zip(blocks, stim_blocks):
        if block.date == date:
            
            block.data.index = pd.Series(np.arange(len(block.data)))

            if mode == "windowedoddsratios":
                
                fig = plt.figure(figsize=(8, 4))
                
                grouped = block.data.groupby(["Class", "Call Type", "Bird Name"])
                results = []
                
                biggest = 0
                for (rewarded, call_type, bird_name), group in sorted(
                        grouped.groups.items(),
                        key=lambda x: (x[0][0], x[0][2])):
                
                    if len(show_stims) and bird_name not in show_stims:
                        continue
                        
                    group = block.data.iloc[group]
                    alts = block.data[block.data["Class"] != rewarded]

                    result = compute_odds_ratio(group, alts, "less" if rewarded == "Rewarded" else "greater")
                    results.append(["blue" if rewarded == "Rewarded" else "red", result])
                
                    windowed, windowed_rest = windows_by_reward(group, alts, rewarded == "Rewarded", n=window_size)
                    x = []
                    y = []
                    sigs = []
                    for window, rest in zip(windowed, windowed_rest):
                        x.append(np.mean(window["OverallTrial"]))
                        odds, pvalue = compute_odds_ratio(window, rest, "less" if rewarded == "Rewarded" else "greater")
                        y.append(odds)
                        sigs.append(pvalue < sig_threshold) # / len(windowed))
                
                    if len(x) > 1:
                        plt.plot(x, y, color="blue" if rewarded == "Rewarded" else "red", linewidth=1)
                    else:
                        plt.scatter(x, y, s=5, marker="x", color="blue" if rewarded == "Rewarded" else "red")

                    biggest = max(np.max(np.round(1 + np.abs(np.log2(np.array(y))))), biggest)

                    for x_, y_, sig_ in zip(x, y, sigs):
                        if sig_:
                            plt.scatter([x_], [y_], marker="d", s=100, color="blue" if rewarded == "Rewarded" else "red")

                plt.xlim(0, np.max(block.data["OverallTrial"]))
                set_oddsratio_yticks(plt.gca(), biggest)

                plt.show()
                return
            
            if mode == "oddsratios":
                
                fig = plt.figure(figsize=(10, 4))

                grouped = block.data.groupby(["Class", "Call Type", "Bird Name"])
                results = []
                labels = []
                
                bar_idx = 0
                labeled_already = {
                    "Rewarded": False,
                    "Unrewarded": False,
                }
                for (rewarded, call_type, bird_name), group in sorted(grouped.groups.items(),
                                                                      key=lambda x: (x[0][0], x[0][2])):
                    group = block.data.iloc[group]
                    alts = block.data[block.data["Class"] != rewarded]
                    
                    result = compute_odds_ratio(group, alts, "less" if rewarded == "Rewarded" else "greater")
                    results.append([rewarded, result])
                    labels.append("{}".format(bird_name))
        
                    if labeled_already[rewarded]:
                        plt.bar(bar_idx, result[0], color="blue" if rewarded == "Rewarded" else "red")
                    else:
                        plt.bar(bar_idx, result[0], color="blue" if rewarded == "Rewarded" else "red", label=rewarded)
                        labeled_already[rewarded] = True
                    bar_idx += 1

                for i, v in enumerate(results):
                    print(v[1][1])
                    plt.text(i, v[1][0], "*" if v[1][1] < 0.05 else "", color='black', fontweight='bold',
                            horizontalalignment="center")
                plt.xticks(np.arange(len(results)) + 0.5, labels, rotation=315, fontsize=12)

                biggest = np.max(np.round(1 + np.abs(np.log2(np.array([r[1][0] for r in results])))))
                set_oddsratio_yticks(plt.gca(), biggest)
                plt.legend(loc="upper left", fontsize=12)
                plt.show()
                
                df = pd.DataFrame(
                    data=list(zip(*[
                        labels,
                        [r[1][0] for r in results],
                        [r[1][1] for r in results]
                    ])),
                    columns=["Bird", "Odds Ratio", "p-value"],
                )
                with pd.option_context('display.float_format', '{:,.4f}'.format):
                    display(HTML(df.to_html(index=False)))
                return
            
            old_data = block.data

            orig_len = None
            if mode == "bystim":
                if len(show_stims):
                    orig_len = len(block.data)
                    block.data = block.data[block.data["Bird Name"].isin(show_stims)]
                labels = get_labels_by_combining_columns(
                    block,
                    ["Class", "Call Type", "Bird Name"],
                    lambda x: "{} {} {}".format(x[2], x[0], x[1])
                )
                
            elif mode == "byreward":
                labels = get_labels_by_combining_columns(
                    block,
                    ["Class", "Call Type"],
                    lambda x: "{} {}".format(x[0], x[1])
                )
#             print(labels)

            fig = plot_data(block, labels, force_len=orig_len, label_order=lambda x: x.split()[1], index_by="trial", label_to_color=color_by_reward)

#             if fig is None:
#                 display(HTML(widgets.Label("No valid data in this date range").to_html()))
            plt.title("{} {}".format(bird, block.date.strftime("%d%m%y")))

            plt.show()
            block.data = old_data

def view_stim_data(date, date_end, bird):
    if date is None:
        return
    
    if date_end == "None":
        blocks, stim_blocks = cached_load(bird, date)
    else:
        blocks, stim_blocks = cached_load(bird, date, date_end)
        
    
    for block, stims in zip(blocks, stim_blocks):
        if block.date == date:
            stim_picker_re.options = [
                ("{} ({})".format(bird_name, rewarded), bird_name)
                for bird_name, rewarded in zip(stims["Bird Name"], stims["Class"])
                if rewarded == "Rewarded"
            ]
            stim_picker_un.options = [
                ("{} ({})".format(bird_name, rewarded), bird_name)
                for bird_name, rewarded in zip(stims["Bird Name"], stims["Class"])
                if rewarded != "Rewarded"
            ]

#             print(stim_picker.options)
#             stims["Bird Name"]
            display(HTML(stims[["Bird Name", "Call Type", "Class", "Trials"]].to_html()))

    
def view_raw_dataframe(date, date_end, show_stims_re, show_stims_un, bird):
    show_stims = show_stims_re + show_stims_un
    print(bird,date)
    if date_end == "None":
        blocks, stim_blocks = cached_load(bird, date)
    else:
        blocks, stim_blocks = cached_load(bird, date, date_end)
    
    print(blocks,stim_blocks)
    for block, stims in zip(blocks, stim_blocks):
        if block.date == date:
            if len(show_stims):
                old_data = block.data
                block.data = block.data[block.data["Bird Name"].isin(show_stims)]
            display(HTML(block.data[[
                                     "Date",
                                     "Trial",
                                     "Bird Name",
                                     "Class",
                                     "Response",
                                     "RT",
                                     "Call Type", 
                                     "Stimulus Name",
#                                      "Filename"
                                    ]].to_html(index=False)))
            if len(show_stims):
                block.data = old_data
        
        
def view_stats(date, date_end, bird):
    if date is None:
        return

    if date_end == "None":
        blocks, stim_blocks = cached_load(bird, date)
    else:
        blocks, stim_blocks = cached_load(bird, date, date_end)

    with block_print():
        data = peck_data(blocks)
    display(HTML(data.to_html()))


date_picker = widgets.Dropdown(
    options=[datetime.date.today() - datetime.timedelta(days=x) for x in range(60)],
    value=datetime.date.today(),
    description="Date",
    disabled=False,
)
date_end_picker = widgets.Dropdown(
    options=["None"] + [datetime.date.today() - datetime.timedelta(days=x) for x in range(
        (datetime.date.today() - date_picker.value).days
    )],
    value="None",
    description="Date end",
    disabled=False,
)

save_button = widgets.Button(
    description="Save Figure",
)
path_input = widgets.Text(
    value="figure",
    placeholder='figure',
    description='Save fig to:',
    disabled=False
)


bird_picker = widgets.Dropdown(
    options=BIRDS,
    value=BIRDS[0],
    description="Bird",
    disabled=False,
)
mode_picker = widgets.RadioButtons(
    options=[
        ("by stim", "bystim"),
        ("by reward", "byreward"),
        ("odds ratio", "oddsratios"),
        ("odds ratio windowed", "windowedoddsratios"),
    ],
    description='view mode:',
    disabled=False
)

def view_all(*args):
    stim_picker_re.value = []
    stim_picker_un.value = []
    
view_all_button = widgets.Button(
    description="View all stims",
)
view_all_button.on_click(view_all)
stim_picker_re = widgets.SelectMultiple(
    options=[],
    value=[],
#     description='Select stims\n(ctrl click)',
    disabled=False
)
stim_picker_un = widgets.SelectMultiple(
    options=[],
    value=[],
#     description='Select stims\n(ctrl click)',
    disabled=False
)

window_size_picker = widgets.IntSlider(
    description="Window size",
    min=2,
    step=1,
    max=30,
    continuous_update=False,
    value=4,
)

def save_fig(butt):
    global fig
    if not os.path.exists("saved_figures"):
        os.makedirs("saved_figures")
    path = os.path.join("saved_figures", "{}_{}.eps".format(path_input.value, mode_picker.value))
    i = 0
    while os.path.exists(path):
        path = os.path.join("saved_figures", "{}_{}_{}.eps".format(path_input.value, mode_picker.value, i))
        i += 1
    fig.tight_layout()
    fig.savefig(path, format="eps")
    
save_button.on_click(save_fig)


out_plot = widgets.interactive_output(
    view_pecking_data, 
    {
        "date": date_picker,
        "date_end": date_end_picker,
        "bird": bird_picker,
        "window_size": window_size_picker,
        "show_stims_re": stim_picker_re,
        "show_stims_un": stim_picker_un,
        "mode": mode_picker
    }
)

out_stim = widgets.interactive_output(
    view_stim_data, 
    {
        "date": date_picker,
        "date_end": date_end_picker,
        "bird": bird_picker,
    }
)

raw_dataframe = widgets.interactive_output(
    view_raw_dataframe, 
    {
        "date": date_picker,
        "date_end": date_end_picker,
        "show_stims_re": stim_picker_re,
        "show_stims_un": stim_picker_un,
        "bird": bird_picker,
    }
)


out_stats = widgets.interactive_output(
    view_stats, 
    {
        "date": date_picker,
        "date_end": date_end_picker,
        "bird": bird_picker,
    }
)

widgets.VBox([
    out_stats,
    widgets.HBox([widgets.VBox([
        bird_picker,
        date_picker,
        date_end_picker,
        widgets.VBox([path_input, save_button]),
        mode_picker,
        view_all_button,
        widgets.VBox([stim_picker_re, stim_picker_un]),
        window_size_picker,
        widgets.Label("Stimulus list:"),
        out_stim
    ]), out_plot]),
    widgets.HBox([
        
        widgets.Label("Raw trial data: "),
        raw_dataframe
    ])
])


VBox(children=(Output(), HBox(children=(VBox(children=(Dropdown(description='Bird', options=('HpiPur058093M', …

In [None]:
for key in cached_load._cache.keys():
    print(key[0])

In [None]:
cached_load._cache

In [None]:
bird = 'PurWhi127037F'
date = datetime.date(2025, 2, 27)
blocks, stim_blocks = load_pecking_days(os.path.join(DATADIR, bird, date.strftime("%d%m%y")))

In [None]:
file_list = []
directory = os.path.join(DATADIR, bird, date.strftime("%d%m%y"))

if re.search("^[0-9]{6}$", os.path.basename(directory)):
    csvs = glob.glob(os.path.join(directory, "*.csv"))
    for csv_file in csvs:
        if os.path.getsize(csv_file) < 500:
            # don't load empty or tiny csv files
            continue
        file_list.append(csv_file)
else:
    for date in get_dates(directory):
        if date_range is None or (date_range[0] <= date <= date_range[1]):
            date_folder = os.path.join(directory, date.strftime("%d%m%y"))
            csvs = glob.glob(os.path.join(date_folder, "*.csv"))

            for csv_file in csvs:
                if os.path.getsize(csv_file) < 500:
                    # don't load empty or tiny csv files
                    continue
                file_list.append(csv_file)

In [None]:
file_list

In [None]:
blocks = PythonCSV.parse(file_list)

In [None]:
def parse_filename(fname):
    pattern = "_".join(["(?P<name>(?:[A-Za-z]{3}){1,2}(?:[X0-9]{2}){1,2}[MF]?)",
                        "trialdata",
                        "(?P<datestr>[0-9]*)\.csv"])
    pattern2 = "_".join(["(?P<name>(?:[A-Za-z]{3}){1,2}(?:[X0-9]{2}){1,2}[MF]?)",
                        "(?P<datestr>[0-9]*)",
                        "trialdata.csv"])
    pattern3 = "_".join(["(?P<name>(?:[A-Za-z]{3}){1,2}(?:[X0-9]{3}){1,2}[MF]?)",
                        "(?P<datestr>[0-9]*)",
                        "trialdata.csv"])

    m = re.match(pattern3, fname, re.IGNORECASE)
    if m is not None:
        m = m.groupdict()
        if m["name"] is None:
            m["name"] = "Unknown"
        if m["datestr"] is None:
            m["datestr"] = "Unknown"

        return m
    m = re.match(pattern2, fname, re.IGNORECASE)
    if m is not None:
        m = m.groupdict()
        if m["name"] is None:
            m["name"] = "Unknown"
        if m["datestr"] is None:
            m["datestr"] = "Unknown"
        return m

In [None]:
blocks = list()

for file in file_list:

    fname = os.path.split(file)[1]
    m = parse_filename(fname)
    if m is not None:
        datetime = pd.to_datetime(m["datestr"])
        blk = objects.Block(name=cls.get_name(m["name"]),
                                    date=datetime.date(),
                                    start=datetime.time(),
                                    filename=file,
                                    data=get_block_data(file))

        blocks.append(blk)
    else:
        print("Could not parse filename %s. Skipping" % file)

In [None]:
m

In [None]:
import pandas as pd

In [None]:
b = pd.read_csv("/data/drive/pecking_test_data/plump_synced/behavior/Lilac91M/170224/Lilac91M_20240217092151_trialdata.csv")

In [None]:
sum(b.reward)

In [None]:
for l, label_df in b.groupby('reward'):
    win_size = 15
    win_size_half = win_size // 2
    rolled = label_df["response"].rolling(win_size, center=True).mean()
    if len(rolled) > win_size_half:
        rolled.iloc[:win_size_half] = rolled.iloc[win_size_half]
        rolled.iloc[-win_size_half:] = rolled.iloc[-win_size_half - 1]
    plt.plot(
            label_df.index,
            rolled,
            label=l,
            alpha=1.0,  # 0.5
            linewidth=3
        )

In [None]:
plt.plot(b.data.Trial[b.data.Reward],b.data.Response[b.data.Reward].values)
plt.plot(b.data.Trial[~b.data.Reward],b.data.Response[~b.data.Reward].values)

In [None]:
labels = get_labels_by_combining_columns(
    b,
    ["Class", "Call Type", "Bird Name"],
    lambda x: "{} {} {}".format(x[2], x[0], x[1])
)
plot_data(b,labels)

In [None]:
import pandas as pd

In [None]:
pd.read_csv("/data/drive/pecking_test_data/plump_synced/behavior/XXXBla31XXF/260223/XXXBla31XXF_20230226085339_trialdata.csv")

In [None]:
pd.read_csv("/data/drive/pecking_test_data/plump_synced/behavior/XXXBla31XXF/250223/XXXBla31XXF_trialdata_20230225091940.csv")