# Phase dynamics and visual feedback

In this notebook, we will build upon the previous quantifications on the relationship between phase and directional swimming and see whether visual feedback is necessary / involved in the rotation of the network.

In [None]:
from scipy.stats import wilcoxon
from tqdm import tqdm

In [None]:
%matplotlib widget

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from scipy.stats import wilcoxon
from tqdm import tqdm

import lotr.plotting as pltltr
from lotr import A_FISH, FIGURES_LOCATION, LotrExperiment, dataset_folders
from lotr.behavior import get_fictive_heading
from lotr.result_logging import ResultsLogger
from lotr.utils import crop, interpolate, linear_regression, resample_matrix

logger = ResultsLogger()

COLS = pltltr.COLS

## Load stimulus information

We will use the `LotrExperiment.stim_trials_df` method to load from each experiment a dictionary describing the sequence of stimuli and the concatenate all of them in a single dataframe.
The `all_stim_conds` will contain a `condition` column that can be one of the following:
 - **`darkness`**: black screen, fish swimming in the darkness
 - **`closed_loop`**: fish experiences visual feedback when moving. This consists in a pink noise pattern that, every time the fish performa a bout, rotates proportionally to the laterality index for the bout (with inverted sign), and translates backward up with an average speed of approx. 10 mm/s. In the gainmod experiments, a gain factor that changes the amount of rotation  induced by each directional bout is alternated between 0.5 (**low gain**), 1.0 (**normal gain**), 2.0 (**high gain**), -1 (**inverted motion**).
 - **`natural_mot`**: pink noise that is translated/rotated independently of the fish motion in ways similar to what the fish would experience swimming over the pattern - a sort of playback of a closed loop experiment. This stimulus induces very small responses in visual directional motion/rotation cells as every movement is very brief.
 - **`directional_mot`**: pink noise that moves in 8 equally spaced directions on the plane below the fish in periods of 10s, spaced by pause periods of 10s where the pattern is static. As the fish is mostly moving during directional motion, for the sake of this quantification for now all bouts during both the motion and the static times are pooled in the "directional_mot" condition # TODO maybe split them

In [None]:
all_stim_logs = [LotrExperiment(path).stim_trials_df for path in tqdm(dataset_folders)]
all_stim_conds = pd.concat(all_stim_logs, ignore_index=True, sort=True)

### Visualize all experiment types

Let's have a quick look at all the experiment types:

In [None]:
def plot_exp_condition(exp_stim_df, ax=None, alpha=0.3, **kwargs):
    """Plot experiment conditions as vertical ranges on some axes.

    Parameters
    ----------
    exp_stim_df : pd.DataFrame
        Dataframe containing the "t_start", "t_stop" and "condition" entries.
    ax : plt.Axis (optional)
        Axis over which to plot, by default current.
    alpha : float (optional)
        Transparency (default=0.3).
    kwargs : dict
        Additional arguments for plt.axvspan function.

    Returns
    -------
    list
        All the vspan function outputs.

    """
    stim_colors = COLS["stim_conditions"]
    if ax is None:
        ax = plt.gca()

    vspan = list()
    for i in exp_stim_df.index:
        t_s, t_e = [exp_stim_df.loc[i, k] for k in ["t_start", "t_stop"]]
        if exp_stim_df.loc[i, "condition"] == "closed_loop":
            col = stim_colors["closed_loop"][exp_stim_df.loc[i, "gain_theta"]]
        else:
            col = stim_colors[exp_stim_df.loc[i, "condition"]]
        vspan.append(ax.axvspan(t_s, t_e, fc=col, lw=0, alpha=alpha, **kwargs))

    return vspan

In [None]:
qualitative = COLS["qualitative"]
cl = COLS["sides"]["lf"]  # qualitative[1]
ol = qualitative[3]
COLS["stim_conditions"] = {
    "darkness": ".5",
    "natural_motion": ol,
    "directional_motion": [0.62, 0.29, 0.63],
    "closed_loop": {
        0.5: pltltr.shift_lum(cl, -0.15),
        1: pltltr.shift_lum(cl, 0),
        2: pltltr.shift_lum(cl, +0.15),
        -1: pltltr.shift_lum(ol, -0.15),
    },
}

