# Analyses on the stability of phase over time

In [None]:
%matplotlib widget

import numpy as np
from matplotlib import pyplot as plt

import lotr.plotting as pltltr
from lotr import A_FISH
from lotr.utils import crop

import pandas as pd
from tqdm import tqdm

from lotr import dataset_folders
from lotr.default_vals import DEFAULT_FN, POST_BOUT_WND_S, PRE_BOUT_WND_S
from lotr.experiment_class import LotrExperiment
from lotr.utils import interpolate, resample_matrix
from scipy.stats import mannwhitneyu
from scipy.stats import gaussian_kde

COLS = pltltr.COLS

In [None]:
def get_nobouts_period_idxs(
    exp,
    exclude_pre_bts_s=30,
    exclude_post_bts_s=30,
    exclude_overlap_pre=PRE_BOUT_WND_S,
    exclude_overlap_post=POST_BOUT_WND_S,
):
    # Generate array of indexes with no occourrences of bouts before and after
    valid_locations = np.arange(
        exclude_pre_bts_s * exp.fs, exp.n_pts - exclude_post_bts_s * exp.fs
    )
    for bout_i in exp.bouts_df["idx_imaging"]:
        filter_in = (valid_locations < bout_i - exclude_pre_bts_s * exp.fs) | (
            valid_locations > bout_i + exclude_post_bts_s * exp.fs
        )
        valid_locations = valid_locations[filter_in]

    # avoid too close croppings of resting periods (//2 allows some very loose overlapping of half, ie, basically no overlap):
    i = 0
    while i < len(valid_locations):
        filter_in = (
            (valid_locations < valid_locations[i] - exclude_overlap_pre * exp.fs // 2)
            | (
                valid_locations
                > valid_locations[i] + exclude_overlap_post * exp.fs // 2
            )
            | (valid_locations == valid_locations[i])
        )
        valid_locations = valid_locations[filter_in]
        i += 1

    return valid_locations

In [None]:
all_cropped = dict()
for path in tqdm(dataset_folders):
    exp = LotrExperiment(path)
    valid_locations = get_nobouts_period_idxs(exp)
    valid_bouts = exp.bouts_df.loc[exp.bouts_df["direction"] != "fw", "idx_imaging"]
    idxs_rest = valid_locations
    idxs_bouts = valid_bouts

    unwrap_phase = np.unwrap(exp.network_phase)
    all_cropped[exp.exp_code] = dict()
    for k, idxs in zip(["rest", "bouts"], [idxs_rest, idxs_bouts]):
        cropped = crop(
            unwrap_phase,
            idxs,
            pre_int=PRE_BOUT_WND_S * exp.fs,
            post_int=POST_BOUT_WND_S * exp.fs,
        )
        cropped = cropped - np.mean(cropped[: PRE_BOUT_WND_S * exp.fs, :], axis=0)
        cropped = cropped * np.sign(cropped[-1, :])

        all_cropped[exp.exp_code][k] = cropped

    all_cropped[exp.exp_code]["has_eyes"] = len(list(path.glob("*dlc*"))) > 0

In [None]:
calculate_drift_after_s = 20  # drift after n seconds

plt.close("all")
cols = [".5", "C0"]


f, ax = plt.subplots(figsize=(1.5, 2.5), gridspec_kw=dict(left=0.3, bottom=0.25))

all_rest = []
all_bouts = []
for k in all_cropped.keys():
    dict_to_plot = all_cropped[k]
    if not dict_to_plot["has_eyes"]:
        for k2, l in zip(["rest", "bouts"], [all_rest, all_bouts]):
            cropped = dict_to_plot[k2]
            try:
                l.append(
                    np.mean(
                        cropped[(PRE_BOUT_WND_S + calculate_drift_after_s) * exp.fs, :]
                        - cropped[PRE_BOUT_WND_S * exp.fs, :]
                    )
                )
            except IndexError:
                pass

plt.plot(np.stack([all_rest, all_bouts]), lw=1, c=".8")
w = 0.05
for i, (col, l) in enumerate(zip(cols, [all_rest, all_bouts])):
    ax.plot(
        [
            i,
        ]
        * 2,
        [np.percentile(l, p) for p in [25, 75]],
        c=col,
    )
    ax.plot(
        [i - w, i + w],
        [
            np.percentile(l, 50),
        ]
        * 2,
        c=col,
    )
plt.show()
ax.set(
    xlim=(-0.2, 1.2),
    ylabel="abs. phase drift after 30 s",
    xticks=[0, 1],
    xticklabels=["resting", "motion"],
)
pltltr.despine(ax)

In [None]:
# Get unwrapped phase:
exp = LotrExperiment(A_FISH)
unwrapped_ph = np.unwrap(exp.network_phase)

# Crop network phase:
cropped_phase = crop(
    unwrapped_ph,
    exp.bouts_df["idx_imaging"],
    pre_int=PRE_BOUT_WND_S * exp.fn,
    post_int=POST_BOUT_WND_S * exp.fn,
)
# Subtract baseline:
cropped_phase = cropped_phase - np.mean(cropped_phase[: PRE_BOUT_WND_S * exp.fn, :], 0)

# Same, for heading direction:
cropped_head = crop(
    exp.fictive_heading,
    exp.bouts_df["idx_imaging"],
    pre_int=PRE_BOUT_WND_S * exp.fn,
    post_int=POST_BOUT_WND_S * exp.fn,
)
cropped_head = cropped_head - np.mean(cropped_head[: PRE_BOUT_WND_S * exp.fn, :], 0)


time_arr = np.arange(cropped_phase.shape[0]) / exp.fn - PRE_BOUT_WND_S

In [None]:
def custom_crop_shifts_all_dataset(crop_stimulus=False):
    """Crop fictive heading and network phase around bouts from all fish.
    in the dataset. For a demo of what is happening, "4. Phase dynamics.ipynb" notebook.

    Returns
    -------
    (all_phase_cropped, all_head_cropped, events_df)
        The first two returns are the cropped n_tpts x n_bouts matrices, the third is
        the dataframe that contains the info about all events.

    """
    fn = DEFAULT_FN

    all_phase_cropped = []
    all_head_cropped = []
    all_stim_cropped = []
    # We will create a dataframe to keep track of events from all fish.
    # Mostly a way of keeping together the crop and the bouts:
    events_df = []

    # Define temporal array for the resampling:
    time_arr = (
        np.arange(1, ((PRE_BOUT_WND_S + POST_BOUT_WND_S) * fn) + 1) / fn
        - PRE_BOUT_WND_S
    )
    for path in tqdm(dataset_folders):
        exp = LotrExperiment(path)
        # TODO recompute to avoid this bugfix
        exp.bouts_df["fid"] = path.name
        nobout_locations = get_nobouts_period_idxs(exp)

        crop_events_df = exp.bouts_df.loc[
            :, ["bias", "direction", "t_start", "idx_imaging", "fid"]
        ]

        n_nob = len(nobout_locations)
        nobout_events_df = pd.DataFrame(
            dict(
                bias=np.full(n_nob, np.nan),
                direction=np.full(n_nob, "nb"),
                t_start=nobout_locations / exp.fs,
                idx_imaging=nobout_locations,
                fid=path.name,
            )
        )

        crop_events_df = pd.concat([crop_events_df, nobout_events_df], axis=0)

        stim_interp = np.full(exp.n_pts, np.nan)
        try:
            stim_df = exp.stimulus_log
            if "cl2D_theta" in stim_df.columns and crop_stimulus:
                stim_interp = interpolate(
                    stim_df["t"], stim_df["cl2D_theta"], exp.time_arr
                )
        except AttributeError:
            pass

        # Crop both the fictive heading (cumulative tail theta sum) and network phase
        # in the same way:
        for dest_list, to_crop in zip(
            [all_phase_cropped, all_head_cropped, all_stim_cropped],
            [np.unwrap(exp.network_phase), exp.fictive_heading, stim_interp],
        ):
            # Crop around events:
            cropped = crop(
                to_crop,
                crop_events_df["idx_imaging"],
                pre_int=PRE_BOUT_WND_S * exp.fs,
                post_int=POST_BOUT_WND_S * exp.fs,
            )

            # Subtract baseline:
            cropped = cropped - np.mean(cropped[: PRE_BOUT_WND_S * exp.fs, :], 0)

            # Interpolate if necessary:
            if exp.fs != fn:
                fish_time_arr = (
                    np.arange(1, cropped.shape[0] + 1) / exp.fs - PRE_BOUT_WND_S
                )
                cropped = resample_matrix(time_arr, fish_time_arr, cropped)

            dest_list.append(cropped)
        events_df.append(crop_events_df.reindex())

    # Concatenate all the results:
    all_phase_cropped = np.concatenate(all_phase_cropped, axis=1)
    all_head_cropped = np.concatenate(all_head_cropped, axis=1)
    try:
        all_stim_cropped = np.concatenate(all_stim_cropped, axis=1)
    except np.AxisError:
        all_stim_cropped = None
    events_df = pd.concat(events_df, ignore_index=True)

    if crop_stimulus:
        return (
            all_phase_cropped,
            all_head_cropped,
            all_stim_cropped,
            events_df,
            time_arr,
        )
    else:
        return all_phase_cropped, all_head_cropped, events_df, time_arr

In [None]:
(
    all_phase_cropped,
    all_head_cropped,
    events_df,
    time_arr,
) = custom_crop_shifts_all_dataset()

In [None]:
def plot_bout_trig(
    events_df, cropped_list, labels_list, ylims=(-4.2, 4.2), legend_lab="{} bouts"
):
    f, axs = plt.subplots(
        1,
        2,
        figsize=(5, 2.5),
        gridspec_kw=dict(left=0.08, bottom=0.15, top=0.9, right=0.73),
        sharey=True,
    )

    for ax, lab, cropped in zip(axs, labels_list, cropped_list):
        for d in events_df["direction"].unique():
            sel = events_df["direction"] == d
            print(cropped.shape, sel.shape)
            ax.plot(
                time_arr,
                cropped[:, sel],
                lw=0.3,
                c=COLS["sides"][d],
                label="_nolabel_",
            )
            ax.plot(
                time_arr,
                np.mean(cropped[:, sel], 1),
                lw=2,
                c=pltltr.dark_col(COLS["sides"][d]),
                label=legend_lab.format(d) if legend_lab is not None else "__nolabel__",
                zorder=30,
            )

        pltltr.despine(ax)
        ax.set(
            xlabel="time from bout (s)",
            ylim=ylims,
            **pltltr.get_pi_labels(0.5, ax="y"),
        )
        ax.set_title(lab, weight="bold")
        ax.axvline(0, lw=0.5, c=".5")
    if legend_lab is not None:
        axs[1].legend(bbox_to_anchor=(1.05, 0.8, 0.2, 0.2))

    return f, axs

In [None]:
# For every fish, compute average response for each direction:
all_phase_means = []
all_head_means = []
mean_events_df = []
for fid in events_df["fid"].unique():
    for d in events_df["direction"].unique():
        sel = (events_df["direction"] == d) & (events_df["fid"] == fid)

        if sum(sel) > 10:
            all_phase_means.append(all_phase_cropped[:, sel].mean(1))
            all_head_means.append(all_head_cropped[:, sel].mean(1))
            mean_events_df.append(events_df[sel].iloc[0, :])

all_phase_means = np.stack(all_phase_means).T
all_head_means = np.stack(all_head_means).T
mean_events_df = pd.DataFrame(mean_events_df)

In [None]:
COLS["sides"]["nb"] = ".8"
f, axs = plot_bout_trig(
    mean_events_df,
    [all_phase_means, all_head_means],
    ["Δphase", "Δheading"],
    legend_lab="{} (all fish + mn)",
    ylims=(-np.pi, np.pi),
)

In [None]:
# values from calculation of phase change after bout:
quantify_after_s = 20
quantify_for_s = 5
wnd_delta_phase = np.array([quantify_after_s, quantify_after_s + quantify_for_s])

wnd_pts = (PRE_BOUT_WND_S + wnd_delta_phase) * DEFAULT_FN
events_df["Δphase"] = np.nanmean(all_phase_cropped[slice(*wnd_pts), :], 0)
events_df["Δhead"] = np.nanmean(all_head_cropped[slice(*wnd_pts), :], 0)

xarr = np.arange(-np.pi * 1.5, np.pi * 1.5, 0.01)
# For every fish, compute average response for each direction:
all_phase_kdes = []
all_head_kdes = []
for fid in events_df["fid"].unique():
    for d in events_df["direction"].unique():
        sel = (events_df["direction"] == d) & (events_df["fid"] == fid)

        if sum(sel) > 10:
            for var, list_to_use in zip(
                ["Δphase", "Δhead"], [all_phase_kdes, all_head_kdes]
            ):
                kde_f = gaussian_kde(events_df.loc[sel, var])
                list_to_use.append(kde_f(xarr))

all_phase_kdes = np.stack(all_phase_kdes)
all_head_kdes = np.stack(all_head_kdes)

In [None]:
def simpleaxis(ax):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()

In [None]:
plt.close("all")
COLS["sides"]["nb"] = ".8"
f, axs = plot_bout_trig(
    mean_events_df,
    [all_head_means, all_phase_means],
    ["Δheading", "Δphase"],
    legend_lab=None,
    ylims=(-np.pi, np.pi),
)


ax = f.add_axes((0.8, 0.15, 0.19, 0.75))
for direction in ["lf", "rt", "nb"]:
    # direction = "nb"
    mn_hist = []
    for fid in events_df["fid"].unique():
        c = COLS["sides"][direction]
        sel = (mean_events_df["direction"] == direction) & (
            mean_events_df["fid"] == fid
        )
        if sum(sel) > 0:
            kde_f = all_phase_kdes[sel, :]
            mn_hist.append(kde_f)
            plt.plot(kde_f[0, :] / np.sum(kde_f[0, :]), xarr, c=c, lw=1, alpha=0.5)

    mn_hist = np.mean(np.concatenate(mn_hist), 0)
    plt.plot(
        mn_hist / np.sum(mn_hist), xarr, c=pltltr.shift_lum(c, -0.2), zorder=100, lw=1.5
    )
axs[0].set(ylim=(-np.pi * 0.7, np.pi * 0.7), **pltltr.get_pi_labels(0.5, ax="y"))
ax.set(
    xlabel="P(drift)",
    xlim=(-0.001, 0.015),
    ylim=(-np.pi * 0.7, np.pi * 0.7),
    yticks=pltltr.get_pi_labels(0.5, ax="y")["yticks"],
    yticklabels=[],
)
plt.show()
pltltr.despine(ax)

In [None]:
plt.close("all")
COLS["sides"]["nb"] = ".8"
f, axs = plot_bout_trig(
    mean_events_df,
    [all_head_means, all_phase_means],
    ["Δheading", "Δphase"],
    legend_lab=None,
    ylims=(-np.pi, np.pi),
)


ax = f.add_axes((0.8, 0.15, 0.19, 0.75))
for direction in ["nb"]:
    mn_hist = []
    for fid in events_df["fid"].unique():
        c = COLS["sides"][direction]
        sel = (mean_events_df["direction"] == direction) & (
            mean_events_df["fid"] == fid
        )
        if sum(sel) > 0:
            kde_f = all_phase_kdes[sel, :]
            mn_hist.append(kde_f)
            plt.plot(kde_f[0, :] / np.sum(kde_f[0, :]), xarr, c=c, lw=1, alpha=0.5)

    mn_hist = np.mean(np.concatenate(mn_hist), 0)
    plt.plot(
        mn_hist / np.sum(mn_hist), xarr, c=pltltr.shift_lum(c, -0.2), zorder=100, lw=1.5
    )
axs[0].set(ylim=(-np.pi * 0.7, np.pi * 0.7), **pltltr.get_pi_labels(0.5, ax="y"))
ax.set(
    xlabel="P(drift)",
    xlim=(-0.001, 0.015),
    ylim=(-np.pi * 0.7, np.pi * 0.7),
    yticks=pltltr.get_pi_labels(0.5, ax="y")["yticks"],
    yticklabels=[],
)
plt.show()
pltltr.despine(ax)

In [None]:
plt.close("all")
COLS["sides"]["nb"] = ".8"
f, axs = plot_bout_trig(
    mean_events_df,
    [all_head_means, all_phase_means],
    ["Δheading", "Δphase"],
    legend_lab=None,
    ylims=(-np.pi, np.pi),
)

ax = f.add_axes((0.8, 0.15, 0.19, 0.75))

for direction in ["nb"]:
    mn_hist = []
    for fid in events_df["fid"].unique():
        c = COLS["sides"][direction]
        sel = (mean_events_df["direction"] == direction) & (
            mean_events_df["fid"] == fid
        )
        if sum(sel) > 0:
            kde_f = all_phase_kdes[sel, :]
            mn_hist.append(kde_f)
            plt.plot(kde_f[0, :] / np.sum(kde_f[0, :]), xarr, c=c, lw=1, alpha=0.5)

    mn_hist = np.mean(np.concatenate(mn_hist), 0)
    plt.plot(
        mn_hist / np.sum(mn_hist), xarr, c=pltltr.shift_lum(c, -0.2), zorder=100, lw=1.5
    )

for direction in ["lf", "rt", "fw"]:
    sel_df = events_df[events_df["direction"] == direction]
    kde_f = gaussian_kde(sel_df.groupby("fid").mean()["Δphase"])
    mn_hist = kde_f(xarr)
    plt.plot(
        mn_hist / np.sum(mn_hist),
        xarr,
        c=COLS["sides"][direction],
        zorder=100,
        lw=1.5,
    )

axs[0].set(ylim=(-np.pi * 0.7, np.pi * 0.7), **pltltr.get_pi_labels(0.5, ax="y"))
ax.set(
    xlabel="P(drift)",
    xlim=(-0.001, 0.015),
    ylim=(-np.pi * 0.7, np.pi * 0.7),
    yticks=pltltr.get_pi_labels(0.5, ax="y")["yticks"],
    yticklabels=[],
)
plt.show()
pltltr.despine(ax)
pltltr.savefig("phase_shift_with_histo")

In [None]:
plt.close("all")
COLS["sides"]["nb"] = ".8"
f, axs = plot_bout_trig(
    mean_events_df,
    [all_head_means, all_phase_means],
    ["Δheading", "Δphase"],
    legend_lab=None,
    ylims=(-np.pi, np.pi),
)

ax = f.add_axes((0.8, 0.15, 0.19, 0.75))
for direction in ["lf", "rt"]:
    sel_df = events_df[events_df["direction"] == direction]
    kde_f = gaussian_kde(sel_df.groupby("fid").mean()["Δphase"])
    mn_hist = kde_f(xarr)
    plt.plot(
        mn_hist / np.sum(mn_hist),
        xarr,
        c=COLS["sides"][direction],
        zorder=100,
        lw=1.5,
    )
axs[0].set(ylim=(-np.pi * 0.7, np.pi * 0.7), **pltltr.get_pi_labels(0.5, ax="y"))
ax.set(
    xlabel="P(drift)",
    xlim=(-0.001, 0.015),
    ylim=(-np.pi * 0.7, np.pi * 0.7),
    yticks=pltltr.get_pi_labels(0.5, ax="y")["yticks"],
    yticklabels=[],
)
plt.show()
pltltr.despine(ax)

In [None]:
pltltr.savefig("stability_histograms.pdf")

In [None]:
events_df

## Heading decoder

In [None]:
N_BINS = 1000  # bins for cumulative counts
COUNTS_LIMS = np.pi * 1.2

all_ids = events_df["fid"].unique()  # all fish ids
bt_dirs = ["lf", "rt", "fw", "nb"]  # directions identifiers

bin_angles = np.linspace(-COUNTS_LIMS, COUNTS_LIMS, N_BINS + 1)

# n_directions x n_fish x n_bins matrix:
distributions_matrix = np.full((len(bt_dirs), len(all_ids), N_BINS), np.nan)

for k, fid in enumerate(all_ids):  # loop over fish
    fish_df = events_df[events_df["fid"] == fid]

    random_identities = fish_df["direction"].values.copy()
    np.random.shuffle(random_identities)

    for i, bt_dir in enumerate(bt_dirs):  # loop over bout directions
        dphase = fish_df.loc[fish_df["direction"] == bt_dir, "Δphase"]

        # count delta phase below current theta value:
        if len(dphase) > 3:
            counts, bins = np.histogram(dphase, bin_angles)
            distributions_matrix[i, k, :] = np.cumsum(counts) / sum(counts)

### Illustrative plot for ROCs

In [None]:
fig, axs = plt.subplots(
    2, 3, figsize=(3.5, 3.5), gridspec_kw=dict(bottom=0.2, hspace=0.8)
)

f_idx = 1
fish_df = events_df[events_df["fid"] == all_ids[f_idx]]
eps = 1
hist_array = np.linspace(-np.pi - eps, np.pi + eps, 15)
direction_labels = dict(lf="Left", rt="Right", fw="Forward", nb="No bout")
alpha = 0.6
lumshift = 0.1
lw = 1
ylims = -0.03, 0.5
xlims = hist_array[0] - 0.2, hist_array[-1] + 0.2


pairs = [(0, 1), (0, 2), (2, 3)]


for i, pair in enumerate(pairs):
    ax_hist, ax_roc = axs[:, i]

    # Histograms:
    for p in pair:
        col = COLS["sides"][bt_dirs[p]]
        w_array = fish_df.loc[fish_df["direction"] == bt_dirs[p], "Δphase"]
        count, b = np.histogram(w_array, hist_array, density=True)
        ax_hist.fill_between(
            (b[1:] + b[:-1]) / 2,
            count / sum(count),
            step="mid",
            alpha=alpha,
            lw=lw,
            fc=pltltr.shift_lum(col, lumshift),
            ec=pltltr.shift_lum(col, -lumshift),
            label="__nolegend__",
        )

    if i > 0:
        pltltr.despine(ax_hist, ["top", "right", "left"])
        ax_roc.set(yticklabels=[])
    else:
        pltltr.despine(ax_hist)

        ax_hist.set(ylabel="Probability")
        ax_roc.set(yticks=(0, 0.5, 1), ylabel="True positives (fract.)")

    pltltr.despine(ax_roc)
    ax_hist.set(**pltltr.get_pi_labels(1), ylim=ylims)

    # ROC curves:
    ax_roc.plot(
        distributions_matrix[pair[0], f_idx, :],
        distributions_matrix[pair[1], f_idx, :],
        c=".5",
    )

    ax_roc.set(
        ylim=(-0.1, 1.1),
        xlim=(-0.1, 1.1),
        xticks=(0, 0.5, 1),
    )

    for p, spine, ax_n in zip(pair, ["bottom", "left"], ["x", "y"]):
        col = COLS["sides"][bt_dirs[p]]
        ax_roc.spines[spine].set_color(col)
        ax_roc.tick_params(axis=ax_n, colors=col)

axs[1, 1].set(xlabel="False positives (fract)")
axs[0, 1].set(xlabel=r"$\Delta$ phase after 20 s")
# Legend:
legend_ax = axs[0, 2]
for d in bt_dirs:
    col = COLS["sides"][d]
    legend_ax.fill_between(
        [],
        [],
        step="mid",
        alpha=alpha,
        lw=lw,
        fc=pltltr.shift_lum(col, lumshift),
        ec=pltltr.shift_lum(col, -lumshift),
        label=direction_labels[d],
    )

legend_ax.legend(loc=1, bbox_to_anchor=(0.2, 0.6, 0.5, 0.5), ncol=2)

# breakpoints:
midpoints = (5, 33.3, 66.6, 98)

for r, m in zip(["I", "II", "III", "IV"], midpoints):
    x = np.percentile(hist_array, m)
    axs[0, 0].plot([x, x], [0, 0.4], lw=0.5, c=".3")
    axs[0, 0].text(x, 0.42, r, ha="center", va="bottom", fontsize=8)

    x, y = np.percentile(distributions_matrix[[0, 1], f_idx, :], m, axis=1)
    axs[1, 0].scatter(
        x,
        y,
        color=".3",
        s=15,
        zorder=100,
    )
    axs[1, 0].text(x + 0.08, y, r, ha="left", va="center", fontsize=8)

pltltr.savefig("roc_explanation", folder="S7b")

### ROCs plot and statistics

In [None]:
def get_interp_roc(x, y, stepsize=0.001):
    xarr = np.arange(0, 1, stepsize)
    return np.interp(xarr, x, y)


def get_aoc(x, y, stepsize=0.001):
    """Calculate AOC."""
    return np.sum(get_interp_roc(x, y, stepsize)) * stepsize


def get_aoc_pval(aocs):
    """follows https://rmets.onlinelibrary.wiley.com/doi/10.1256/003590002320603584."""
    sel_aocs = [aoc for aoc in aocs if not np.isnan(aoc)]
    test = mannwhitneyu(
        sel_aocs,
        [
            0.5,
        ]
        * len(sel_aocs),
    )

    return test.pvalue

In [None]:
fig, axs = plt.subplots(
    len(bt_dirs),
    len(bt_dirs),
    figsize=(4, 4),
    sharex=True,
    sharey=True,
    gridspec_kw=dict(width_ratios=(1, 1, 1, 0.001), height_ratios=(0.001, 1, 1, 1)),
)

f_c = "0.7"
f_lw = 0.5
m_c = "C0"
m_lw = 1

for i, lab1 in enumerate(bt_dirs):
    for j, lab2 in enumerate(bt_dirs):
        try:
            ax = axs[i, j]
        except IndexError:
            continue
        if j < i:
            dist1, dist2 = distributions_matrix[[i, j], :, :]

            for i_fish, fid in enumerate(all_ids):
                ax.plot(dist1[i_fish, :], dist2[i_fish, :], f_c, linewidth=f_lw)
            ax.plot(
                np.nanmedian(dist1, axis=0),
                np.nanmedian(dist2, axis=0),
                m_c,
                linewidth=m_lw,
            )

            aocs = np.array([get_aoc(dist1[k, :], dist2[k, :]) for k in range(31)])

            ax.text(
                0.1, 0.75, f"AOC={np.nanmedian(aocs):0.2f}\np={get_aoc_pval(aocs):0.2e}"
            )

        else:
            ax.axis("off")

        if i == len(bt_dirs) - 1:
            ax.set(xlabel=direction_labels[lab2], yticks=[0, 1])
            ax.xaxis.label.set_color(COLS["sides"][lab2])
        if j == 0:
            ax.set(ylabel=direction_labels[lab1], yticks=[0, 1])
            ax.yaxis.label.set_color(COLS["sides"][lab1])

legend_axs = axs[1, 1]
legend_axs.plot([], [], lw=f_lw, c=f_c, label="single fish")
legend_axs.plot([], [], lw=m_lw, c=m_c, label="median")
legend_axs.legend()


plt.show()
pltltr.savefig("roc_curves", folder="S7b")

## Shuffle illustration

!! Code here might take some minutes to run with 100k suffles.

In [None]:
n_synt_neurons = len(all_ids)
n_shuffles = 100000

f_c = "0.7"
f_lw = 0.5
m_c = "C0"
m_lw = 1

i, j = 0, 1  # left and right bouts
dist = np.concatenate([distributions_matrix[idx, :, :] for idx in [i, j]], axis=0)

# shuffle:
all_shuf_aocs = np.zeros(n_shuffles)
all_shuf_rocs = np.zeros((n_shuffles, distributions_matrix.shape[-1]))
for draw in tqdm(range(n_shuffles)):
    indices = np.random.randint(0, len(all_ids) * 2, n_synt_neurons * 2)
    dist1 = np.nanmedian(dist[indices[:n_synt_neurons]], axis=0)
    dist2 = np.nanmedian(dist[indices[n_synt_neurons:]], axis=0)

    all_shuf_rocs[draw, :] = get_interp_roc(dist1, dist2)
    all_shuf_aocs[draw] = get_aoc(dist1, dist2)

# real distributions:
dist1, dist2 = distributions_matrix[[i, j], :, :]

In [None]:
plt.close("all")
f, axs = plt.subplots(
    1, 2, figsize=(4, 1.8), gridspec_kw=dict(left=0.2, bottom=0.2, wspace=0.5)
)
roc_ax, hist_ax = axs
for i, (col, perc) in enumerate(zip([".8", ".7", ".6"], [0.05, 0.5, 2.5])):
    perc1 = np.nanpercentile(all_shuf_rocs, perc, axis=0)
    perc2 = np.nanpercentile(all_shuf_rocs, 100 - perc, axis=0)
    roc_ax.fill_between(np.linspace(0, 1, len(perc1)), perc1, perc2, lw=0, fc=col)
    roc_ax.text(0.7, 0.1 * i, f"{100-perc*2}%", c=pltltr.shift_lum(col, -0.2))
roc_ax.plot(
    np.nanmedian(dist1, axis=0),
    np.nanmedian(dist2, axis=0),
    m_c,
    linewidth=m_lw,
)
roc_ax.set(
    xlabel="False positives",
    ylabel="True positives",
    yticks=[0, 0.5, 1],
    xticks=[0, 0.5, 1],
)

hist_ax.hist(all_shuf_aocs, np.arange(0, 1, 0.01), lw=0, fc=".7", label="shuffles")
data_aoc = get_aoc(np.nanmedian(dist1, axis=0), np.nanmedian(dist2, axis=0))
pval = sum(all_shuf_aocs > data_aoc) / len(all_shuf_aocs)
hist_ax.axvline(data_aoc, label="data")
hist_ax.text(0.5, 4000, f"p={pval:0.1e}")
hist_ax.set(xlabel="AOC", ylabel="Counts", ylim=(0, 8000))
hist_ax.legend()
pltltr.despine(hist_ax)
plt.show()
pltltr.savefig("stat_illustration", folder="S7b")

In [None]:
fig, axs = plt.subplots(
    len(bt_dirs),
    len(bt_dirs),
    figsize=(4, 4),
    sharex=True,
    sharey=True,
    gridspec_kw=dict(width_ratios=(1, 1, 1, 0.001), height_ratios=(0.001, 1, 1, 1)),
)

n_shuffles = 10000

for i, lab1 in tqdm(list(enumerate(bt_dirs))):
    for j, lab2 in enumerate(bt_dirs):
        try:
            ax = axs[i, j]
        except IndexError:
            continue
        if j < i:
            dist1, dist2 = distributions_matrix[[i, j], :, :]

            for i_fish, fid in enumerate(all_ids):
                ax.plot(dist1[i_fish, :], dist2[i_fish, :], f_c, linewidth=f_lw)
            ax.plot(
                np.nanmedian(dist1, axis=0),
                np.nanmedian(dist2, axis=0),
                m_c,
                linewidth=m_lw,
            )

            aocs = np.array([get_aoc(dist1[k, :], dist2[k, :]) for k in range(31)])

            # shuffle:
            all_shuf_aocs = np.zeros(n_shuffles)
            all_shuf_rocs = np.zeros((n_shuffles, distributions_matrix.shape[-1]))
            for draw in range(n_shuffles):
                indices = np.random.randint(0, len(all_ids) * 2, n_synt_neurons * 2)
                dist1 = np.nanmedian(dist[indices[:n_synt_neurons]], axis=0)
                dist2 = np.nanmedian(dist[indices[n_synt_neurons:]], axis=0)

                all_shuf_rocs[draw, :] = get_interp_roc(dist1, dist2)
                all_shuf_aocs[draw] = get_aoc(dist1, dist2)

            data_aoc = np.nanmedian(aocs)
            pval = min(
                sum(all_shuf_aocs > data_aoc) / len(all_shuf_aocs),
                sum(all_shuf_aocs < data_aoc) / len(all_shuf_aocs),
            )

            ax.text(0.1, 0.75, f"AOC={data_aoc:0.2f}\np={pval:0.2e}")

        else:
            ax.axis("off")

        if i == len(bt_dirs) - 1:
            ax.set(xlabel=direction_labels[lab2], yticks=[0, 1])
            ax.xaxis.label.set_color(COLS["sides"][lab2])
        if j == 0:
            ax.set(ylabel=direction_labels[lab1], yticks=[0, 1])
            ax.yaxis.label.set_color(COLS["sides"][lab1])

legend_axs = axs[1, 1]
legend_axs.plot([], [], lw=f_lw, c=f_c, label="single fish")
legend_axs.plot([], [], lw=m_lw, c=m_c, label="median")
legend_axs.legend()


plt.show()
pltltr.savefig("roc_curves_shuffletest", folder="S7b")