# Explore network activity profile

We will now explore how consistent in time seems to be the wave of activation that spreads through our ring of neurons.

In [None]:
%matplotlib widget
from pathlib import Path

import flammkuchen as fl
import lotr.plotting as pltltr
import numpy as np
import pandas as pd
import seaborn as sns
from bouter.utilities import crop
from lotr import A_FISH, FIGURES_LOCATION, LotrExperiment, dataset_folders
from lotr.pca import fit_phase_neurons, pca_and_phase
from lotr.utils import interpolate, roll_columns_jit, zscore
from matplotlib import pyplot as plt
from tqdm import tqdm

COLS = pltltr.COLS

fig_location = FIGURES_LOCATION / "3 - activity_prof"
fig_location.mkdir(exist_ok=True)

## Profile over time

First, we will find a way of shifting columns of the activation matrix over time so that the peak of the network activation is always in the same roughly consistent position

In [None]:
exp = LotrExperiment(A_FISH)

# Find sorting over neurons according to angles...
sort_idxs = np.argsort(exp.rpc_angles)

# ...and apply it to sort the traces:
sorted_traces = exp.traces[:, exp.hdn_indexes[sort_idxs]]

# Find the right amount of shift over time to have the bump centered:
# by first stretching phase to (-0.5, 0.5) interval and then
# to (-n_rois//2, n_rois//2) interval. In this way, we will center
# phase 0 of the network on position of angle 0:
phase_shifts = (exp.network_phase / (2 * np.pi)) * (exp.n_hdns - 1)

# Then, apply shifts to traces:
shifted_traces = roll_columns_jit(sorted_traces, -np.round(phase_shifts))

In [None]:
flims = 1.7

f, axs = plt.subplots(2, 1, figsize=(4, 4))

for i, traces in enumerate([sorted_traces, shifted_traces]):
    c = axs[i].imshow(
        traces.T,
        aspect="auto",
        extent=(0, exp.n_pts / exp.fn, 0, exp.n_hdns),
        vmin=-flims,
        vmax=flims,
        cmap=COLS["dff_plot"],
    )
for i, line in enumerate(
    [phase_shifts + exp.n_hdns / 2, np.ones(exp.n_pts) * exp.n_hdns / 2]
):
    axs[i].plot(exp.time_arr, line, lw=0.5, c=COLS["ph_plot"])

pltltr.despine(axs[0], "all")
pltltr.despine(axs[1], ["left", "right", "top"])
pltltr.add_dff_cbar(c, axs[0], (1.05, 0.78, 0.02, 0.22))

axs[1].set(xlabel="time (s)", ylabel="dist. from peak")
axs[0].set(ylabel="sorted ROI n.")
f.savefig(fig_location / "phase_centered_traces.pdf")

### Improve shift by interpolation

This looks nice, but we could potentially improve it! Above, we were shifting the activation bump assuming that neurons where equally spaced across all possible angles in rPC space. However, this might not be the case, and we actually know that we can have non-homogeneous distribution on neurons along the circle.
Let's see if we can improve things by interpolating the traces to sample homogeneously along the circle before shifting:

In [None]:
# arbitrary number of bins for the resampling:
N_BINS_RESAMPLED = 100

# we will resample over the (-pi, pi) interval
resampling_base = np.linspace(-np.pi, np.pi, N_BINS_RESAMPLED)

angle_resampled_traces = np.zeros((exp.n_pts, N_BINS_RESAMPLED))
for i in range(exp.n_pts):
    angle_resampled_traces[i, :] = np.interp(
        resampling_base,
        exp.rpc_angles[sort_idxs],
        exp.traces[i, exp.hdn_indexes[sort_idxs]],
    )

# Then, we repeat:
phase_shifts_resamp = (exp.network_phase / (2 * np.pi)) * (N_BINS_RESAMPLED - 1)
shifted_traces_resamp = roll_columns_jit(
    angle_resampled_traces, -np.round(phase_shifts_resamp)
)

In [None]:
f_lims = dict(vmin=-1.7, vmax=1.7)
f, ax = plt.subplots(1, 1, figsize=(3.5, 2))

for i, (lab, x, traces) in enumerate(
    zip(
        ["not resamp.", "resamp."],
        [exp.rpc_angles[sort_idxs], resampling_base],
        [shifted_traces, shifted_traces_resamp],
    )
):

    mn, sd = traces.mean(0), traces.std(0)
    (p,) = ax.plot(x, mn, lw=2, label=lab)
    ax.fill_between(x, mn - sd, mn + sd, lw=0, alpha=0.5, label="_nolegend_")


