# Check behavior and adaptation

In this notebook, we **manually** check behavior quality of all experiments and input it in a column of the experiment dataframe (loaded from `exp_df_raw.h5`). The dataframe with this new column is then re-saved as `exp_df.h5`.

We also calculate an adaptation index for every fish.

In [None]:
%matplotlib widget
from ipywidgets import HBox, VBox, interact, interact_manual, widgets

In [None]:
from pathlib import Path

import flammkuchen as fl
import numpy as np
import pandas as pd
import seaborn as sns
from bouter.angles import reduce_to_pi
from ec_code.analysis_utils import *
from ec_code.file_utils import get_dataset_location
from ec_code.plotting_utils import *
from matplotlib import pyplot as plt
from tqdm import tqdm

# sns.set(style="ticks", palette="deep")

In [None]:
master_path = get_dataset_location("fb_effect")

## Check behavior quality

In [None]:
old_df = fl.load("/Users/luigipetrucco/Google Drive/data/ECs_E50/oldones/exp_df.h5")

In [None]:
old_df

In [None]:
exp_df = fl.load(master_path / "exp_df_raw.h5")

BEHAVIOR_DESCRIPTORS = ["-", "good", "bad"]
if "behavior" not in exp_df.columns:
    exp_df["behavior"] = BEHAVIOR_DESCRIPTORS[1]

output = widgets.Output()

with output:
    fig, ax = plt.subplots(constrained_layout=True, figsize=(10, 3))

(line,) = ax.plot([0, 0])
(line2,) = ax.plot([0, 0])

ax.set_ylim(-3.5, 3.5)
ax.set_xlim(0, 3600)
sns.despine()

fish_slider = widgets.IntSlider(
    value=0, min=0, max=len(exp_df) - 1, step=1, description="Fish n:",
)

behavior_buttons = widgets.ToggleButtons(
    options=BEHAVIOR_DESCRIPTORS, description="Behavior:"
)


def update_behavior(change):
    fid = exp_df.index[fish_slider.value]
    exp_df.loc[fid, "behavior"] = BEHAVIOR_DESCRIPTORS[change.new]


def update(change):
    """redraw line (update plot)"""
    fid = exp_df.index[change.new]
    beh_log = fl.load(master_path / "resamp_beh_dict.h5", f"/{fid}")
    b = beh_log.tail_sum
    sel = np.abs(b) // np.pi > 0
    b[sel] = np.mod(b[sel], np.pi)  # - np.pi
    # b = np.mod(b, np.pi) #+ ((-1)**(b // np.pi - 1)*np.pi)
    line.set_data(beh_log.index, b)
    line2.set_data(beh_log.index, np.abs(b) // np.pi)
    fig.canvas.draw()
    b_idx = BEHAVIOR_DESCRIPTORS.index(exp_df.loc[fid, "behavior"])
    ax.set_title(f"{fid}, behavior: {BEHAVIOR_DESCRIPTORS[b_idx]}")
    behavior_buttons.set_trait("index", b_idx)


# connect callbacks and traits:
fish_slider.observe(update, "value")
behavior_buttons.observe(update_behavior, "index")

controls = widgets.HBox([fish_slider, behavior_buttons])
# Hacky update:
fish_slider.value = 1
fish_slider.value = 0
widgets.VBox([output, controls])

In [None]:
fl.save(master_path / "exp_df.h5", exp_df)

## Check adaptation

In [None]:
# Compute motor adaptation as -log10 of the pvalue of the difference between bouts w/ and w/o visual feedback,
# computed using the kstest test on the distributions:
from scipy.stats import kstest

bouts_df = fl.load(master_path / "bouts_df.h5")

exp_df["adaptation"] = np.nan
for fid in tqdm(exp_df.index):
    fish_bouts_df = bouts_df.loc[
        (bouts_df["duration"] > min_dur_s) & (bouts_df["fid"] == fid),
        [bout_param_stat, "gain"],
    ]
    g0_dur, g1_dur = [
        fish_bouts_df.loc[fish_bouts_df["gain"] == g, bout_param_stat].values
        for g in [0, 1]
    ]
    exp_df.loc[fid, "adaptation"] = -np.log10(kstest(g0_dur, g1_dur).pvalue)

fl.save(master_path / "exp_df.h5", exp_df)

# Motor adaptation - matching procedure plots

In [None]:
from bouter import utilities

In [None]:
exp_df = fl.load(master_path / "exp_df.h5")
cells_df = fl.load(master_path / "cells_df.h5")
traces_df = fl.load(master_path / "traces_df.h5")

In [None]:
fid = "200828_f4_clol"
min_dur_s = 0.05
bout_param_stat = "duration"

In [None]:
fish_bouts_df = bouts_df.loc[
    (bouts_df["duration"] > min_dur_s) & (bouts_df["fid"] == fid),
    [bout_param_stat, "gain", "matched"],
]
g0_dur, g1_dur = [
    fish_bouts_df.loc[fish_bouts_df["gain"] == g, bout_param_stat].values
    for g in [0, 1]
]
g0_dur_m, g1_dur_m = [
    fish_bouts_df.loc[
        (fish_bouts_df["gain"] == g) & (fish_bouts_df["matched"]), bout_param_stat
    ].values
    for g in [0, 1]
]

fig = plt.figure(figsize=(3.5, 2.5))
x = np.arange(0, 2, 0.1)
for i, (d, m) in enumerate(zip([g0_dur, g1_dur], [g0_dur_m, g1_dur_m])):
    f, bins = np.histogram(m, x)
    # plt.fill_between((bins[:-1] + bins[1:])/2, np.zeros(len(f)), f,
    #                 label=f"match; gain{i}", linewidth=0, facecolor=cols[i], alpha=0.5, step="mid")

    f, bins = np.histogram(d, x)
    plt.step(
        (bins[:-1] + bins[1:]) / 2,
        f,
        label=f"all: gain{i}",
        linewidth=1,
        c=cols[i],
        alpha=0.8,
        where="mid",
    )

sns.despine()
plt.xlabel("bout duration (s)")
plt.ylabel("count")
plt.legend(frameon=False)
plt.tight_layout()

fig.savefig("/Users/luigipetrucco/Desktop/bout_nomatching.pdf")

In [None]:
dt = 0.2
pre_int_s = 2
post_int_s = 6

fish_bouts = bouts_df.loc[
    (bouts_df["fid"] == fid) & bouts_df["mindist_included"], :
].copy()

timepoints = fish_bouts["t_start"]
traces_block = traces_df[cells_df[cells_df["fid"] == fid].index].values
start_idxs = np.round(timepoints / dt).astype(np.int)
bt_crop_f = utilities.crop(
    traces_block, start_idxs, pre_int=int(pre_int_s / dt), post_int=int(post_int_s / dt)
)

mean_resps = bt_crop_f[10:40, :, :].mean(0)

In [None]:
f = plt.figure(figsize=(4, 3))
i = 4
for g in [0, 1]:
    plt.scatter(
        fish_bouts.loc[(fish_bouts["gain"] == g), "duration"],
        mean_resps[(fish_bouts["gain"] == g), i],
        c=cols[g],
        s=10,
        label=f"gain {g}",
    )

plt.ylabel("max dF/F (s)")
plt.xlabel("bout duration (s)")
plt.title(f"cell id: {cells_df[cells_df['fid']==fid].index[i]}")
sns.despine()
plt.legend(frameon=False)
f.savefig("/Users/luigipetrucco/Desktop/resp_duration.pdf")

In [None]:
mean_resps.shape

In [None]:
(fish_bouts["gain"] == 0).shape