# Phase dynamics and visual feedback

In this notebook, we will build upon the previous quantifications on the relationship between phase and directional swimming and see whether visual feedback is necessary / involved in the rotation of the network.

In [None]:
%matplotlib widget

import lotr.plotting as pltltr
import numpy as np
import pandas as pd
from lotr import A_FISH, FIGURES_LOCATION, LotrExperiment, dataset_folders
from lotr.behavior import get_fictive_heading
from lotr.utils import crop, resample_matrix
from matplotlib import pyplot as plt
from tqdm import tqdm

COLS = pltltr.COLS

## Load stimulus information

We will use the `LotrExperiment.stim_trials_df` method to load from each experiment a dictionary describing the sequence of stimuli and the concatenate all of them in a single dataframe.
The `all_stim_conds` will contain a `condition` column that can be one of the following:
 - **`darkness`**: black screen, fish swimming in the darkness
 - **`closed_loop`**: fish experiences visual feedback when moving. This consists in a pink noise pattern that, every time the fish performa a bout, rotates proportionally to the laterality index for the bout (with inverted sign), and translates backward up with an average speed of approx. 10 mm/s. In the gainmod experiments, a gain factor that changes the amount of rotation  induced by each directional bout is alternated between 0.5 (**low gain**), 1.0 (**normal gain**), 2.0 (**high gain**), -1 (**inverted motion**).
 - **`natural_mot`**: pink noise that is translated/rotated independently of the fish motion in ways similar to what the fish would experience swimming over the pattern - a sort of playback of a closed loop experiment. This stimulus induces very small responses in visual directional motion/rotation cells as every movement is very brief.
 - **`directional_mot`**: pink noise that moves in 8 equally spaced directions on the plane below the fish in periods of 10s, spaced by pause periods of 10s where the pattern is static. As the fish is mostly moving during directional motion, for the sake of this quantification for now all bouts during both the motion and the static times are pooled in the "directional_mot" condition # TODO maybe split them

In [None]:
all_stim_logs = [LotrExperiment(path).stim_trials_df for path in tqdm(dataset_folders)]
all_stim_conds = pd.concat(all_stim_logs, ignore_index=True, sort=True)

### Visualize all experiment types

Let's have a quick look at all the experiment types:

In [None]:
all_exp_types = all_stim_conds.exp_type.unique()

f, axs = plt.subplots(
    len(all_exp_types),
    1,
    figsize=(5, 4),
    gridspec_kw=dict(top=0.8, right=0.8, hspace=1),
)
for i, exp_type in enumerate(all_exp_types):
    # Get an example fish for the experiment:
    fid = all_stim_conds.loc[all_stim_conds["exp_type"] == exp_type, "fid"].unique()[0]

    # Plot with custom lotr function:
    pltltr.plot_exp_condition(
        all_stim_conds[all_stim_conds["fid"] == fid], ax=axs[i], alpha=1
    )
    axs[i].set(title=exp_type, xlim=(0.0, 2500))

for ax in axs[:-1]:
    pltltr.despine(ax, "all")
pltltr.despine(axs[-1], ["left", "top", "right"])

axs[-1].set_xlabel("Time (s)")

ax = f.add_axes((0.75, 0.95, 0.2, 0.2))
ax.axis("off")
v_space = -0.17
v_o, h_o = 0, 0.01

for text, item in COLS["stim_conditions"].items():
    lab = " ".join(text.split("_")).capitalize()
    if text != "closed_loop":
        ax.text(h_o, v_o, lab, c=item, weight="bold", fontsize=8)
        v_o += v_space
    else:
        for g, col in item.items():
            ax.text(h_o, v_o, f"{lab} (gain {g})", c=col, weight="bold", fontsize=8)
            v_o += v_space

pltltr.savefig("experiment_types")

## Bout-induced phase changes w/ and w/o visual feedback 

First of all, let's crop phase shifts and cumulative theta turned around all bouts as we did in the previous notebook:

In [None]:
from lotr.analysis.shift_cropping import crop_shifts_all_dataset