ax.set(
    xlabel="phase dist. from peak",
    ylabel="avg. ΔF",
    ylim=(-2, 2),
    **pltltr.get_pi_labels(d=0.5)
)
ax.legend(bbox_to_anchor=(1.3, 1, 0.01, 0.0))
for l in [-1 / 2, 0, 1 / 2]:
    plt.axvline(l * np.pi, lw=0.5, c=".7", zorder=-100)
plt.axhline(0, lw=0.5, c=".7", zorder=-100)
pltltr.despine(ax)

plt.tight_layout()
f.savefig(fig_location / "pre_post_interp.pdf")

This looks definitively better!

## Loop over all fish 
Now, let's compare phase activations across fish from the entire dataset. The steps have been wrapped in the `lotr.analysis.activity_profile.resample_and_shift()` function:

In [None]:
from lotr.analysis.activity_profile import resample_and_shift

resamp, reshaped = resample_and_shift(exp)
assert np.allclose(angle_resampled_traces, resamp, rtol=0.001)
assert np.allclose(shifted_traces_resamp, reshaped, rtol=0.001)

In [None]:
N_BINS_RESAMP = 100

all_mn_profiles = []
all_std_profiles = []

for path in tqdm(dataset_folders):
    _, shifted_resamp_traces = resample_and_shift(
        LotrExperiment(path), n_bins_resampling=N_BINS_RESAMP
    )

    all_mn_profiles.append(shifted_resamp_traces.mean(0))
    all_std_profiles.append(shifted_resamp_traces.std(0))

all_mn_profiles = np.array(all_mn_profiles)
all_std_profiles = np.array(all_std_profiles)

In [None]:
all_mn_profiles = np.array(all_mn_profiles)
all_std_profiles = np.array(all_std_profiles)

In [None]:
fig, axs = plt.subplots(
    2,
    1,
    gridspec_kw=dict(left=0.15, bottom=0.15, right=0.8),
    figsize=(3, 3),
    sharex=True,
)


for mn, std in zip(all_mn_profiles, all_std_profiles):
    axs[1].fill_between(
        resampling_base,
        mn - std,
        mn + std,
        lw=0,
        fc=".0",
        alpha=0.05,
        label="_nolegend_",
    )
axs[1].fill_between([], [], [], lw=0, fc=".0", alpha=0.2, label="fish (mn+/-sd)")
axs[1].plot(resampling_base, all_mn_profiles.mean(0), label="mean")
axs[1].legend()
axs[1].set(
    xlabel="Norm. network position",
    ylabel="mean dF",
    ylim=(-2.2, 2.2),
    **pltltr.get_pi_labels(d=0.5)
)
axs[0].imshow(
    all_mn_profiles,
    aspect="auto",
    extent=[-np.pi, np.pi, all_mn_profiles.shape[0], 0],
    cmap=COLS["dff_plot"],
)
pltltr.add_dff_cbar(c, axs[0], (1.1, 0.8, 0.03, 0.25))
axs[0].set(ylabel="Fish n.")
[pltltr.despine(ax) for ax in axs]

f.savefig(fig_location / "all_fish_profiles.pdf")

## Activation profile and behavior
Next interesting question is: is the bump sustained also when fish is not swimming? How much can it persist after last bout?

We will proceed in the following way: we will create an array that specifies for each frame how much time elapsed from the last bout, and we will then average frames by such distance.

In [None]:
from lotr.behavior import get_bouts_props_array

exp = LotrExperiment(A_FISH)

bouts_arr = get_bouts_props_array(
    exp.n_pts, exp.bouts_df, min_bias=0, selection="all", value=1
)

frames_elapsed = np.zeros(bouts_arr.shape)

for t in range(1, len(bouts_arr)):
    if bouts_arr[t] == 1:
        frames_elapsed[t] = 0
    else:
        frames_elapsed[t] = frames_elapsed[t - 1] + 1

In [None]:
f, ax = plt.subplots(figsize=(6, 2))
ax.plot(exp.time_arr, frames_elapsed / exp.fn)

for t in range(1, len(bouts_arr)):
    if bouts_arr[t] == 1:
        ax.axvline(t / exp.fn, lw=0.2, c=".4")
