# Some probability checks

So far, so good, our network seems to integrate directional motion over a range of tens of seconds. 
But there are some important sanity checks to do. In this brief notebook we will address the following:
 1. **Are all network phases equally likely?** 
 2. **Is the probability of being in a certain phase given a bout left or right the same?**

In [None]:
%matplotlib widget

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm import tqdm

import lotr.plotting as pltltr
from lotr import LotrExperiment, dataset_folders

COLS = pltltr.COLS

## Probability of different phases
we expect all phases to be equally likely. To check this, we will load all phases and compute their histograms

In [None]:
# bins to compute histogram
hist_base = np.linspace(-np.pi, np.pi, 25)

# Load phases from all fish:
phase_hists = []
for path in tqdm(dataset_folders):
    phase_hist, _ = np.histogram(
        LotrExperiment(path).network_phase, hist_base, density=True
    )
    phase_hists.append(phase_hist)
phase_hists = np.array(phase_hists)

This looks quite flat! It is important to see that there is no location of consistent preferential activation across fish.

In [None]:
hist_x = (hist_base[1:] + hist_base[:-1]) / 2

f, ax = plt.subplots(figsize=(3, 2), gridspec_kw=dict(left=0.15, bottom=0.2))
for h in phase_hists:
    ax.bar(
        hist_x,
        h,
        width=hist_x[1] - hist_x[0],
        fc=pltltr.shift_lum(COLS["ph_plot"], -0.1),
        lw=0.0,
        alpha=0.2,
    )
ax.step(
    hist_x,
    np.nanmean(phase_hists, 0),
    lw=1.5,
    c=pltltr.shift_lum(COLS["ph_plot"], 0.1),
    where="mid",
)
ax.set(ylim=(0, 0.8), ylabel=r"$P(\Phi)$")
pltltr.despine(ax)

ax.set(**pltltr.get_pi_labels(0.5), xlabel=r"Network phase ($\Phi$)")

pltltr.savefig("network_phase_probability")

The lower probability at $\Phi=0$ and $\Phi=\pi$ might correspond to the fact that there are less cells along the midline?

## Probability of phase given bout direction

Let's see whether the directional bouts occurrence is related to the theta. If our network is only integrating past directional motion, this should not be the case (it would be, for example, in a region like the ARTR).

In [None]:
# For some fish, generate an histogram of phase given bouts for each possible bout direction
hist_df = []  # df with info on all stacked histograms
phase_given_bout_hists = []  # array of all stacked histograms

for path in tqdm(dataset_folders):
    exp = LotrExperiment(path)
    phase_hist, _ = np.histogram(exp.network_phase, hist_base)
    df = exp.bouts_df
    df["phase"] = exp.network_phase[exp.bouts_df["idx_imaging"]]

    for d in exp.bouts_df["direction"].unique():
        if sum(df["direction"] == d) > 1:
            bouts_hist, _ = np.histogram(
                df.loc[df["direction"] == d, "phase"], hist_base
            )
            ratio = bouts_hist / phase_hist
            ratio = ratio / np.nansum(ratio)

            # append histogram array and info to the dataframe:
            hist_df.append(dict(fid=exp.dir_name, direction=d))
            phase_given_bout_hists.append(ratio)

hist_df = pd.DataFrame(hist_df)  # dataframe with info
phase_given_bout_hists = np.array(phase_given_bout_hists)  # matrix of histograms

In [None]:
hist_x = (hist_base[1:] + hist_base[:-1]) / 2
f, axs = plt.subplots(
    1,
    3,
    figsize=(7, 1.5),
    gridspec_kw=dict(left=0.075, bottom=0.25, wspace=0.35, right=0.98),
)
for ax, d in zip(axs, ["lf", "rt", "fw"]):
    for h in phase_given_bout_hists[hist_df["direction"] == d, :]:
        ax.bar(
            hist_x,
            h,
            width=hist_x[1] - hist_x[0],
            fc=pltltr.shift_lum(COLS["sides"][d], -0.1),
            lw=0.0,
            alpha=0.2,
        )
    ax.step(
        hist_x,
        np.nanmean(phase_given_bout_hists[hist_df["direction"] == d, :], 0),
        lw=1.5,
        c=pltltr.shift_lum(COLS["sides"][d], 0.1),
        where="mid",
    )
    ax.set(ylim=(0, 0.21), ylabel=r"$P(\Phi|bout_{" + f"{d}" + "})$")
    pltltr.despine(ax)

    ax.set(**pltltr.get_pi_labels(0.5), xlabel=r"Network phase ($\Phi$)")

plt.show()
pltltr.savefig("phase_given_bout_hists", folder="S6")

Those look quite flat, as we were expecting them!