all_phase_cropped, all_head_cropped, events_df, time_arr = crop_shifts_all_dataset()

In [None]:
FN = 5  # we impose a sampling frequency, fish with a different one will be resampled

directions = ["rt", "lf", "fw"]
pre_wnd_s, post_wnd_s = 10, 25

all_phase_cropped = []
all_head_cropped = []

# We will create a dataframe to keep track of events from all fish.
# Mostly a way of keeping together the crop and the bouts:
events_df = []

time_arr = np.arange(1, ((pre_wnd_s + post_wnd_s) * FN) + 1) / FN - pre_wnd_s
for path in tqdm(dataset_folders):
    exp = LotrExperiment(path)

    unwrapped_ph = np.unwrap(exp.network_phase)
    fictive_head = get_fictive_heading(exp.n_pts, exp.bouts_df)

    for d in directions:
        sel_bouts = exp.bouts_df.loc[exp.bouts_df["direction"] == d, :]

        events_df.append(sel_bouts.reindex())

        for dest_list, to_crop in zip(
            [all_phase_cropped, all_head_cropped], [unwrapped_ph, fictive_head]
        ):

            cropped = crop(
                to_crop,
                sel_bouts["idx_imaging"],
                pre_int=pre_wnd_s * exp.fn,
                post_int=post_wnd_s * exp.fn,
            )

            # Subtract baseline:
            cropped = cropped - np.mean(cropped[: pre_wnd_s * exp.fn, :], 0)

            # Interpolate if necessary:
            if exp.fn != FN:
                fish_time_arr = np.arange(1, cropped.shape[0] + 1) / exp.fn - pre_wnd_s
                cropped = resample_matrix(time_arr, fish_time_arr, cropped)

            dest_list.append(cropped)

all_phase_cropped = np.concatenate(all_phase_cropped, axis=1)
all_head_cropped = np.concatenate(all_head_cropped, axis=1)
events_df = pd.concat(events_df, ignore_index=True)

In [None]:
# TODO: fix this in actual preprocessing
events_df.loc[events_df["fid"] == "210601_f0_2d_vr_eyes", "fid"] = "210601_f0_2dvr_eyes"

events_df["condition"] = "-"
events_df["gain"] = np.nan

for i in tqdm(events_df.index):
    sel_trial = (
        (all_stim_conds["t_start"] < events_df.loc[i, "t_start"])
        & (all_stim_conds["t_stop"] > events_df.loc[i, "t_start"])
        & (all_stim_conds["fid"] == events_df.loc[i, "fid"])
    )
    events_df.loc[i, "condition"] = all_stim_conds.loc[sel_trial, "condition"].values
    events_df.loc[i, "exp_type"] = all_stim_conds.loc[sel_trial, "exp_type"].values

    if events_df.loc[i, "condition"] == "closed_loop":
        events_df.loc[i, "gain_theta"] = all_stim_conds.loc[
            sel_trial, "gain_theta"
        ].values

In [None]:
t_slice = slice(160, None)

cols = ("r", "k", "b", "y")
plt.figure(figsize=(3, 3))
for g, c in zip([0.5, 1, 2, -1], cols):
    sel = (events_df["exp_type"] == "gainmod") & (events_df["gain_theta"] == g)
    print(sum(sel))
    plt.scatter(
        all_phase_cropped[t_slice, sel].mean(0),
        all_head_cropped[t_slice, sel].mean(0),
        c=c,
        alpha=0.1,
        lw=0,
    )
plt.axis("equal")
plt.xlabel("delta_phase")
plt.ylabel("delta_head")
plt.tight_layout()

In [None]:
import statsmodels.formula.api as sm

