In [None]:
%matplotlib widget
import flammkuchen as fl
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm import tqdm

import lotr.plotting as pltltr
from lotr import LotrExperiment, dataset_folders
from lotr.utils import convolve_with_tau, crop, interpolate, zscore

COLS = pltltr.COLS

In [None]:
fish_with_eyes = [
    f for f in dataset_folders if "eyes" in f.name and "noeyes" not in f.name
]

In [None]:
def find_above_threshold(trace, threshold, skipval=25):
    events = []
    k = 0

    while k < len(trace):
        if trace[k] > threshold:
            events.append(k)

            k += skipval

        k += 1

    return np.array(events)

In [None]:
def _process_reg(reg):
    reg = convolve_with_tau(reg, int(TAU_S * exp.fs))
    return zscore(reg)

In [None]:
np.random.seed(24324215)
TAU_S = 5
WND = 300
DEF_FN = 5
SACC_THR = 0.025
MIN_BOUT_DISTANCE = 5
means_l = []
means_r = []

pre_int_s = 5
post_int_s = 20

pre_int_pts = pre_int_s * DEF_FN
post_int_pts = post_int_s * DEF_FN
x_arr = np.arange(pre_int_pts + post_int_pts) / DEF_FN - pre_int_s

means_eye_l = []
means_eye_r = []

results_df = []
for path in tqdm(fish_with_eyes):
    exp = LotrExperiment(path)
    dlc_df = fl.load(path / "behavior_from_dlc.h5", "/data")

    fictive_head = exp.fictive_heading
    phase = exp.network_phase

    eyes_arr = dlc_df["rt_eye_medfilt"] + dlc_df["lf_eye_medfilt"]
    interp_eye = interpolate(dlc_df["t"], eyes_arr, exp.time_arr)

    data_diff_df = pd.DataFrame(
        dict(
            phase=_process_reg(np.diff(np.unwrap(phase))),
            mov_regr=_process_reg(np.diff(fictive_head)),
            eye_pos_regr=_process_reg(np.diff(interp_eye)),
        )
    )

    saccades = find_above_threshold(np.abs(np.diff(interp_eye)), SACC_THR)
    for i in range(len(saccades)):
        min_bout_dist_from_sacc = np.min(
            np.abs(saccades[i] - exp.bouts_df["idx_imaging"])
        )
        if min_bout_dist_from_sacc < (MIN_BOUT_DISTANCE * exp.fn):
            saccades[i] = -1
    saccades = saccades[saccades > 0]
    saccades_l = saccades[np.diff(interp_eye)[saccades] > 0]
    saccades_r = saccades[np.diff(interp_eye)[saccades] < 0]

    for c, pts, ls, ls_eyes in zip(
        ["r", "b"],
        [saccades_l, saccades_r],
        [means_l, means_r],
        [means_eye_l, means_eye_r],
    ):
        sac_eye = crop(np.unwrap(interp_eye), np.array(pts), pre_int=25, post_int=100)
        sac_eye = sac_eye - np.nanmean(sac_eye[:10, :], 0)

        sac_resp = crop(np.unwrap(phase), np.array(pts), pre_int=25, post_int=100)
        sac_resp = sac_resp - np.nanmean(sac_resp[8:10, :], 0)

        if len(pts) > 2:
            ls.append(np.nanmean(sac_resp, 1))
            ls_eyes.append(np.nanmean(sac_eye, 1))


results_df = pd.DataFrame(results_df)
means_r = np.array(means_r)
means_l = np.array(means_l)

means_eye_r = np.array(means_eye_r)
means_eye_l = np.array(means_eye_l)

In [None]:
f, axs = plt.subplots(
    1, 2, figsize=(5, 2.5), gridspec_kw=dict(left=0.2, bottom=0.2, right=1, wspace=0.5)
)


for ax, toplot, title in zip(
    axs[::-1],
    [[means_eye_l, means_eye_r], [means_l, means_r]],
    ["Gaze angle", "Network phase"],
):
    for c, mn, lab in zip(
        [COLS["qualitative"][0], COLS["qualitative"][1]],
        toplot,
        ["Rightward saccades", "Leftward saccades"],
    ):
        ax.plot(x_arr, mn.T, lw=0.5, alpha=0.3, c=c, label="__nolegend__")
        ax.plot(x_arr, mn.mean(0), lw=1.5, c=c, label=lab)

    l = 0.5
    ax.set_title(title)
    ax.set(
        ylim=(-l, l),
        xlabel="Time from saccade (s)",
        **pltltr.get_pi_labels(coefs=[-0.125, -1 / 16, 0, 1 / 16, 0.125], ax="y"),
    )
    pltltr.despine(ax)

axs[1].legend(
    loc=2, bbox_to_anchor=(0.25, 0.25), labelcolor="linecolor", handlelength=0
)

pltltr.savefig("saccade_triggered", folder="S9")