In [None]:
all_exp_types = all_stim_conds.exp_type.unique()

f, axs = plt.subplots(
    len(all_exp_types) + 1,
    1,
    figsize=(5, 4),
    gridspec_kw=dict(top=0.8, right=0.8, hspace=2),
)
for i, exp_type in enumerate(all_exp_types):
    # Get an example fish for the experiment:
    fids = all_stim_conds.loc[all_stim_conds["exp_type"] == exp_type, "fid"].unique()
    fid = fids[0]

    # Plot with custom lotr function:
    pltltr.plot_exp_condition(
        all_stim_conds[all_stim_conds["fid"] == fid], ax=axs[i], alpha=1
    )

    axs[i].set(title=f"n={len(fids)}", xlim=(0.0, 2500))

for ax in axs[:-1]:
    pltltr.despine(ax, "all")
pltltr.despine(axs[-1], ["left", "top", "right"])

axs[-1].set(xlabel="Time (s)", xlim=(0.0, 2500))

ax = f.add_axes((0.75, 0.95, 0.2, 0.2))
ax.axis("off")
v_space = -0.17
v_o, h_o = 0, 0.01

for text, item in COLS["stim_conditions"].items():
    lab = " ".join(text.split("_")).capitalize()
    if text != "closed_loop":
        ax.text(h_o, v_o, lab, c=item, weight="bold", fontsize=8)
        v_o += v_space
    else:
        for g, col in item.items():
            ax.text(h_o, v_o, f"{lab} (gain {g})", c=col, weight="bold", fontsize=8)
            v_o += v_space

pltltr.savefig("experiment_types",folder="S8b")
# plt.savefig("/Users/luigipetrucco/Desktop/experiment_types_new.pdf")

## Bout-induced phase changes w/ and w/o visual feedback 

First of all, we crop phase shifts and cumulative theta turned around all bouts as we did in the previous notebook. Then, for every bout we will add a label specifying what was the ongoing stimulus at its onset time:

In [None]:
from lotr.analysis.shift_cropping import crop_shifts_all_dataset
from lotr.default_vals import (
    DEFAULT_FN,
    POST_BOUT_WND_S,
    PRE_BOUT_WND_S,
    WND_DELTA_PHASE_S,
)

(
    all_phase_cropped,
    all_head_cropped,
    all_stim_cropped,
    events_df,
    time_arr,
) = crop_shifts_all_dataset(crop_stimulus=True)

# TODO: fix this misnaming in actual preprocessing
events_df.loc[events_df["fid"] == "210601_f0_2d_vr_eyes", "fid"] = "210601_f0_2dvr_eyes"

# add label for condition and gain to every bout:
events_df["condition"] = "-"
events_df["gain"] = np.nan

for i in tqdm(events_df.index):
    sel_trial = (
        (all_stim_conds["t_start"] < events_df.loc[i, "t_start"])
        & (all_stim_conds["t_stop"] > events_df.loc[i, "t_start"])
        & (all_stim_conds["fid"] == events_df.loc[i, "fid"])
    )
    events_df.loc[i, "condition"] = all_stim_conds.loc[sel_trial, "condition"].values
    events_df.loc[i, "exp_type"] = all_stim_conds.loc[sel_trial, "exp_type"].values

    if events_df.loc[i, "condition"] == "closed_loop":
        events_df.loc[i, "gain_theta"] = all_stim_conds.loc[
            sel_trial, "gain_theta"
        ].values

# Calculate amount of shift around each motor event:
wnd_pts = (PRE_BOUT_WND_S + WND_DELTA_PHASE_S) * DEFAULT_FN
events_df["delta_phase"] = np.nanmean(all_phase_cropped[slice(*wnd_pts), :], 0)
events_df["delta_head"] = np.nanmean(all_head_cropped[slice(*wnd_pts), :], 0)

### Effects of gain modulation

