# The aHB circuit tracks fictive trajectories

We have seen in the previous notebooks that the aHB integrator could be capabable of keep track of past motion over a significant amount of time. In this notebook, we will see how much it can keep track of past trajectories over significant amounts of time.

## A first look at cumulative directional motion and network phase

In [None]:
%matplotlib widget
from pathlib import Path

import lotr.plotting as pltltr
import numpy as np
import pandas as pd
from bouter.utilities import crop
from lotr import A_FISH, LotrExperiment, dataset_folders
from lotr.behavior import get_bouts_props_array
from lotr.default_vals import REGRESSOR_TAU_S
from lotr.utils import convolve_with_tau
from matplotlib import pyplot as plt
from scipy.stats import spearmanr, wilcoxon
from tqdm import tqdm

COLS = pltltr.COLS

In [None]:
# First, get the unwrapped phase:
exp = LotrExperiment(A_FISH)
network_phase = np.unwrap(exp.network_phase)

# Now, let's find a way of reconstruct past motion of the fish.
# We will start with a vector of zeros that has ones at timepoints
# where the fish swam:
theta_turned = get_bouts_props_array(
    exp.n_pts, exp.bouts_df, selection="all", value="bias",
)

# Fictive heading will be the cumulative sum of this array:
fictive_head = np.cumsum(theta_turned)

# Finally, we smooth this array with a kernel that matches the slower dynamics
# of the neuron (results are not much affected by this)
fictive_head = convolve_with_tau(np.cumsum(theta_turned), REGRESSOR_TAU_S * exp.fn)

In [None]:
f, axs = plt.subplots(
    3, 1, figsize=(5, 3), sharex=True, gridspec_kw=(dict(bottom=0.12))
)
lw = 1

labels = ["Tail angle", "Angle turned (dθ)", "Fict. heading (θ)"]
col_seq = [COLS["beh"], pltltr.shift_lum(COLS["th_plot"], 0.2), COLS["th_plot"]]
for i, (x, y) in enumerate(
    [
        (exp.behavior_log["t"], exp.behavior_log["tail_sum"]),
        (exp.time_arr, theta_turned),
        (exp.time_arr, fictive_head),
    ]
):
    axs[i].plot(
        x, y, lw=lw, c=col_seq[i], label=labels[i], rasterized=i == 0
    )  # rasterize for exporting
    axs[i].legend(loc=1, bbox_to_anchor=(1.1, 1.1))

axs[0].set(**pltltr.get_pi_labels(1, ax="y"))
axs[1].set(**pltltr.get_pi_labels(0.5, ax="y"), ylim=(-1.8, 1.8))
axs[2].set(xlabel="Time (s)", **pltltr.get_pi_labels(coefs=(0, -5, -10, -15), ax="y"))

[pltltr.despine(ax) for ax in axs.flatten()]
plt.show()
pltltr.savefig("fictive_traj_computation")

Now, how does the network phase look, compared with the fictive heading?

In [None]:
f, ax = plt.subplots(
    1, 1, figsize=(4, 1.5), gridspec_kw=(dict(bottom=0.25, right=0.85))
)
twin_ax = ax.twinx()
ax.plot(
    exp.time_arr, fictive_head, c=COLS["th_plot"],
)
twin_ax.plot(
    exp.time_arr, np.unwrap(exp.network_phase), c=COLS["ph_plot"],
)
ax.set_ylabel("Fict. heading (θ)", c=COLS["th_plot"])
ax.set(xlabel="Times (s)", **pltltr.get_pi_labels(coefs=(0, -5, -10, -15), ax="y"))
twin_ax.set_ylabel("Network phase (ϕ)", c=COLS["ph_plot"])
twin_ax.set(
    ylim=twin_ax.get_ylim()[::-1], **pltltr.get_pi_labels(coefs=(0, 2, 4, 6), ax="y")
)
# rho, pval = spearmanr(fictive_head, np.unwrap(exp.network_phase))
# ax.text(1200, -np.pi, "$ρ_r=" + f"{rho:0.3f}" + "$" + pltltr.get_pval_stars(pval))
[pltltr.despine(a, ["top"]) for a in [ax, twin_ax]]