all_coefs = []
for g, c in zip([0.5, 1, 2, -1], cols):
    sel = (events_df["exp_type"] == "gainmod") & (events_df["gain_theta"] == g)
    gain_arr = []
    for fid in events_df[sel].fid.unique():
        fish_sel = sel & (events_df["fid"] == fid)
        data_df = pd.DataFrame(
            {
                "x": all_phase_cropped[t_slice, fish_sel].mean(0),
                "y": all_head_cropped[t_slice, fish_sel].mean(0),
            }
        )
        ols_model = sm.ols(formula="y ~ x", data=data_df)
        results = ols_model.fit()
        # coefficients
        # gain_arr.append(np.corrcoef(data_df.values.T)[0, 1])
        # print("Intercept, x-Slope : {}".format(results.params))
        # y_pred = ols_model.fit().predict(data_df["x"])
        gain_arr.append(results.params["x"])

    all_coefs.append(gain_arr)

In [None]:
plt.figure()
plt.scatter(
    data_df.x, data_df.y, c=c, alpha=0.1, lw=0,
)
# x_fit = np.arange(-2*np.pi, 2*np.pi, 0.1)
# plt.plot(x_fit, x_fit * results.params["x"] + results.params["Intercept"])
plt.plot(data_df["x"], ols_model.fit().predict(data_df["x"]))

In [None]:
included_fish

In [None]:
data_df

In [None]:
included_fish = events_df.loc[events_df["condition"] == "closed_loop", "fid"].unique()

all_coefs = []
for fid in included_fish:
    cond_arr = []
    for cl_sel in [
        (events_df["condition"] == "closed_loop"),
        (events_df["condition"] != "closed_loop"),
    ]:
        sel = cl_sel & (events_df["fid"] == fid)
        # print(sum(sel))
        data_df = pd.DataFrame(
            {
                "x": all_phase_cropped[t_slice, sel].mean(0),
                "y": all_head_cropped[t_slice, sel].mean(0),
            }
        )
        if len(data_df > 10):
            ols_model = sm.ols(formula="y ~ x", data=data_df)
            results = ols_model.fit()
            # coefficients
            # cond_arr.append(np.corrcoef(data_df.values.T)[0, 1])
            # print("Intercept, x-Slope : {}".format(results.params))
            # y_pred = ols_model.fit().predict(data_df["x"])
            cond_arr.append(results.params["x"])
        else:
            cond_arr.append(np.nan)

    all_coefs.append(cond_arr)

In [None]:
plt.figure()
plt.plot(np.array(all_coefs).T, "-o")
plt.show()

In [None]:
for con in ["closed_loop", "natural_mot", "directional_mot", "darkness"]:
    sel = events_df["condition"] == con

    data_df = pd.DataFrame(
        {
            "x": all_phase_cropped[t_slice, sel].mean(0),
            "y": all_head_cropped[t_slice, sel].mean(0),
        }
    )
    ols_model = sm.ols(formula="y ~ x", data=data_df)
    results = ols_model.fit()
    # coefficients
    print(np.corrcoef(data_df.values.T)[0, 1])

In [None]:
for g, c in zip([0.5, 1, 2], cols):
    sel = (events_df["condition"] == "closed_loop") & (events_df["gain"] == g)

    data_df = pd.DataFrame(
        {
            "x": all_phase_cropped[t_slice, sel].mean(0),
            "y": all_head_cropped[t_slice, sel].mean(0),
        }
    )
    ols_model = sm.ols(formula="y ~ x", data=data_df)
    results = ols_model.fit()
    # coefficients
    print(np.corrcoef(data_df.values.T)[0, 1])

In [None]:
t_slice = slice(160, None)

cols = ("r", "k", "b")
plt.figure(figsize=(3, 3))
# for g, c in zip([0.5, 1, 2], cols):
#     # sel = (events_df["condition"] == "closed_loop") & (events_df["gain"] == g)
plt.scatter(
    all_phase_cropped[t_slice, :].mean(0),
    all_head_cropped[t_slice, :].mean(0) / 2,
    c="k",
    alpha=0.5,
    lw=0,
)

s = np.abs(events_df["bias"]) < 0.001
plt.scatter(
    all_phase_cropped[t_slice, s].mean(0),
    all_head_cropped[t_slice, s].mean(0) / 2,
    c="r",
    alpha=0.5,
    lw=0,
)

plt.axis("equal")
plt.xlabel("delta_phase")
plt.ylabel("delta_head")
plt.tight_layout()

In [None]:
events_df