Now, for `gainmod` experiments, we can calculate whether gain has an effect on peri-event heading-to-phase correlation

In [None]:
gains = [0.5, 1, 2, -1]  # fix sequence for consistency

# select gainmod stimuli:
gainmod_df = events_df[events_df["exp_type"] == "gainmod"]
gain_results_df = []

if len(gainmod_df) > 0:
    for fid in gainmod_df["fid"].unique():
        gdict = dict(fid=fid)
        for g in gains:
            sel_df = events_df[
                (events_df["gain_theta"] == g) & (events_df["fid"] == fid)
            ]
            _, corr = linear_regression(sel_df["delta_phase"], sel_df["delta_head"])

            gdict[g] = corr
            # results_df.append({"gain": g, "fid": fid, "corr": corr})
        gain_results_df.append(gdict)

    gain_results_df = pd.DataFrame(gain_results_df).set_index("fid")

In [None]:
if len(gainmod_df) > 0:
    cols = [c for _, c in COLS["stim_conditions"]["closed_loop"].items()]

    f, ax = plt.subplots(figsize=(2.0, 2), gridspec_kw=dict(left=0.2, bottom=0.2))
    pltltr.tick_with_bars(
        gain_results_df, cols=cols, lw=1.5, s=0.08,
    )
    plt.plot(gain_results_df.values.T, lw=1, c=".8")
    ax.set(
        ylabel=(r"Δhead Δphase $\alpha$"),
        xlabel="Gain",
        xticks=[0, 1, 2, 3],
        xticklabels=[0.5, 1, 2, -1],
        xlim=(-0.3, 3.3),
        ylim=(-4, 1),
    )
    pltltr.despine(ax)

    # run all possible pair-wise tests:
    all_res = []
    for i in range(len(gains)):
        for j in range(i + 1, len(gains)):
            test_res = wilcoxon(*[gain_results_df[k] for k in [gains[i], gains[j]]])
            print(f"{gains[i]} vs. {gains[j]}: p={test_res.pvalue}")
            all_res.append(test_res.pvalue)

    ax.text(
        1.5,
        0.5,
        pltltr.get_pval_stars(min(all_res)) + " on all comparisons",
        ha="center",
    )

    pltltr.savefig("different_gains_comparison")

In [None]:
gains

In [None]:
all_res

In [None]:
f, axs = plt.subplots(3, 2, figsize=(8, 3), gridspec_kw=dict(wspace=0.35, hspace=0.4))

axs_flat = axs.flatten()
results_df = []
for i, fid in tqdm(enumerate(gainmod_df["fid"].unique())):
    path = [d for d in dataset_folders if d.name == fid][0]
    exp = LotrExperiment(path)
    phase = np.unwrap(exp.network_phase)

    fictive_head = exp.fictive_heading

    ax = axs_flat[i]
    pltltr.plot_exp_condition(
        all_stim_conds[all_stim_conds["fid"] == fid], ax=ax, alpha=0.2
    )
    twin_ax = ax.twinx()
    ax.plot(exp.time_arr, fictive_head, c=COLS["th_plot"], label="Estimated heading")
    ax.plot([], [], c=COLS["ph_plot"], label="Network phase")
    twin_ax.plot(
        exp.time_arr, np.unwrap(exp.network_phase), c=COLS["ph_plot"],
    )
    ax.set_title(exp.exp_code, fontsize=6, loc="left")
    lims = twin_ax.get_ylim()
    twin_ax.set_ylim((lims[1], lims[0]))
    pltltr.add_scalebar(
        ax, xlen=500, xpos=100, ylen=0, disable_axis=False, xlabel="", ylabel=""
    )

    [pltltr.despine(a, ["top", "bottom"]) for a in [ax, twin_ax]]
    for a, col in zip([ax, twin_ax], [COLS["th_plot"], COLS["ph_plot"]]):
        [t.set_color(col) for t in a.yaxis.get_ticklines()]
        [t.set_color(col) for t in a.yaxis.get_ticklabels()]
# ax.axvspan(t_off_s, t_off_s + t_dur_s, lw=0, fc=".9", zorder=-100)

