# Motor activity and phase dynamics

How does the phase evolve over time? What happens to it when the fish is moving? This is what  we will investigate in this notebook.

In [None]:
%matplotlib widget
import lotr.plotting as pltltr
import numpy as np
import pandas as pd
from lotr import A_FISH, FIGURES_LOCATION, LotrExperiment, dataset_folders
from lotr.behavior import get_fictive_heading
from lotr.utils import linear_regression
from matplotlib import pyplot as plt
from scipy.stats import wilcoxon
from tqdm import tqdm

COLS = pltltr.COLS

## Overall probability of network phase over time

First of all, we want to check whether the network crosses all phases with similar probability. We will use the function for calculating network phase defined in notebook 2.

First of all, let's look at the histogram over the full experiment for all fish:

In [None]:
# Load phases from all fish:
phases = [LotrExperiment(path).network_phase for path in tqdm(dataset_folders)]

In [None]:
f, ax = plt.subplots(figsize=(3, 2))
hist_base = np.arange(-np.pi, np.pi, 0.3)
for i, ph in enumerate(phases):
    plt.hist(
        ph,
        hist_base,
        fc=COLS["ph_plot"],
        lw=0,
        density=True,
        alpha=0.1,
        label="fish" if i == 0 else "_nolegend_",
    )

plt.hist(
    np.concatenate(phases),
    hist_base,
    density=True,
    histtype="step",
    lw=1,
    label="mean",
    ec=pltltr.dark_col(COLS["ph_plot"], 0.3),
)
ax.set(xlabel="phase", ylabel="p. density", ylim=(0, 0.8), **pltltr.get_pi_labels(0.5))
pltltr.despine(ax)
ax.legend()
plt.show()
plt.tight_layout()

pltltr.savefig("all_network_phase_hist")

This looks quite flat! We'll come back later to this result. But for now, it is interesting to see that there is no location of consistent preferential activation across fish.

## Dynamics of the network and behavior

Let's now get to the main point: what happens to the phase when the fish is performing directional motion?

