In [None]:
%matplotlib widget
import flammkuchen as fl
import lotr.plotting as pltltr
import numpy as np
import pandas as pd
from lotr import DATASET_LOCATION, LotrExperiment, dataset_folders
from lotr.data_preprocessing.dlc_tracking import export_dlc_behavior
from lotr.utils import convolve_with_tau, interpolate, nan_phase_jumps, zscore
from matplotlib import pyplot as plt
from scipy.stats import ttest_rel, wilcoxon
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
from tqdm import tqdm
from lotr.utils import crop

COLS = pltltr.COLS

In [None]:
from lotr import DATASET_LOCATION

fish_with_eyes = [
    f.parent
    for f in DATASET_LOCATION.glob("*/*eyes*/selected.h5")
    if "noeyes" not in f.parent.name
]
fish_with_eyes

In [None]:
from numba import njit

def find_above_threshold(trace, threshold, skipval=25):
    events = []
    k = 0
    
    while k < len(trace):
        if trace[k] > threshold:
            events.append(k)
            
            k += skipval
            
        k += 1
        
    return events

In [None]:
def _process_reg(reg):
    reg = convolve_with_tau(reg, int(TAU_S * exp.fs))
    return zscore(reg)


def draw_train_test(n_pts, wnd):
    indexes = np.arange(n_pts - wnd, dtype=np.float)

    # Draw beginning of training window:
    train_i_start = np.random.randint(0, n_pts - wnd)
    train_i_end = train_i_start + wnd

    # Exclude training window from possible choiches:
    indexes[train_i_start:train_i_end] = np.nan
    # Exclude beginning and end if we are to close to limits:
    if train_i_start < wnd:
        indexes[:train_i_start] = np.nan
    if n_pts - train_i_end < wnd:
        indexes[train_i_end:] = np.nan

    indexes = indexes[~np.isnan(indexes)]  # .astype(np.int)
    test_i_start = int(np.random.choice(indexes))

    return slice(train_i_start, train_i_end), slice(test_i_start, test_i_start + wnd)

In [None]:
np.random.seed(24324215)
TAU_S = 5
WND = 300
results_df = []
for path in tqdm(fish_with_eyes):
    exp = LotrExperiment(path)
    dlc_df = fl.load(path / "behavior_from_dlc.h5", "/data")

    fictive_head = exp.fictive_heading
    phase = exp.network_phase

    eyes_arr = dlc_df["rt_eye_medfilt"] + dlc_df["lf_eye_medfilt"]
    interp_eye = interpolate(dlc_df["t"], eyes_arr, exp.time_arr)

    data_diff_df = pd.DataFrame(
        dict(
            phase=_process_reg(np.diff(np.unwrap(phase))),
            mov_regr=_process_reg(np.diff(fictive_head)),
            eye_pos_regr=_process_reg(np.diff(interp_eye)),
        )
    )
    #plt.plot(_process_reg(np.diff(np.unwrap(phase))))
    # plt.plot(zscore(interp_eye))
    # plt.plot(zscore(_process_reg(interp_eye)))
    #plt.plot(_process_reg(np.diff(interp_eye)))
    #plt.plot(_process_reg(np.diff(fictive_head)))
    saccades_l = find_above_threshold(np.diff(interp_eye), 0.025)
    saccades_r = find_above_threshold(-np.diff(interp_eye), 0.025)

    x_arr = np.arange(125)/5 - 5
    f, ax = plt.subplots(1, 2, figsize=(6, 2.5))
    for c, pts in zip(["r", "b"], [saccades_l, saccades_r]):
        sac_resp = crop(np.unwrap(interp_eye), np.array(pts), pre_int=25, post_int=100)
        sac_resp = sac_resp - np.nanmean(sac_resp[:10, :], 0)
        ax[0].plot(x_arr, sac_resp, c=c, lw=0.5, alpha=0.3)
        ax[0].plot(x_arr, np.nanmean(sac_resp, 1), c=c, lw=1.5)
        ax[0].set_ylim(-0.3, 0.3)

        sac_resp = crop(np.unwrap(phase), np.array(pts), pre_int=25, post_int=100)
        sac_resp = sac_resp - np.nanmean(sac_resp[:10, :], 0)
        ax[1].plot(x_arr, sac_resp, c=c, lw=0.5, alpha=0.3)
        ax[1].plot(x_arr, np.nanmean(sac_resp, 1), c=c, lw=1.5)
        ax[1].set_ylim(-0.3, 0.3)
    plt.show()

    """
    for to_fit, tofit_lab in zip([data_diff_df], ["data"]):
        # all_coefs = []

        for _ in range(500):
            res_df = dict()
            for lab, cols in zip(
                ["tail", "eye", "both"],
                [["mov_regr"], ["eye_pos_regr"], ["mov_regr", "eye_pos_regr"]],
            ):
                train, test = draw_train_test(len(phase), WND * exp.fs)

                regr = LinearRegression()
                regr.fit(
                    to_fit[cols].values[train, :], to_fit["phase"].values[train],
                )

                # if len(cols) == 2:
                prediction = regr.predict(to_fit[cols].values[test, :])
                # else:
                #    prediction = to_fit[cols[0]].values[test]
                res_df[lab] = np.abs(
                    np.corrcoef(prediction, to_fit["phase"][test])[0, 1]
                )
            # res_df["batch"] = tname
            res_df["data"] = tofit_lab
            res_df["fid"] = path.name

            results_df.append(res_df)
            """

results_df = pd.DataFrame(results_df)

In [None]:
saccades_l = find_above_threshold(np.diff(interp_eye), 0.025)
saccades_r = find_above_threshold(-np.diff(interp_eye), 0.025)

f, ax = plt.subplots(1, 2, figsize=(6, 2.5))
for c, pts in zip(["r", "b"], [saccades_l, saccades_r]):
    sac_resp = crop(np.unwrap(interp_eye), np.array(pts), pre_int=10, post_int=35)
    sac_resp = sac_resp - np.nanmean(sac_resp[:10, :], 0)
    ax[0].plot(sac_resp, c=c, lw=0.5, alpha=0.3)
    ax[0].plot(np.nanmean(sac_resp, 1), c=c, lw=1.5)
    ax[0].set_ylim(-0.3, 0.3)
    
    sac_resp = crop(np.unwrap(phase), np.array(pts), pre_int=10, post_int=25)
    sac_resp = sac_resp - np.nanmean(sac_resp[:10, :], 0)
    ax[1].plot(sac_resp, c=c, lw=0.5, alpha=0.3)
    ax[1].plot(np.nanmean(sac_resp, 1), c=c, lw=1.5)
    ax[1].set_ylim(-0.3, 0.3)
plt.show()

In [None]:
exp.bouts_df["idx_imaging"]

In [None]:
saccades_l - exp.bouts_df["idx_imaging"]