[pltltr.despine(axs_flat[j], "all") for j in range(i + 1, len(axs_flat))]


axs[0, -1].legend(
    loc=2, bbox_to_anchor=(0.55, 1.7), labelcolor="linecolor", handlelength=0.0
)

ax_lab = f.add_axes((0.1, 0.02, 0.4, 0.05))
for i, (g, col) in enumerate(COLS["stim_conditions"]["closed_loop"].items()):
    ax_lab.text(i * 0.25, 0, f"Gain {g}", c=col, weight="bold", fontsize=8)
ax_lab.set(xlim=(-0.1, 1.1), ylim=(-0.1, 1))
ax_lab.axis("off")

pltltr.savefig("gains_comparison", folder="S8")

In [None]:
for g in gain_results_df.columns:
    logger.add_entry(
        f"slope_gain{g}",
        gain_results_df[g],
        fids=list(gain_results_df.index),
        moment="median",
    )

There seem to be no effect of the gain!

### Effects of closed- vs. open-loop stimuli

In [None]:
# Select fish that have closed loop:
# included_fish = events_df.loc[events_df["condition"] == "closed_loop", "fid"].unique()
included_fish = events_df.loc[events_df["exp_type"] == "clol", "fid"].unique()


clol_results_df = []
for fid in included_fish:
    gdict = dict()
    for cl_sel, k in zip(
        [
            (events_df["condition"] == "closed_loop"),
            (events_df["condition"] != "closed_loop"),
        ],
        ["cl", "ol"],
    ):
        sel_df = events_df[cl_sel & (events_df["fid"] == fid)]
        if len(sel_df) > 10:
            _, corr = linear_regression(sel_df["delta_phase"], sel_df["delta_head"])
        else:
            corr = np.nan
        gdict[k] = corr
        gdict["fid"] = fid

    clol_results_df.append(gdict)
clol_results_df = pd.DataFrame(clol_results_df).dropna()

clol_results_df = clol_results_df.set_index("fid")

In [None]:
if len(clol_results_df) > 0:
    cols = [
        COLS["stim_conditions"]["closed_loop"][1],
        COLS["stim_conditions"]["natural_motion"],
    ]

    f, ax = plt.subplots(figsize=(1.5, 2), gridspec_kw=dict(left=0.5, bottom=0.2))
    pltltr.tick_with_bars(
        clol_results_df, cols=cols, lw=1.5,
    )
    plt.plot(clol_results_df.values.T, lw=1, c=".8")
    ax.set(
        ylabel=(r"Δhead Δphase $\alpha$"),
        xticks=[0, 1],
        xticklabels=["closed \nloop", "open \nloop"],
        xlim=(-0.3, 1.3),
        ylim=(-4, 1),
    )
    test_res = wilcoxon(*[clol_results_df[k] for k in ["cl", "ol"]])
    ax.text(0.5, 0.5, pltltr.get_pval_stars(test_res), ha="center")
    pltltr.despine(ax)

    pltltr.savefig("cl_ol_comparison")

Also closed vs. open loop does not really seem to have any effect!

In [None]:
f, axs = plt.subplots(4, 2, figsize=(8, 4), gridspec_kw=dict(wspace=0.35, hspace=0.4))

axs_flat = axs.flatten()
results_df = []
for i, fid in tqdm(enumerate(included_fish)):
    path = [d for d in dataset_folders if d.name == fid][0]
    exp = LotrExperiment(path)
    phase = np.unwrap(exp.network_phase)

    fictive_head = exp.fictive_heading

    ax = axs_flat[i]
    pltltr.plot_exp_condition(
        all_stim_conds[all_stim_conds["fid"] == fid], ax=ax, alpha=0.2
    )
    twin_ax = ax.twinx()
    ax.plot(exp.time_arr, fictive_head, c=COLS["th_plot"], label="Estimated heading")
    ax.plot([], [], c=COLS["ph_plot"], label="Network phase")
    twin_ax.plot(
        exp.time_arr, np.unwrap(exp.network_phase), c=COLS["ph_plot"],
    )
    ax.set_title(exp.exp_code, fontsize=6, loc="left")
    lims = twin_ax.get_ylim()
    twin_ax.set_ylim((lims[1], lims[0]))
    pltltr.add_scalebar(
        ax, xlen=500, xpos=100, ylen=0, disable_axis=False, xlabel="", ylabel=""
    )

    [pltltr.despine(a, ["top", "bottom"]) for a in [ax, twin_ax]]
    for a, col in zip([ax, twin_ax], [COLS["th_plot"], COLS["ph_plot"]]):
        [t.set_color(col) for t in a.yaxis.get_ticklines()]
        [t.set_color(col) for t in a.yaxis.get_ticklabels()]
