In [None]:
%matplotlib widget
import flammkuchen as fl
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from scipy.signal import medfilt
from scipy.stats import wilcoxon
from sklearn.linear_model import LinearRegression
from tqdm import tqdm

import lotr.plotting as pltltr
from lotr import LotrExperiment, dataset_folders, DATASET_LOCATION
from lotr.result_logging import ResultsLogger
from lotr.utils import convolve_with_tau, interpolate, zscore

logger = ResultsLogger()

COLS = pltltr.COLS

## Eye angle convention
 - negative angles: fish looking left
 - positive angles: fish looking right

In [None]:
fish_with_eyes = [
    f for f in dataset_folders if "eyes" in f.name and "noeyes" not in f.name
]

In [None]:
def _process_reg(reg):
    reg = convolve_with_tau(reg, int(TAU_S * exp.fs))
    return zscore(reg)


def draw_train_test(n_pts, wnd):
    indexes = np.arange(n_pts - wnd, dtype=float)

    # Draw beginning of training window:
    train_i_start = np.random.randint(0, n_pts - wnd)
    train_i_end = train_i_start + wnd

    # Exclude training window from possible choiches:
    indexes[train_i_start:train_i_end] = np.nan
    # Exclude beginning and end if we are to close to limits:
    if train_i_start < wnd:
        indexes[:train_i_start] = np.nan
    if n_pts - train_i_end < wnd:
        indexes[train_i_end:] = np.nan

    indexes = indexes[~np.isnan(indexes)]  # .astype(np.int)
    test_i_start = int(np.random.choice(indexes))

    return slice(train_i_start, train_i_end), slice(test_i_start, test_i_start + wnd)

In [None]:
reg_cols = ["gaze", "heading", "heading+gaze"]
np.random.seed(24324215)
TAU_S = 3
WND = 300
N_ITERATIONS = 500

results_df = []
for path in fish_with_eyes:
    exp = LotrExperiment(path)
    dlc_df = fl.load(path / "behavior_from_dlc.h5", "/data")

    fictive_head = exp.fictive_heading
    phase = exp.network_phase

    eyes_arr = dlc_df["rt_eye_medfilt"] + dlc_df["lf_eye_medfilt"]
    interp_eye = interpolate(dlc_df["t"], eyes_arr, exp.time_arr)

    data_diff_df = pd.DataFrame(
        dict(
            phase=np.diff(np.unwrap(phase)),
            mov_regr=_process_reg(np.diff(fictive_head)),
            eye_pos_regr=_process_reg(np.diff(interp_eye)),
        )
    )

    to_fit = data_diff_df

    for _ in tqdm(range(N_ITERATIONS)):
        res_df = dict()
        for lab, cols in zip(
            reg_cols,
            [["eye_pos_regr"], ["mov_regr"], ["mov_regr", "eye_pos_regr"]],
        ):
            train, test = draw_train_test(len(phase), WND * exp.fs)

            regr = LinearRegression()
            regr.fit(
                to_fit[cols].values[train, :],
                to_fit["phase"].values[train],
            )

            prediction = regr.predict(to_fit[cols].values[test, :])
            res_df[lab] = np.corrcoef(prediction, to_fit["phase"].values[test])[0, 1]
        res_df["fid"] = path.name

        results_df.append(res_df)


results_df = pd.DataFrame(results_df)

In [None]:
def _to_float(df):
    df.iloc[:, :].values = df.iloc[:, :].values.astype(np.float)


quantiles_df = (
    results_df.loc[:, reg_cols + ["fid"]].groupby("fid").quantile((0.25, 0.5, 0.75))
)
quantiles_df.index = quantiles_df.index.set_names(["fid", "percentile"])

In [None]:
f, ax = plt.subplots(figsize=(2.2, 2), gridspec_kw=dict(left=0.25, bottom=0.2, top=0.8))
for i, path in enumerate(fish_with_eyes):
    pltltr.tick_with_bars(
        quantiles_df.xs(path.name, level="fid")[reg_cols],
        cols=[COLS["qualitative"][0]] * 3,
        xdisperse=0.25,
    )

plt.plot(quantiles_df.xs(0.5, level="percentile")[reg_cols].values.T, c=".5", lw=0.5)

ax.set(
    ylabel="Fit correlation (cross-val)",
    xticks=[0, 1, 2],
    xticklabels=[" +\n".join(r.split("+")) for r in reg_cols],
)
pltltr.despine(ax)
ax.text(
    0.5,
    0.5,
    pltltr.get_pval_stars(
        wilcoxon(
            quantiles_df.xs(0.5, level="percentile")["gaze"],
            quantiles_df.xs(0.5, level="percentile")["heading"],
        )
    ),
    ha="center",
    va="bottom",
)
ax.text(
    1.5,
    0.5,
    pltltr.get_pval_stars(
        wilcoxon(
            quantiles_df.xs(0.5, level="percentile")["heading+gaze"],
            quantiles_df.xs(0.5, level="percentile")["heading"],
        )
    ),
    ha="center",
    va="bottom",
)

pltltr.savefig("eye_tail_correlations", folder="S9")