pltltr.savefig("fictive_traj_phase_onefish")

## Quantify correlation over the entire dataset

To run the correlation over the entire dataset, we will not use the whole trace to avoid the problem that a large drift at some point of the experiment could disrupt the correlation for the rest of it. Instead, we will use correlation in windows of 5 minutes, including only times where there were at least 2 directional bouts happening (otherwise fictive heading will be a flat line, with which correlation ill-defined).

In [None]:
CORR_WND_S = 300  # window over which correlation will be computed
N_OVERLAPS = 10  # number of overlaps per window
MIN_BOUTS = 2  # minimum number of bouts to inclide in the window

results_df = []
for path in tqdm(dataset_folders):
    # for path in tqdm(maybe):
    exp = LotrExperiment(path)
    phase = np.unwrap(exp.network_phase)
    theta_turned = get_bouts_props_array(
        exp.n_pts, exp.bouts_df, selection="all", value="bias",
    )

    # Fictive heading will be the cumulative sum of this array:
    fictive_head = np.cumsum(theta_turned)
    wnd_pts = int(CORR_WND_S * exp.fn)
    overlap_pts = wnd_pts // N_OVERLAPS

    # Find what ranges are valid, i.e have enough bouts in the interval:
    indexes = list(range(0, len(phase) - wnd_pts, overlap_pts))
    valid_indexes = []
    for i in indexes:
        if np.sum(np.abs(theta_turned[i : i + wnd_pts])) > MIN_BOUTS:
            valid_indexes.append(i)

    shuf_indexes = []
    for i in valid_indexes:
        shuf_i = i
        c = 0
        while np.abs(i - shuf_i) < wnd_pts // 2:
            shuf_i = np.random.choice(valid_indexes)
            c += 1
            if c > 50:
                print("could not find a valid shuffle for ", path.name)
                break

        shuf_indexes.append(shuf_i)

    for i, shuf_i in zip(valid_indexes, shuf_indexes):
        wnd = slice(i, i + wnd_pts)
        shuf_wnd = slice(shuf_i, shuf_i + wnd_pts)

        toshuf = fictive_head[shuf_wnd]

        if np.random.rand() > 0.5:
            toshuf = -toshuf

        rho, _ = spearmanr(phase[wnd], fictive_head[wnd])
        rho_shuf, _ = spearmanr(phase[wnd], toshuf)

        results_df.append(dict(rho=rho, shuf=rho_shuf, fid=exp.dir_name))

In [None]:
res_df = pd.DataFrame(results_df)
# compute quantiles:
quantiles_df = res_df.groupby("fid").quantile((0.25, 0.50, 0.75))

# sort on 0.5 quantile:
sort_order = quantiles_df.xs(0.5, level=1).sort_values("rho").index

# compute p values for each fish:
all_p_vals = []
for fid in res_df["fid"].unique():
    _, p = wilcoxon(*[res_df.loc[res_df["fid"] == fid, k] for k in ["rho", "shuf"]])
    all_p_vals.append(p)

In [None]:
def _wilcoxon_p(df, keys=["rho", "shuf"]):
    _, p = wilcoxon(*[df[k] for k in keys])
    return p


pvals_df = res_df.groupby("fid").apply(_wilcoxon_p)
pvals_df

In [None]:
f, ax = plt.subplots(figsize=(3, 2), gridspec_kw=dict(bottom=0.2, left=0.18),)
for k, col in zip(["shuf", "rho"], [COLS["shuf"], COLS["ph_plot"]]):
    stacked_quantiles = quantiles_df[[k]].stack().unstack(level=0)
    pltltr.tick_with_bars(
        stacked_quantiles[sort_order], cols=[col,] * len(sort_order), s=0.2
    )


# plt.plot(rhos[np.argsort(rhos[:, 1]), :], c="r")
ax.set(xlabel="Fish n.", ylabel="Correlation (Spearman ρ)")
ax.axhline(0, lw=1, c=".4")
pltltr.despine(ax)
plt.show()