# ax.axvspan(t_off_s, t_off_s + t_dur_s, lw=0, fc=".9", zorder=-100)

[pltltr.despine(axs_flat[j], "all") for j in range(i + 1, len(axs_flat))]

axs[0, -1].legend(
    loc=2, bbox_to_anchor=(0.55, 1.7), labelcolor="linecolor", handlelength=0.0
)

ax_lab = f.add_axes((0.1, 0.07, 0.4, 0.05))
for i, (lab, col) in enumerate(
    zip(
        ["Closed loop", "Open loop"],
        [
            COLS["stim_conditions"]["closed_loop"][1],
            COLS["stim_conditions"]["natural_motion"],
        ],
    )
):
    ax_lab.text(i * 0.25, 0, lab, weight="bold", fontsize=8, c=col)
ax_lab.set(xlim=(-0.1, 0.8), ylim=(-0.1, 1))
ax_lab.axis("off")

pltltr.savefig("cl_ol_comparison", folder="S8")

In [None]:
for c in clol_results_df.columns:
    logger.add_entry(
        f"slope_{c}", clol_results_df[c], list(clol_results_df.index), moment="median"
    )

# Phase bout-triggered averages in closed loop experiments

In [None]:
if all_stim_cropped is not None:
    included_fish = events_df.loc[events_df["exp_type"] == "gainmod", "fid"].unique()
    # clol_results_df = []

    f, axs = plt.subplots(1, 3, figsize=(9, 3), gridspec_kw=dict(right=0.8))

    for d in ["fw", "lf", "rt"]:
        for gain, lum_c in zip(
            [0.5, 1, 2, -1],
            [pltltr.shift_lum(COLS["sides"][d], s) for s in [0, 0.2, 0.4]]
            + [np.array(COLS["sides"][d])[[2, 0, 1]],],
        ):
            avgs = [[], [], []]
            for fid in included_fish:
                sel_df = events_df[
                    (events_df["gain_theta"] == gain)
                    & (events_df["direction"] == d)
                    & (events_df["fid"] == fid)
                ]
                # plt.plot(np.nanmean(all_phase_cropped[:, sel_df.index], 1), c=COLS["sides"][d], lw=0.5)
                # avgs.append(np.nanmean(all_phase_cropped[:, sel_df.index], 1))
                # if len()
                if len(sel_df) > 5:
                    for i, to_crop in enumerate(
                        [all_phase_cropped, -all_head_cropped, all_stim_cropped]
                    ):
                        avgs[i].append(np.nanmedian(to_crop[:, sel_df.index], 1))
                        # avgs = np.nanmean(
                        #    to_crop[:, sel_df.index], 1
                        # np.nanmean(np.array(avgs), 0)

            for i, (ax, title) in enumerate(zip(axs, ["phase", "heading", "stim"])):
                ax.plot(
                    np.nanmean(avgs[i], 0), c=lum_c, lw=2, label=f"gain {gain}, dir {d}"
                )
                pltltr.despine(ax)
                ax.set_title(title)

            axs[2].legend(loc=2, bbox_to_anchor=(1, 1))