In [None]:
wilcoxon(
    quantiles_df.xs(0.5, level="percentile")["gaze"],
    quantiles_df.xs(0.5, level="percentile")["heading"],
)

In [None]:
for k in ["heading+gaze", "gaze", "heading"]:
    logger.add_entry(
        f"eyes_fit_{k}",
        quantiles_df.xs(0.5, level="percentile")[k].values,
        list(quantiles_df.xs(0.5, level="percentile")[k].index),
        moment="median",
    )

## Single fish example

In [None]:
path = DATASET_LOCATION / "lightsheet" / "210601_f0b" / "210601_f0_2dvr_eyes"
exp = LotrExperiment(path)
dlc_df = fl.load(path / "behavior_from_dlc.h5", "/data")

fictive_head = exp.fictive_heading
phase = exp.network_phase

eyes_arr = dlc_df["rt_eye_medfilt"] + dlc_df["lf_eye_medfilt"]
interp_eye = interpolate(dlc_df["t"], eyes_arr, exp.time_arr)

data_diff_df = pd.DataFrame(
    dict(
        phase=np.diff(np.unwrap(phase)),
        mov_regr=_process_reg(np.diff(fictive_head)),
        eye_pos_regr=_process_reg(np.diff(interp_eye)),
    )
)

data_df = pd.DataFrame(
    dict(
        phase=np.unwrap(phase),
        mov_regr=(fictive_head),
        eye_pos_regr=(interp_eye),
    )
)

In [None]:
t_slice_s = (1350, 1880)
f, ax = plt.subplots(figsize=(2, 2), gridspec_kw=dict(left=0.2, bottom=0.2, right=0.99))
lw = 1

sel_beh = exp.behavior_log[
    (exp.behavior_log["t"] > t_slice_s[0]) & (exp.behavior_log["t"] < t_slice_s[1])
]
t_slice = slice(*[t * exp.fs for t in t_slice_s])
t_arr = np.arange(t_slice.stop - t_slice.start) / exp.fs
ax.plot(
    t_arr,
    -zscore(np.unwrap(phase[t_slice])),
    c=COLS["ph_plot"],
    lw=lw,
    label="Network phase",
)
ax.plot(t_arr, zscore(fictive_head[t_slice]), c=COLS["th_plot"], lw=lw, label="Heading")
ax.plot(
    t_arr,
    -zscore(interp_eye[t_slice]) / 4 - 3,
    lw=lw,
    c=COLS["qualitative"][1],
    label="Eyes (mean)",
)

ax.plot(
    sel_beh["t"] - sel_beh["t"].values[0],
    zscore(sel_beh["tail_sum"]) / 12 - 5,
    c=".6",
    lw=lw,
    label="Tail sum",
    rasterized=True,
)

ax.set(xlabel="Time (s)")
pltltr.despine(ax, ["left", "top", "right"])
ax.legend(
    loc=2,
    bbox_to_anchor=(0.0, 1.1),
    fontsize=6,
    labelcolor="linecolor",
    handlelength=0.0,
)
pltltr.savefig("eye_tail_traces")

## Regressors plot

In [None]:
t_slice_s = (1350, 2000)
t_slice = slice(*[t * exp.fs for t in t_slice_s])
t_arr = np.arange(t_slice.stop - t_slice.start) / exp.fs

f, axs = plt.subplots(
    1,
    2,
    figsize=(5, 2.5),
    gridspec_kw=dict(bottom=0.2, left=0.05, right=0.95, wspace=0.3),
)
# for ax, df in zip(axs, [data_df, data_diff_df]):
cols = [COLS["ph_plot"], COLS["th_plot"], COLS["qualitative"][1]]
ax = axs[0]
ax.plot(t_arr, zscore(medfilt(data_df["phase"][t_slice], 11)) - 15, c=cols[0])
ax.plot(t_arr, zscore(data_df["mov_regr"][t_slice]), c=cols[1])
ax.plot(t_arr, zscore(data_df["eye_pos_regr"][t_slice]) / 2 - 5, c=cols[2])
[pltltr.despine(ax, ["left", "right", "top"]) for ax in axs]
for col, off, lab in zip(
    cols,
    [-15, 0, -5],
    ["phase (to predict)", "heading regressor", "gaze_pos regressor"],
):
    ax.text(t_arr[-1], off + 3, lab, ha="right", c=col)
ax.set(xlabel="Time (s)")

ax = axs[1]
lw = 1
for arr, col, off, lab in zip(
    [
        medfilt(data_diff_df["phase"], 11),
        data_diff_df["mov_regr"],
        data_diff_df["eye_pos_regr"],
    ],
    cols,
    [-50, 0, -18],
    ["d(phase)/dt (to predict)", "d(heading)/dt regressor", "d(gaze_pos)/dt regressor"],
):
    arr_toplot = zscore(np.array(arr))[t_slice] + off
    arr_toplot[0] = off
    arr_toplot[-1] = off
    ax.fill(t_arr, arr_toplot, lw=lw, fc=col, ec=col, alpha=0.5)
    ax.text(t_arr[-1], off - 8, lab, ha="right", c=col)

ax.set(xlabel="Time (s)", ylim=(-62, 7))
pltltr.savefig("regressors_descr", folder="S9")