
import matplotlib.gridspec as gridspec

from lotr.data_preprocessing.anatomy import anatomical_angle_remapping
from lotr.notebook_utils import print_source
from lotr.pca import pca_and_phase
from lotr.result_logging import ResultsLogger
from lotr.rpca_calculation import get_zero_mean_weights, reorient_pcs
from lotr.utils import (
    circular_corr,
    get_rot_matrix,
    get_vect_angle,
    reduce_to_pi,
    zscore,
)
from matplotlib import cm
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit
from scipy.stats import mannwhitneyu

logger = ResultsLogger()

COLS = pltltr.COLS

In [None]:
%matplotlib widget
from pathlib import Path

import lotr.plotting as pltltr
import numpy as np
import pandas as pd
from lotr import A_FISH, FIGURES_LOCATION, LotrExperiment, dataset_folders
from lotr.default_vals import (
    DEFAULT_FN,
    POST_BOUT_WND_S,
    PRE_BOUT_WND_S,
    WND_DELTA_PHASE_S,
)
from lotr.utils import crop
from matplotlib import pyplot as plt
from tqdm import tqdm

COLS = pltltr.COLS


def beeswarm(y, off=0, c=1, nbins=None):
    """
    Returns x coordinates for the points in `y`, so that plotting `x` and
    `y``results in a bee swarm plot.
    """
    y = np.asarray(y)
    if nbins is None:
        nbins = len(y) // 6

    # Get upper bounds of bins
    x = np.zeros(len(y))
    ylo = np.min(y)
    yhi = np.max(y)
    dy = (yhi - ylo) / nbins
    ybins = np.linspace(ylo + dy, yhi - dy, nbins - 1)

    # Divide indices into bins
    i = np.arange(len(y))
    ibs = [0] * nbins
    ybs = [0] * nbins
    nmax = 0
    for j, ybin in enumerate(ybins):
        f = y <= ybin
        ibs[j], ybs[j] = i[f], y[f]
        nmax = max(nmax, len(ibs[j]))
        f = ~f
        i, y = i[f], y[f]
    ibs[-1], ybs[-1] = i, y
    nmax = max(nmax, len(ibs[-1]))

    # Assign x indices
    dx = 1 / (nmax // 2)
    for i, y in zip(ibs, ybs):
        if len(i) > 1:
            j = len(i) % 2
            i = i[np.argsort(y)]
            a = i[j::2]
            b = i[j + 1 :: 2]
            x[a] = (0.5 + j / 3 + np.arange(len(b))) * dx
            x[b] = (0.5 + j / 3 + np.arange(len(b))) * -dx

    return c * x + off

In [None]:
def get_nobouts_period_idxs(
    exp,
    exclude_pre_bts_s=30,
    exclude_post_bts_s=30,
    exclude_overlap_pre=PRE_BOUT_WND_S,
    exclude_overlap_post=POST_BOUT_WND_S,
):

    # Generate array of indexes with no occourrences of bouts before and after
    valid_locations = np.arange(
        exclude_pre_bts_s * exp.fs, exp.n_pts - exclude_post_bts_s * exp.fs
    )
    for bout_i in exp.bouts_df["idx_imaging"]:
        filter_in = (valid_locations < bout_i - exclude_pre_bts_s * exp.fs) | (
            valid_locations > bout_i + exclude_post_bts_s * exp.fs
        )
        valid_locations = valid_locations[filter_in]

    # avoid too close croppings of resting periods (//2 allows some very loose overlapping of half, ie, basically no overlap):
    i = 0
    while i < len(valid_locations):
        filter_in = (
            (valid_locations < valid_locations[i] - exclude_overlap_pre * exp.fs // 2)
            | (
                valid_locations
                > valid_locations[i] + exclude_overlap_post * exp.fs // 2
            )
            | (valid_locations == valid_locations[i])
        )
        valid_locations = valid_locations[filter_in]
        i += 1

    return valid_locations

In [None]:
all_cropped = dict()
for path in tqdm(dataset_folders):
    exp = LotrExperiment(path)
    valid_locations = get_nobouts_period_idxs(exp)
    valid_bouts = exp.bouts_df.loc[exp.bouts_df["direction"] != "fw", "idx_imaging"]
    # n = min(len(valid_bouts), len(valid_locations))
    idxs_rest = valid_locations  # np.random.choice(valid_locations, n)
    idxs_bouts = valid_bouts  # np.random.choice(valid_bouts, n)

    unwrap_phase = np.unwrap(exp.network_phase)
    all_cropped[exp.exp_code] = dict()
    for k, idxs in zip(["rest", "bouts"], [idxs_rest, idxs_bouts]):
        cropped = crop(
            unwrap_phase,
            idxs,
            pre_int=PRE_BOUT_WND_S * exp.fs,
            post_int=POST_BOUT_WND_S * exp.fs,
        )
        cropped = cropped - np.mean(cropped[: PRE_BOUT_WND_S * exp.fs, :], axis=0)
        cropped = cropped * np.sign(cropped[-1, :])

        all_cropped[exp.exp_code][k] = cropped

    all_cropped[exp.exp_code]["has_eyes"] = len(list(path.glob("*dlc*"))) > 0

plt.close("all")
cols = [".5", "C0"]
n_plots = len(all_cropped.keys())
f, all_axs = plt.subplots(n_plots, 3, figsize=(6, n_plots))

for n, k in enumerate(all_cropped.keys()):
    axs = all_axs[n, :]
    dict_to_plot = all_cropped[k]
    for i, (k2, col, ax) in enumerate(zip(["rest", "bouts"], cols, axs)):
        cropped = dict_to_plot[k2]
        xarr = np.linspace(-PRE_BOUT_WND_S, POST_BOUT_WND_S, cropped.shape[0])
        ax.plot(xarr, cropped, lw=0.5, c=col)
        ax.set(ylim=(-1.1, 3.2))
        axs[2].scatter(
            np.random.rand(len(cropped[-1, :])) / 2 + i,
            # beeswarm(cropped[-1, :], off=i ) + i,
            cropped[-1, :],
            color=col,
            s=15,
            alpha=0.25,
            lw=0,
        )
        ax.axvline(0, lw=0.5, c=".3", zorder=-100)
        axs[2].scatter(
            i + 0.25, np.mean(cropped[-1, :]), color=col, s=15, marker="_",
        )
        # axs[2].plot(
        #    [i + 0.25,] * 2,
        #    [np.percentile(cropped[-1, :], p) for p in [25, 75]],
        #    c=col,
        # )
        axs[2].set(ylim=(-1.1, 3.2))
        # axs[2].scatter(
        #    beeswarm(cropped[-1, :], off=i ) + i,
        #    cropped[-1, :],
        #    color=col,
        #    s=15,
        #    alpha=0.9,
        #    lw=0
        # )
        # sns.swarmplot(cropped[-1, :], ax=axs[2])

f.savefig("/Users/luigipetrucco/Desktop/nomotion.png")

In [None]:
cropped.shape

In [None]:
175 / 5

In [None]:
calculate_drift_after_s = 20  # drift after n seconds

plt.close("all")
cols = [".5", "C0"]


f, ax = plt.subplots(figsize=(1.5, 2.5), gridspec_kw=dict(left=0.3, bottom=0.25))

all_rest = []
all_bouts = []
for k in all_cropped.keys():
    dict_to_plot = all_cropped[k]
    if not dict_to_plot["has_eyes"]:
        for k2, l in zip(["rest", "bouts"], [all_rest, all_bouts]):
            cropped = dict_to_plot[k2]
            try:
                l.append(
                    np.mean(
                        cropped[(PRE_BOUT_WND_S + calculate_drift_after_s) * exp.fs, :]
                        - cropped[PRE_BOUT_WND_S * exp.fs, :]
                    )
                )
            except IndexError:
                pass

plt.plot(np.stack([all_rest, all_bouts]), lw=1, c=".8")
w = 0.05
for i, (col, l) in enumerate(zip(cols, [all_rest, all_bouts])):
    ax.plot(
        [i,] * 2, [np.percentile(l, p) for p in [25, 75]], c=col,
    )
    ax.plot(
        [i - w, i + w], [np.percentile(l, 50),] * 2, c=col,
    )
plt.show()
ax.set(
    xlim=(-0.2, 1.2),
    ylabel="abs. phase drift after 30 s",
    xticks=[0, 1],
    xticklabels=["resting", "motion"],
)
pltltr.despine(ax)

In [None]:
# Get unwrapped phase:
exp = LotrExperiment(A_FISH)
unwrapped_ph = np.unwrap(exp.network_phase)

# Crop network phase:
cropped_phase = crop(
    unwrapped_ph,
    exp.bouts_df["idx_imaging"],
    pre_int=PRE_BOUT_WND_S * exp.fn,
    post_int=POST_BOUT_WND_S * exp.fn,
)
# Subtract baseline:
cropped_phase = cropped_phase - np.mean(cropped_phase[: PRE_BOUT_WND_S * exp.fn, :], 0)

# Same, for heading direction:
cropped_head = crop(
    exp.fictive_heading,
    exp.bouts_df["idx_imaging"],
    pre_int=PRE_BOUT_WND_S * exp.fn,
    post_int=POST_BOUT_WND_S * exp.fn,
)
cropped_head = cropped_head - np.mean(cropped_head[: PRE_BOUT_WND_S * exp.fn, :], 0)


time_arr = np.arange(cropped_phase.shape[0]) / exp.fn - PRE_BOUT_WND_S

In [None]:
# import numpy as np
import pandas as pd
from bouter.utilities import crop
from lotr import dataset_folders
from lotr.default_vals import DEFAULT_FN, POST_BOUT_WND_S, PRE_BOUT_WND_S
from lotr.experiment_class import LotrExperiment
from lotr.utils import interpolate, resample_matrix
from tqdm import tqdm


def custom_crop_shifts_all_dataset(crop_stimulus=False):
    """Crop fictive heading and network phase around bouts from all fish.
    in the dataset. For a demo of what is happening, "4. Phase dynamics.ipynb" notebook.

    Returns
    -------
    (all_phase_cropped, all_head_cropped, events_df)
        The first two returns are the cropped n_tpts x n_bouts matrices, the third is
        the dataframe that contains the info about all events.

    """
    fn = DEFAULT_FN

    all_phase_cropped = []
    all_head_cropped = []
    all_stim_cropped = []
    # We will create a dataframe to keep track of events from all fish.
    # Mostly a way of keeping together the crop and the bouts:
    events_df = []

    # Define temporal array for the resampling:
    time_arr = (
        np.arange(1, ((PRE_BOUT_WND_S + POST_BOUT_WND_S) * fn) + 1) / fn
        - PRE_BOUT_WND_S
    )
    for path in tqdm(dataset_folders):
        exp = LotrExperiment(path)
        # TODO recompute to avoid this bugfix
        exp.bouts_df["fid"] = path.name
        nobout_locations = get_nobouts_period_idxs(exp)

        crop_events_df = exp.bouts_df.loc[
            :, ["bias", "direction", "t_start", "idx_imaging", "fid"]
        ]
        # print(crop_events_df.shape)

        n_nob = len(nobout_locations)
        nobout_events_df = pd.DataFrame(
            dict(
                bias=np.full(n_nob, np.nan),
                direction=np.full(n_nob, "nb"),
                t_start=nobout_locations / exp.fs,
                idx_imaging=nobout_locations,
                fid=path.name,
            )
        )

        crop_events_df = pd.concat([crop_events_df, nobout_events_df], axis=0)
        # print(crop_events_df)

        stim_interp = np.full(exp.n_pts, np.nan)
        try:
            stim_df = exp.stimulus_log
            if "cl2D_theta" in stim_df.columns and crop_stimulus:
                stim_interp = interpolate(
                    stim_df["t"], stim_df["cl2D_theta"], exp.time_arr
                )
        except AttributeError:
            pass

        # Crop both the fictive heading (cumulative tail theta sum) and network phase
        # in the same way:
        for dest_list, to_crop in zip(
            [all_phase_cropped, all_head_cropped, all_stim_cropped],
            [np.unwrap(exp.network_phase), exp.fictive_heading, stim_interp],
        ):

            # Crop around events:
            # print(idxs)
            cropped = crop(
                to_crop,
                crop_events_df["idx_imaging"],
                pre_int=PRE_BOUT_WND_S * exp.fs,
                post_int=POST_BOUT_WND_S * exp.fs,
            )
            # print(cropped.shape, len(crop_events_df["idx_imaging"]))

            # Subtract baseline:
            cropped = cropped - np.mean(cropped[: PRE_BOUT_WND_S * exp.fs, :], 0)

            # Interpolate if necessary:
            if exp.fs != fn:
                fish_time_arr = (
                    np.arange(1, cropped.shape[0] + 1) / exp.fs - PRE_BOUT_WND_S
                )
                cropped = resample_matrix(time_arr, fish_time_arr, cropped)

            dest_list.append(cropped)
        events_df.append(crop_events_df.reindex())

    # Concatenate all the results:
    all_phase_cropped = np.concatenate(all_phase_cropped, axis=1)
    all_head_cropped = np.concatenate(all_head_cropped, axis=1)
    try:
        all_stim_cropped = np.concatenate(all_stim_cropped, axis=1)
    except np.AxisError:
        all_stim_cropped = None
    events_df = pd.concat(events_df, ignore_index=True)

    if crop_stimulus:
        return (
            all_phase_cropped,
            all_head_cropped,
            all_stim_cropped,
            events_df,
            time_arr,
        )
    else:
        return all_phase_cropped, all_head_cropped, events_df, time_arr

In [None]:
(
    all_phase_cropped,
    all_head_cropped,
    events_df,
    time_arr,
) = custom_crop_shifts_all_dataset()

In [None]:
events_df.shape

In [None]:
def plot_bout_trig(
    events_df, cropped_list, labels_list, ylims=(-4.2, 4.2), legend_lab="{} bouts"
):

    f, axs = plt.subplots(
        1,
        2,
        figsize=(5, 2.5),
        gridspec_kw=dict(left=0.08, bottom=0.15, top=0.9, right=0.73),
        sharey=True,
    )

    for ax, lab, cropped in zip(axs, labels_list, cropped_list):
        for d in events_df["direction"].unique():
            sel = events_df["direction"] == d
            print(cropped.shape, sel.shape)
            ax.plot(
                time_arr,
                cropped[:, sel],
                lw=0.3,
                c=COLS["sides"][d],
                label="_nolabel_",
            )
            ax.plot(
                time_arr,
                np.mean(cropped[:, sel], 1),
                lw=2,
                c=pltltr.dark_col(COLS["sides"][d]),
                label=legend_lab.format(d) if legend_lab is not None else "__nolabel__",
                zorder=30,
            )

        pltltr.despine(ax)
        ax.set(
            xlabel="time from bout (s)",
            ylim=ylims,
            **pltltr.get_pi_labels(0.5, ax="y"),
        )
        ax.set_title(lab, weight="bold")
        ax.axvline(0, lw=0.5, c=".5")
    if legend_lab is not None:
        axs[1].legend(bbox_to_anchor=(1.05, 0.8, 0.2, 0.2))

    return f, axs


# f, axs = plot_bout_trig(
#    exp.bouts_df, [all_phase_cropped, all_head_cropped], ["Δphase", "Δheading"]
# )
# pltltr.savefig("bout_trig_phase_change_onefish")

In [None]:
# For every fish, compute average response for each direction:
all_phase_means = []
all_head_means = []
mean_events_df = []
for fid in events_df["fid"].unique():
    for d in events_df["direction"].unique():
        sel = (events_df["direction"] == d) & (events_df["fid"] == fid)

        if sum(sel) > 10:
            all_phase_means.append(all_phase_cropped[:, sel].mean(1))
            all_head_means.append(all_head_cropped[:, sel].mean(1))
            mean_events_df.append(events_df[sel].iloc[0, :])

all_phase_means = np.stack(all_phase_means).T
all_head_means = np.stack(all_head_means).T
mean_events_df = pd.DataFrame(mean_events_df)

In [None]:
COLS["sides"]["nb"] = ".8"
f, axs = plot_bout_trig(
    mean_events_df,
    [all_phase_means, all_head_means],
    ["Δphase", "Δheading"],
    legend_lab="{} (all fish + mn)",
    ylims=(-np.pi, np.pi),
)

In [None]:
from scipy.stats import gaussian_kde

In [None]:
# values from calculation of phase change after bout:
quantify_after_s = 20
quantify_for_s = 5
wnd_delta_phase = np.array([quantify_after_s, quantify_after_s + quantify_for_s])

wnd_pts = (PRE_BOUT_WND_S + wnd_delta_phase) * DEFAULT_FN
events_df["Δphase"] = np.nanmean(all_phase_cropped[slice(*wnd_pts), :], 0)
events_df["Δhead"] = np.nanmean(all_head_cropped[slice(*wnd_pts), :], 0)

xarr = np.arange(-np.pi * 1.5, np.pi * 1.5, 0.01)
# For every fish, compute average response for each direction:
all_phase_kdes = []
all_head_kdes = []
for fid in events_df["fid"].unique():
    for d in events_df["direction"].unique():
        sel = (events_df["direction"] == d) & (events_df["fid"] == fid)

        if sum(sel) > 10:
            for var, list_to_use in zip(
                ["Δphase", "Δhead"], [all_phase_kdes, all_head_kdes]
            ):
                kde_f = gaussian_kde(events_df.loc[sel, var])
                list_to_use.append(kde_f(xarr))

all_phase_kdes = np.stack(all_phase_kdes)
all_head_kdes = np.stack(all_head_kdes)

In [None]:
def simpleaxis(ax):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()

In [None]:
plt.close("all")
COLS["sides"]["nb"] = ".8"
f, axs = plot_bout_trig(
    mean_events_df,
    [all_head_means, all_phase_means],
    ["Δheading", "Δphase"],
    legend_lab=None,
    ylims=(-np.pi, np.pi),
)

# axs[1].clear()
# axs[1].
# simpleaxis(axs[1])

ax = f.add_axes((0.8, 0.15, 0.19, 0.75))
for direction in ["lf", "rt", "nb"]:
    # direction = "nb"
    mn_hist = []
    for fid in events_df["fid"].unique():
        c = COLS["sides"][direction]
        sel = (mean_events_df["direction"] == direction) & (
            mean_events_df["fid"] == fid
        )
        if sum(sel) > 0:
            kde_f = all_phase_kdes[sel, :]
            mn_hist.append(kde_f)
            # plt.fill_between(xarr, np.zeros(len(xarr)), kde_f[0, :], fc=c, lw=0, alpha=0.05)
            plt.plot(kde_f[0, :] / np.sum(kde_f[0, :]), xarr, c=c, lw=1, alpha=0.5)

    mn_hist = np.mean(np.concatenate(mn_hist), 0)
    plt.plot(
        mn_hist / np.sum(mn_hist), xarr, c=pltltr.shift_lum(c, -0.2), zorder=100, lw=1.5
    )
axs[0].set(ylim=(-np.pi * 0.7, np.pi * 0.7), **pltltr.get_pi_labels(0.5, ax="y"))
ax.set(
    xlabel="P(drift)",
    xlim=(-0.001, 0.015),
    ylim=(-np.pi * 0.7, np.pi * 0.7),
    yticks=pltltr.get_pi_labels(0.5, ax="y")["yticks"],
    yticklabels=[],
)
plt.show()
pltltr.despine(ax)

In [None]:
plt.close("all")
COLS["sides"]["nb"] = ".8"
f, axs = plot_bout_trig(
    mean_events_df,
    [all_head_means, all_phase_means],
    ["Δheading", "Δphase"],
    legend_lab=None,
    ylims=(-np.pi, np.pi),
)

# axs[1].clear()
# axs[1].
# simpleaxis(axs[1])

ax = f.add_axes((0.8, 0.15, 0.19, 0.75))
for direction in ["nb"]:
    # direction = "nb"
    mn_hist = []
    for fid in events_df["fid"].unique():
        c = COLS["sides"][direction]
        sel = (mean_events_df["direction"] == direction) & (
            mean_events_df["fid"] == fid
        )
        if sum(sel) > 0:
            kde_f = all_phase_kdes[sel, :]
            mn_hist.append(kde_f)
            # plt.fill_between(xarr, np.zeros(len(xarr)), kde_f[0, :], fc=c, lw=0, alpha=0.05)
            plt.plot(kde_f[0, :] / np.sum(kde_f[0, :]), xarr, c=c, lw=1, alpha=0.5)

    mn_hist = np.mean(np.concatenate(mn_hist), 0)
    plt.plot(
        mn_hist / np.sum(mn_hist), xarr, c=pltltr.shift_lum(c, -0.2), zorder=100, lw=1.5
    )
axs[0].set(ylim=(-np.pi * 0.7, np.pi * 0.7), **pltltr.get_pi_labels(0.5, ax="y"))
ax.set(
    xlabel="P(drift)",
    xlim=(-0.001, 0.015),
    ylim=(-np.pi * 0.7, np.pi * 0.7),
    yticks=pltltr.get_pi_labels(0.5, ax="y")["yticks"],
    yticklabels=[],
)
plt.show()
pltltr.despine(ax)

In [None]:
plt.close("all")
COLS["sides"]["nb"] = ".8"
f, axs = plot_bout_trig(
    mean_events_df,
    [all_head_means, all_phase_means],
    ["Δheading", "Δphase"],
    legend_lab=None,
    ylims=(-np.pi, np.pi),
)

# axs[1].clear()
# axs[1].
# simpleaxis(axs[1])

ax = f.add_axes((0.8, 0.15, 0.19, 0.75))

for direction in ["nb"]:
    # direction = "nb"
    mn_hist = []
    for fid in events_df["fid"].unique():
        c = COLS["sides"][direction]
        sel = (mean_events_df["direction"] == direction) & (
            mean_events_df["fid"] == fid
        )
        if sum(sel) > 0:
            kde_f = all_phase_kdes[sel, :]
            mn_hist.append(kde_f)
            # plt.fill_between(xarr, np.zeros(len(xarr)), kde_f[0, :], fc=c, lw=0, alpha=0.05)
            plt.plot(kde_f[0, :] / np.sum(kde_f[0, :]), xarr, c=c, lw=1, alpha=0.5)

    mn_hist = np.mean(np.concatenate(mn_hist), 0)
    plt.plot(
        mn_hist / np.sum(mn_hist), xarr, c=pltltr.shift_lum(c, -0.2), zorder=100, lw=1.5
    )

for direction in ["lf", "rt"]:
    sel_df = events_df[events_df["direction"] == direction]
    kde_f = gaussian_kde(sel_df.groupby("fid").mean()["Δphase"])
    mn_hist = kde_f(xarr)
    plt.plot(
        mn_hist / np.sum(mn_hist), xarr, c=COLS["sides"][direction],#pltltr.shift_lum(c, -0.2), 
        zorder=100, lw=1.5
    )

axs[0].set(ylim=(-np.pi * 0.7, np.pi * 0.7), **pltltr.get_pi_labels(0.5, ax="y"))
ax.set(
    xlabel="P(drift)",
    xlim=(-0.001, 0.015),
    ylim=(-np.pi * 0.7, np.pi * 0.7),
    yticks=pltltr.get_pi_labels(0.5, ax="y")["yticks"],
    yticklabels=[],
)
plt.show()
pltltr.despine(ax)

In [None]:
plt.close("all")
COLS["sides"]["nb"] = ".8"
f, axs = plot_bout_trig(
    mean_events_df,
    [all_head_means, all_phase_means],
    ["Δheading", "Δphase"],
    legend_lab=None,
    ylims=(-np.pi, np.pi),
)

# axs[1].clear()
# axs[1].
# simpleaxis(axs[1])

ax = f.add_axes((0.8, 0.15, 0.19, 0.75))
for direction in ["lf", "rt"]:
    sel_df = events_df[events_df["direction"] == direction]
    kde_f = gaussian_kde(sel_df.groupby("fid").mean()["Δphase"])
    mn_hist = kde_f(xarr)
    plt.plot(
        mn_hist / np.sum(mn_hist), xarr, c=COLS["sides"][direction],#pltltr.shift_lum(c, -0.2), 
        zorder=100, lw=1.5
    )
axs[0].set(ylim=(-np.pi * 0.7, np.pi * 0.7), **pltltr.get_pi_labels(0.5, ax="y"))
ax.set(
    xlabel="P(drift)",
    xlim=(-0.001, 0.015),
    ylim=(-np.pi * 0.7, np.pi * 0.7),
    yticks=pltltr.get_pi_labels(0.5, ax="y")["yticks"],
    yticklabels=[],
)
plt.show()
pltltr.despine(ax)

In [None]:
f.savefig("/Users/luigipetrucco/Desktop/histograms.pdf")

In [None]:
cumsum = np.cumsum(mn_hist / np.sum(mn_hist))
thrs = (0.05, 0.95)
fork = np.array(
    [
        np.argwhere((cumsum[1:] > thrs[0]) & (cumsum[:-1] < thrs[0]))[0, 0],
        np.argwhere((cumsum[1:] > thrs[1]) & (cumsum[:-1] < thrs[1]))[0, 0],
    ]
)

In [None]:
xarr[fork]

In [None]:
from scipy.optimize import curve_fit


def gaus(x, a, x0, sigma):
    return a * np.exp(-((x - x0) ** 2) / (2 * sigma ** 2))


popt, pcov = curve_fit(gaus, xarr, mn_hist, p0=[1, 0, 0.5])

In [None]:
popt

In [None]:
all_sigmas = []
direction = "nb"
mn_hist = []
for fid in events_df["fid"].unique():
    c = COLS["sides"][direction]
    sel = (mean_events_df["direction"] == direction) & (mean_events_df["fid"] == fid)
    if sum(sel) > 0:
        kde_f = all_phase_kdes[sel, :]
        popt, pcov = curve_fit(gaus, xarr, kde_f[0, :], p0=[1, 0, 0.5])

        all_sigmas.append(popt[2])

In [None]:
np.median(all_sigmas)

In [None]:
plt.figure()
plt.plot(mn_hist)
plt.plot(mn_hist[::-1])

In [None]:
COLS["sides"]["nb"] = ".8"
f, axs = plot_bout_trig(
    mean_events_df,
    [all_phase_means, all_head_means],
    ["Δphase", "Δheading"],
    legend_lab="{} (all fish + mn)",
    ylims=(-np.pi, np.pi),
)

In [None]:
for ax, lab, cropped in zip(axs, labels_list, cropped_list):
    for d in events_df["direction"].unique():
        sel = events_df["direction"] == d
        ax.plot(
            time_arr, cropped[:, sel], lw=0.3, c=COLS["sides"][d], label="_nolabel_",
        )
        ax.plot(
            time_arr,
            np.mean(cropped[:, sel], 1),
            lw=2,
            c=pltltr.dark_col(COLS["sides"][d]),
            label=legend_lab.format(d),
            zorder=30,
        )