ax.set(xlabel="time (s)", ylabel="time since last bout (s)")

pltltr.despine(ax)
plt.tight_layout()
f.savefig(fig_location / "time_elapsed_expl.pdf")

In [None]:
# Now, loop over possible distances and average all activation profiles corresponding that
# distance. We well also keep track of how many frames we have for each elapsed time.
# As a control, we look at the not- reshaped array, to both get a noise level and
# make sure there is not just a stable activation profile to which the network converges.

max_el_count = 600  # maximum elapsed count, in frames

mean_activations = np.zeros((max_el_count, shifted_traces_resamp.shape[1]))
cnt_activations = np.zeros((max_el_count, angle_resampled_traces.shape[1]))

elapsed_hist = np.zeros(max_el_count)

for f_count in range(max_el_count):
    frames_sel = frames_elapsed == f_count
    mean_activations[f_count, :] = shifted_traces_resamp[frames_sel, :].mean(0)
    cnt_activations[f_count, :] = angle_resampled_traces[frames_sel, :].mean(0)

    elapsed_hist[f_count] = np.sum(frames_sel)

In [None]:
plt.figure(figsize=(3, 4))
plt.imshow(
    cnt_activations.T,
    aspect="auto",
    extent=[-np.pi, np.pi, mean_activations.shape[0], 0],
    cmap=COLS["dff_plot"],
)

In [None]:
max_el_count = 900

all_cnt = []
all_counts = []

for path in tqdm(dataset_folders):
    exp = LotrExperiment(A_FISH)

    sort_idxs = np.argsort(exp.rpc_angles)

    angle_resampled_traces = np.zeros((exp.n_pts, N_BINS_RESAMPLED))
    for i in range(exp.n_pts):
        angle_resampled_traces[i, :] = np.interp(
            resampling_base,
            exp.rpc_angles[sort_idxs],
            exp.traces[i, exp.hdn_indexes[sort_idxs]],
        )

    # Find the right amount of shift over time to have the bump centered:
    # by first stretching phase to 0-1 interval and then to 0-n_rois interval
    phase_shifts_res = (exp.network_phase / (2 * np.pi)) * (N_BINS_RESAMPLED - 1)

    # Then, apply shifts to traces:
    shifted_resamp_traces = roll_columns_jit(
        angle_resampled_traces, -np.round(phase_shifts_res)
    )

    bouts_arr = get_bouts_props_array(
        exp.n_pts, exp.bouts_df, min_bias=0, selection="all", value=1
    )

    frames_elapsed = np.zeros(bouts_arr.shape)

    for t in range(1, len(bouts_arr)):
        if bouts_arr[t] == 1:
            frames_elapsed[t] = 0
        else:
            frames_elapsed[t] = frames_elapsed[t - 1] + 1

    mean_activations = np.full((max_el_count, shifted_resamp_traces.shape[1]), np.nan)
    cnt_activations = np.full((max_el_count, shifted_resamp_traces.shape[1]), np.nan)

    for f_count in range(max_el_count):
        frames_sel = frames_elapsed == f_count
        mean_activations[f_count, :] = shifted_resamp_traces[frames_sel, :].mean(0)
        cnt_activations[f_count, :] = angle_resampled_traces[frames_sel, :].mean(0)

    all_counts.append(mean_activations)

    all_cnt.append(cnt_activations)

In [None]:
%%time
for f_count in range(max_el_count):
    frames_sel = frames_elapsed == f_count
    mean_activations[f_count, :] = shifted_resamp_traces[frames_sel, :].mean(0)
    cnt_activations[f_count, :] = angle_resampled_traces[frames_sel, :].mean(0)

In [None]:
all_counts = np.array(all_counts)
all_cnt = np.array(all_cnt)

In [None]:
[LotrExperiment(path).fn for path in dataset_folders]

In [None]:
plt.figure()
plt.imshow(np.nanmean(all_counts, 0).T, aspect="auto")

In [None]:
mn = np.nanmean(all_counts, 0)

plt.figure()
plt.plot(mn.T, "r")
plt.plot(np.nanmean(all_cnt, 0).T, "k")

plt.show()

In [None]:
all_phases = [LotrExperiment(path).network_phase for path in dataset_folders]

In [None]:
all_angles = [LotrExperiment(path).rpc_angles for path in dataset_folders]

In [None]:
plt.figure()
plt.hist(np.concatenate(all_angles), 50)
plt.show()