We will start with a very naive approach, simply cropping phase after every single directional bout to see how it changes. From now on, we will work with the [np.unwrap](https://numpy.org/doc/stable/reference/generated/numpy.unwrap.html)ed version of the phase, to avoid jumps from `-pi` to `pi`

In [None]:
from lotr.default_vals import DEFAULT_FN, POST_BOUT_WND_S, PRE_BOUT_WND_S
from lotr.utils import crop, resample_matrix

In [None]:
# Get unwrapped phase:
exp = LotrExperiment(A_FISH)
unwrapped_ph = np.unwrap(exp.network_phase)

# Crop network phase:
cropped_phase = crop(
    unwrapped_ph,
    exp.bouts_df["idx_imaging"],
    pre_int=PRE_BOUT_WND_S * exp.fn,
    post_int=POST_BOUT_WND_S * exp.fn,
)
# Subtract baseline:
cropped_phase = cropped_phase - np.mean(cropped_phase[: PRE_BOUT_WND_S * exp.fn, :], 0)

# Same, for heading direction:
cropped_head = crop(
    exp.fictive_heading,
    exp.bouts_df["idx_imaging"],
    pre_int=PRE_BOUT_WND_S * exp.fn,
    post_int=POST_BOUT_WND_S * exp.fn,
)
cropped_head = cropped_head - np.mean(cropped_head[: PRE_BOUT_WND_S * exp.fn, :], 0)


time_arr = np.arange(cropped_phase.shape[0]) / exp.fn - PRE_BOUT_WND_S

In [None]:
def plot_bout_trig(
    events_df, cropped_list, labels_list, ylims=(-4.2, 4.2), legend_lab="{} bouts"
):

    f, axs = plt.subplots(
        1,
        2,
        figsize=(5, 2.5),
        gridspec_kw=dict(left=0.08, bottom=0.15, top=0.9, right=0.73),
        sharey=True,
    )

    for ax, lab, cropped in zip(axs, labels_list, cropped_list):
        for d in events_df["direction"].unique():
            sel = events_df["direction"] == d
            ax.plot(
                time_arr,
                cropped[:, sel],
                lw=0.3,
                c=COLS["sides"][d],
                label="_nolabel_",
            )
            ax.plot(
                time_arr,
                np.mean(cropped[:, sel], 1),
                lw=2,
                c=pltltr.dark_col(COLS["sides"][d]),
                label=legend_lab.format(d),
                zorder=30,
            )

        pltltr.despine(ax)
        ax.set(
            xlabel="time from bout (s)",
            ylim=ylims,
            **pltltr.get_pi_labels(0.5, ax="y"),
        )
        ax.set_title(lab, weight="bold")
        ax.axvline(0, lw=0.5, c=".5")
    axs[1].legend(bbox_to_anchor=(1.05, 0.8, 0.2, 0.2))

    return f, axs


f, axs = plot_bout_trig(
    exp.bouts_df, [cropped_phase, cropped_head], ["Δphase", "Δheading"]
)
pltltr.savefig("bout_trig_phase_change_onefish")

## Crop phase changes across all fish

We can crop in a similar way events from all fish in the dataset. As we will be doing this also in other notebooks, we will use the function defined in `lotr.analysis.shift_cropping`, performing for all experiments the operations below plus some additional step for interpolation of experiments with different sampling freq. For the first fish, we will check that results are consistent.

In [None]:
from lotr.analysis.shift_cropping import crop_shifts_all_dataset

all_phase_cropped, all_head_cropped, events_df, time_arr = crop_shifts_all_dataset()

# Control consistency with notebook pipeline:
sel = events_df["fid"] == exp.dir_name
assert np.allclose(all_phase_cropped[:, sel], cropped_phase, rtol=0.001)
assert np.allclose(all_head_cropped[:, sel], cropped_head, rtol=0.001)

In [None]:
# For every fish, compute average response for each direction:
all_phase_means = []
all_head_means = []
mean_events_df = []
for fid in events_df["fid"].unique():
    for d in events_df["direction"].unique():
        sel = (events_df["direction"] == d) & (events_df["fid"] == fid)

        if sum(sel) > 0:
            all_phase_means.append(all_phase_cropped[:, sel].mean(1))
            all_head_means.append(all_head_cropped[:, sel].mean(1))
            mean_events_df.append(events_df[sel].iloc[0, :])

all_phase_means = np.stack(all_phase_means).T
all_head_means = np.stack(all_head_means).T
mean_events_df = pd.DataFrame(mean_events_df)

In [None]:
f, axs = plot_bout_trig(
    mean_events_df,
    [all_phase_means, all_head_means],
    ["Δphase", "Δheading"],
    legend_lab="{} (all fish + mn)",
    ylims=(-np.pi, np.pi),
)
pltltr.savefig("bout_trig_phase_change_allfish")

Note that we have not introduced any arbitrary sign flip to get to this plot, we just imposed a registration with the anatomy, and for free we get such an agreement of phase changes across fish! 
If we consider the convention for phase definition:
 - **left bouts induce cw rotations**
 - **right bouts induce ccw rotations**
 
In finalizing those statements we should make sure we cross check everything 10 times, but I looked again at the following:
 - definition of left and right bouts can be cross-checked using localization of motor selective ROIs as the reference. With our definition of left and right bouts, we get left bouts activating left side of aHB, and right bouts activating right side of aHB (consistent with e.g. Chen et al 2018)
 - our definition of phase rotation is consistent with visual inspection of data for A_FISH dataset

## Phase vs. theta turned correlation

In [None]:
# window in which we will calculate the delta change
WND_DELTA_S = np.array([15, 20])
wnd_pts = (PRE_BOUT_WND_S + WND_DELTA_S) * DEFAULT_FN

In [None]:
events_df["Δphase"] = np.nanmean(all_phase_cropped[slice(*wnd_pts), :], 0)
events_df["Δhead"] = np.nanmean(all_head_cropped[slice(*wnd_pts), :], 0)

In [None]:
np.random.seed(60)
results_df = []
for fid in events_df["fid"].unique():
    sel_df = events_df[(events_df["fid"] == fid) & (events_df["direction"] != "fw")]
    shuffle = np.arange(len(sel_df))
    np.random.shuffle(shuffle)
    # corr = np.corrcoef(sel_df["Δphase"], sel_df["Δhead"])[0, 1]
    # shuf = np.corrcoef(sel_df["Δphase"].values[shuffle], sel_df["Δhead"])[0, 1]
    _, corr = linear_regression(sel_df["Δphase"], sel_df["Δhead"])
    _, shuf = linear_regression(sel_df["Δphase"].values[shuffle], sel_df["Δhead"])
    results_df.append({"Data": corr, "Shuffle": shuf})
results_df = pd.DataFrame(results_df)

In [None]:
f, axs = plt.subplots(
    1,
    2,
    figsize=(5, 2.5),
    gridspec_kw=dict(left=0.2, bottom=0.15, wspace=0.5, width_ratios=[1, 0.5]),
)


for fid, col in zip(events_df.fid.unique(), COLS["fish_cols"]):
    sel_df = events_df[(events_df["fid"] == fid) & (events_df["direction"] != "fw")]
    axs[0].scatter(
        sel_df["Δphase"],
        sel_df["Δhead"],
        alpha=0.5,
        facecolors="none",
        edgecolors=col,
        lw=0.5,
        s=3,
    )
    o, c = linear_regression(sel_df["Δphase"], sel_df["Δhead"])
    axs[0].plot(sel_df["Δphase"], sel_df["Δphase"] * c + o, lw=1, c=col, zorder=100)
pltltr.despine(axs[0])
axs[0].set(
    xlabel=r"$Δphase_{15-20s}(\phi)$",
    ylabel=r"$Δhead_{15-20s} (\theta)$",
    **pltltr.get_pi_labels(coefs=(-4, -2, 0, 2, 4), ax="y"),
    **pltltr.get_pi_labels(coefs=(-2, 0, 2), ax="x")
)
axs[0].text(2, 10, r"$\theta=\alpha \phi + \beta$", fontsize=8)


axs[1].plot(results_df.T, lw=1, c=".7", alpha=0.5, zorder=100)
axs[1].axhline(0, linestyle="dashed", c=".5")
pltltr.bar_with_bars(
    results_df, axs[1], cols=[COLS["ph_plot"], ".3"], empty=False, lw=1.0
)
axs[1].set(ylabel=(r"$\alpha$"), xlim=(-0.3, 1.3))
pltltr.despine(axs[1])

test_res = wilcoxon(results_df["Data"], results_df["Shuffle"])
axs[1].text(0.5, -2, pltltr.get_pval_stars(test_res), ha="center")

pltltr.savefig("correlation_quantification.pdf")

In [None]:
(30 * 78 * 2) / 120