In [None]:
if all_stim_cropped is not None:
    included_fish = events_df.loc[events_df["exp_type"] == "gainmod", "fid"].unique()
    # clol_results_df = []
    plt.figure()
    for d in ["fw", "lf", "rt"]:
        for gain in [0.5, 1, 2, -1]:
            avgs = []
            for fid in included_fish:
                sel_df = events_df[
                    (events_df["gain_theta"] == gain)
                    & (events_df["direction"] == d)
                    & (events_df["fid"] == fid)
                ]
                # plt.plot(np.nanmean(all_phase_cropped[:, sel_df.index], 1), c=COLS["sides"][d], lw=0.5)
                avgs.append(np.nanmean(all_head_cropped[:, sel_df.index], 1))

            avgs = np.nanmean(np.array(avgs), 0)
            plt.plot(avgs, c=COLS["sides"][d], lw=2)

In [None]:
if all_stim_cropped is not None:
    included_fish = events_df.loc[events_df["exp_type"] == "gainmod", "fid"].unique()
    # clol_results_df = []
    plt.figure()
    for d in ["fw", "lf", "rt"]:
        for gain in [0.5, 1, 2, -1]:
            avgs = []
            for fid in included_fish:
                sel_df = events_df[
                    (events_df["gain_theta"] == gain)
                    & (events_df["direction"] == d)
                    & (events_df["fid"] == fid)
                ]
                # plt.plot(np.nanmean(all_phase_cropped[:, sel_df.index], 1), c=COLS["sides"][d], lw=0.5)
                avgs.append(np.nanmean(all_phase_cropped[:, sel_df.index], 1))

            avgs = np.nanmean(np.array(avgs), 0)
            plt.plot(avgs, c=COLS["sides"][d], lw=2)

In [None]:
plt.plot(avgs.T)
plt.show()

In [None]:
f, axs = plt.subplots(4, 2, figsize=(8, 4), gridspec_kw=dict(wspace=0.35, hspace=0.4))

axs_flat = axs.flatten()
results_df = []
for i, fid in tqdm(enumerate(included_fish)):
    path = [d for d in dataset_folders if d.name == fid][0]
    exp = LotrExperiment(path)
    phase = np.unwrap(exp.network_phase)

    fictive_head = exp.fictive_heading

    ax = axs_flat[i]
    pltltr.plot_exp_condition(
        all_stim_conds[all_stim_conds["fid"] == fid], ax=ax, alpha=0.2
    )
    twin_ax = ax.twinx()
    ax.plot(exp.time_arr, fictive_head, c=COLS["th_plot"], label="Estimated heading")
    ax.plot([], [], c=COLS["ph_plot"], label="Network phase")
    twin_ax.plot(
        exp.time_arr, np.unwrap(exp.network_phase), c=COLS["ph_plot"],
    )
    ax.set_title(exp.exp_code, fontsize=6, loc="left")
    lims = twin_ax.get_ylim()
    twin_ax.set_ylim((lims[1], lims[0]))
    pltltr.add_scalebar(
        ax, xlen=500, xpos=100, ylen=0, disable_axis=False, xlabel="", ylabel=""
    )

    [pltltr.despine(a, ["top", "bottom"]) for a in [ax, twin_ax]]
    for a, col in zip([ax, twin_ax], [COLS["th_plot"], COLS["ph_plot"]]):
        [t.set_color(col) for t in a.yaxis.get_ticklines()]
        [t.set_color(col) for t in a.yaxis.get_ticklabels()]
# ax.axvspan(t_off_s, t_off_s + t_dur_s, lw=0, fc=".9", zorder=-100)

[pltltr.despine(axs_flat[j], "all") for j in range(i + 1, len(axs_flat))]

axs[0, -1].legend(
    loc=2, bbox_to_anchor=(0.55, 1.7), labelcolor="linecolor", handlelength=0.0
)

ax_lab = f.add_axes((0.1, 0.07, 0.4, 0.05))
for i, (lab, col) in enumerate(
    zip(
        ["Closed loop", "Open loop"],
        [
            COLS["stim_conditions"]["closed_loop"][1],
            COLS["stim_conditions"]["natural_motion"],
        ],
    )
):
    ax_lab.text(i * 0.25, 0, lab, weight="bold", fontsize=8, c=col)
ax_lab.set(xlim=(-0.1, 0.8), ylim=(-0.1, 1))
ax_lab.axis("off")

pltltr.savefig("cl_ol_comparison", folder="S8")

In [None]:
for c in clol_results_df.columns:
    logger.add_entry(
        f"slope_{c}", clol_results_df[c], list(clol_results_df.index), moment="median"
    )

# Phase bout-triggered averages in closed loop experiments

In [None]:
if all_stim_cropped is not None:
    included_fish = events_df.loc[events_df["exp_type"] == "gainmod", "fid"].unique()
    # clol_results_df = []

    f, axs = plt.subplots(1, 3, figsize=(9, 3), gridspec_kw=dict(right=0.8))

    for d in ["fw", "lf", "rt"]:
        for gain, lum_c in zip(
            [0.5, 1, 2, -1],
            [pltltr.shift_lum(COLS["sides"][d], s) for s in [0, 0.2, 0.4]]
            + [np.array(COLS["sides"][d])[[2, 0, 1]],],
        ):
            avgs = [[], [], []]
            for fid in included_fish:
                sel_df = events_df[
                    (events_df["gain_theta"] == gain)
                    & (events_df["direction"] == d)
                    & (events_df["fid"] == fid)
                ]
                # plt.plot(np.nanmean(all_phase_cropped[:, sel_df.index], 1), c=COLS["sides"][d], lw=0.5)
                # avgs.append(np.nanmean(all_phase_cropped[:, sel_df.index], 1))
                # if len()
                if len(sel_df) > 5:
                    for i, to_crop in enumerate(
                        [all_phase_cropped, -all_head_cropped, all_stim_cropped]
                    ):
                        avgs[i].append(np.nanmedian(to_crop[:, sel_df.index], 1))
                        # avgs = np.nanmean(
                        #    to_crop[:, sel_df.index], 1
                        # np.nanmean(np.array(avgs), 0)

            for i, (ax, title) in enumerate(zip(axs, ["phase", "heading", "stim"])):
                ax.plot(
                    np.nanmean(avgs[i], 0), c=lum_c, lw=2, label=f"gain {gain}, dir {d}"
                )
                pltltr.despine(ax)
                ax.set_title(title)

            axs[2].legend(loc=2, bbox_to_anchor=(1, 1))

In [None]:
if all_stim_cropped is not None:
    included_fish = events_df.loc[events_df["exp_type"] == "gainmod", "fid"].unique()
    # clol_results_df = []
    plt.figure()
    for d in ["fw", "lf", "rt"]:
        for gain in [0.5, 1, 2, -1]:
            avgs = []
            for fid in included_fish:
                sel_df = events_df[
                    (events_df["gain_theta"] == gain)
                    & (events_df["direction"] == d)
                    & (events_df["fid"] == fid)
                ]
                # plt.plot(np.nanmean(all_phase_cropped[:, sel_df.index], 1), c=COLS["sides"][d], lw=0.5)
                avgs.append(np.nanmean(all_head_cropped[:, sel_df.index], 1))

            avgs = np.nanmean(np.array(avgs), 0)
            plt.plot(avgs, c=COLS["sides"][d], lw=2)

In [None]:
if all_stim_cropped is not None:
    included_fish = events_df.loc[events_df["exp_type"] == "gainmod", "fid"].unique()
    # clol_results_df = []
    plt.figure()
    for d in ["fw", "lf", "rt"]:
        for gain in [0.5, 1, 2, -1]:
            avgs = []
            for fid in included_fish:
                sel_df = events_df[
                    (events_df["gain_theta"] == gain)
                    & (events_df["direction"] == d)
                    & (events_df["fid"] == fid)
                ]
                # plt.plot(np.nanmean(all_phase_cropped[:, sel_df.index], 1), c=COLS["sides"][d], lw=0.5)
                avgs.append(np.nanmean(all_phase_cropped[:, sel_df.index], 1))

            avgs = np.nanmean(np.array(avgs), 0)
            plt.plot(avgs, c=COLS["sides"][d], lw=2)

In [None]:
plt.plot(avgs.T)
plt.show()