# Explore network activity profile

We will now explore how consistent in time seems to be the wave of activation that spreads through our ring of neurons.

In [None]:
%matplotlib widget
from pathlib import Path

import flammkuchen as fl
import lotr.plotting as ltrplt
import numpy as np
import pandas as pd
import seaborn as sns
from bouter.utilities import crop
from lotr import A_FISH, FIGURES_LOCATION, LotrExperiment, dataset_folders
from lotr.pca import fit_phase_neurons, pca_and_phase
from lotr.utils import interpolate, roll_columns_jit, zscore
from matplotlib import pyplot as plt
from tqdm import tqdm

COLS = ltrplt.COLS

fig_location = FIGURES_LOCATION / "3 - activity_prof"
fig_location.mkdir(exist_ok=True)

## Profile over time

First, we will find a way of shifting columns of the activation matrix over time so that the peak of the network activation is always in the same roughly consistent position

In [None]:
exp = LotrExperiment(A_FISH)

# Find sorting over neurons according to angles...
sort_idxs = np.argsort(exp.rpc_angles)

# ...and apply it to sort the traces:
sorted_traces = exp.traces[:, exp.hdn_indexes[sort_idxs]]

# Find the right amount of shift over time to have the bump centered:
# by first stretching phase to 0-1 interval and then to 0-n_rois interval
phase_shifts = (exp.network_phase / (2 * np.pi)) * (exp.n_hdns - 1)

# Then, apply shifts to traces:
shifted_traces = roll_columns_jit(sorted_traces, -np.round(phase_shifts))

In [None]:
flims = 1.7

f, axs = plt.subplots(2, 1, figsize=(4, 4))

for i, traces in enumerate([sorted_traces, shifted_traces]):
    c = axs[i].imshow(
        traces.T,
        aspect="auto",
        extent=(0, exp.n_pts / exp.fn, 0, exp.n_hdns),
        vmin=-flims,
        vmax=flims,
        cmap=COLS["dff_plot"],
    )
for i, line in enumerate(
    [phase_shifts + exp.n_hdns / 2, np.ones(exp.n_pts) * exp.n_hdns / 2]
):
    axs[i].plot(exp.time_arr, line, lw=0.5, c=COLS["ph_plot"])
ltrplt.despine(axs[0], "all")
ltrplt.despine(axs[1], ["left", "right", "top"])

cbar = ltrplt.add_cbar(
    c,
    axs[0],
    (1.05, 0.05, 0.02, 0.22),
    ticks=[-1, 1],
    ticklabels=["$-$", "$+$"],
    title="ΔF",
    titlesize=8,
    labelsize=6,
)
# cbar.set_label("ΔF", fontsize=8)
axs[1].set(xlabel="time (s)", ylabel="dist. from peak")
axs[0].set(ylabel="sorted ROI n.")

f.savefig(fig_location / "phase_centered_traces.pdf")

### Improve shift by interpolation

This looks nice, but we could potentially improve it! Above, we were shifting the activation bump assuming that neurons where equally spaced across all possible angles in rPC space. However, this might not be the case, and we actually know that we can have non-homogeneous distribution on neurons along the circle.
Let's see if we can improve things by interpolating the traces to sample homogeneously along the circle before shifting:

In [None]:
# arbitrary number of bins for the resampling:
N_BINS_RESAMPLED = 100

# we will resample over the (-pi, pi) interval
resampling_base = np.linspace(-np.pi, np.pi, N_BINS_RESAMPLED)

angle_resampled_traces = np.zeros((exp.n_pts, N_BINS_RESAMPLED))
for i in range(exp.n_pts):
    angle_resampled_traces[i, :] = np.interp(
        resampling_base,
        exp.rpc_angles[sort_idxs],
        exp.traces[i, exp.hdn_indexes[sort_idxs]],
    )

# Find the right amount of shift over time to have the bump centered:
# by first stretching phase to 0-1 interval and then to 0-n_rois interval
phase_shifts_res = (exp.network_phase / (2 * np.pi)) * (N_BINS_RESAMPLED - 1)

# Then, apply shifts to traces:
shifted_resamp_traces = roll_columns_jit(
    angle_resampled_traces, -np.round(phase_shifts_res)
)

In [None]:
f_lims = dict(vmin=-1.7, vmax=1.7)
f, ax = plt.subplots(1, 1, figsize=(3.5, 2))

for i, (lab, x, traces) in enumerate(
    zip(
        ["not resamp.", "resamp."],
        [exp.rpc_angles[sort_idxs], resampling_base],
        [shifted_traces, shifted_resamp_traces],
    )
):

    mn, sd = traces.mean(0), traces.std(0)
    (p,) = ax.plot(x, mn, lw=2, label=lab)
    ax.fill_between(x, mn - sd, mn + sd, lw=0, alpha=0.5, label="_nolegend_")

ax.set(xlabel="phase dist. from peak", ylabel="avg. ΔF", ylim=(-2, 2))
ax.legend(bbox_to_anchor=(1.3, 1, 0.01, 0.0))
for l in [-np.pi / 2, 0, np.pi / 2]:
    plt.axvline(l, lw=0.5, c=".7", zorder=-100)
plt.axhline(0, lw=0.5, c=".7", zorder=-100)
ltrplt.despine(ax)

plt.tight_layout()
f.savefig(fig_location / "pre_post_interp.pdf")

## Loop over all fish 
Now, let's compare phase activations across fish from the entire dataset!

In [None]:
N_BINS_RESAMPLED = 100
# we will resample over the (-pi, pi) interval
resampling_base = np.linspace(-np.pi, np.pi, N_BINS_RESAMPLED)

all_mn_profiles = np.zeros((N_BINS_RESAMPLED, len(dataset_folders)))
all_std_profiles = np.zeros((N_BINS_RESAMPLED, len(dataset_folders)))
for fi, path in tqdm(enumerate(dataset_folders)):
    exp = LotrExperiment(path)
    sort_idxs = np.argsort(exp.rpc_angles)

    angle_resampled_traces = np.zeros((exp.n_pts, N_BINS_RESAMPLED))
    for i in range(exp.n_pts):
        angle_resampled_traces[i, :] = np.interp(
            resampling_base,
            exp.rpc_angles[sort_idxs],
            exp.traces[i, exp.hdn_indexes[sort_idxs]],
        )

    # Find the right amount of shift over time to have the bump centered:
    # by first stretching phase to 0-1 interval and then to 0-n_rois interval
    phase_shifts_res = (exp.network_phase / (2 * np.pi)) * (N_BINS_RESAMPLED - 1)

    # Then, apply shifts to traces:
    shifted_resamp_traces = roll_columns_jit(
        angle_resampled_traces, -np.round(phase_shifts_res)
    )

    all_mn_profiles[:, fi] = shifted_resamp_traces.mean(0)
    all_std_profiles[:, fi] = shifted_resamp_traces.std(0)

all_profiles = np.array(all_profiles)

In [None]:
fig, axs = plt.subplots(2, 1, constrained_layout=True, figsize=(3, 3), sharex=True)


for i in range(all_std_profiles.shape[1]):
    mn, std = all_mn_profiles[:, i], all_std_profiles[:, i]
    axs[1].fill_between(resampling_base, mn - std, mn + std, lw=0, fc=".0", alpha=0.02,
                       label="_nolegend_")
axs[1].fill_between([], [], [], lw=0, fc=".0", alpha=0.2,
                       label="fish (mn+/-sd)")
axs[1].plot(resampling_base, all_mn_profiles.mean(1), label="mean")
axs[1].legend()
axs[1].set(xlabel="Norm. network position", ylabel="mean dF")
axs[0].imshow(
    all_profiles, aspect="auto", extent=[-np.pi, np.pi, all_profiles.shape[0], 0],
    cmap=COLS["dff_plot"],
)
axs[0].set(ylabel="Fish n.")

for l in [-np.pi / 2, 0, np.pi / 2]:
    plt.axvline(l, lw=0.5, c=".7", zorder=-100)
plt.axhline(0, lw=0.5, c=".7", zorder=-100)
[ltrplt.despine(ax) for ax in axs]

plt.tight_layout()
f.savefig(fig_location / "all_fish_profiles.pdf")

In [None]:
all_mean_profiles = [np.mean(p, 0) for p in all_profiles]

all_interpd = []
new_x = np.arange(0, 1, 0.01)

fig, axs = plt.subplots(2, 1, constrained_layout=True, figsize=(4, 5), sharex=True)

all_interpd
for i, p in enumerate(all_profiles):
    med_p = np.mean(p, 0)
    std_p = np.std(p, 0)
    n_rois = len(med_p)
    roll_pts = -np.argmax(med_p) + n_rois // 4
    circ_perm_mn = np.roll(med_p, roll_pts)
    circ_perm_std = np.roll(std_p, roll_pts)

    x = np.arange(n_rois) / n_rois
    interp_mn = interpolate(x, circ_perm_mn, new_x)
    interp_std = interpolate(x, circ_perm_std, new_x)

    axs[1].fill_between(
        new_x, interp_mn - interp_std, interp_mn + interp_std, lw=0, alpha=0.1
    )
    axs[1].plot(new_x, interp_mn)

    all_interpd.append(interp_mn)

axs[1].axvline(0.5, lw=1.0, c=(0.2,) * 3)
axs[1].set(xlabel="Norm. network position", ylabel="mean dF")
axs[0].imshow(
    all_interpd, aspect="auto", extent=[0, 1, len(all_interpd), 0],
)
axs[0].set(ylabel="Fish n.")

plt.tight_layout()

In [None]:
PHASE_BINS_S = 10

all_profiles = []
all_phases = []
for path in tqdm(file_list):
    exp = LotrExperiment(path)
    n_pts, n_cells = exp.n_pts, exp.n_rois
    n_sel = len(exp.hdn_indexes)

    # Compute network phase with the usual fit of the ring over PCA space:
    pcaed, phase, _, _ = pca_and_phase(
        exp.traces[exp.pca_t_slice, exp.hdn_indexes], exp.traces[:, exp.hdn_indexes]
    )

    # Compute preferred phase of each neuron:
    neuron_phases, _ = fit_phase_neurons(
        exp.traces[:, exp.hdn_indexes], phase, disable_bar=True
    )

    # Find sorting over neurons according to phases...
    sort = np.argsort(neuron_phases)
    # ...and apply it to sort the traces:
    sorted_traces = exp.traces[:, exp.hdn_indexes[sort]]

    # From phase array, compute amount of shift in number of neurons units:
    phase_bins = np.round(((phase / (2 * np.pi)) + 0.5) * (n_sel - 1)).astype(int)

    # Shift every timepoint column of the traces matrix by an amount speficied by the phase:
    shifted = np.zeros((n_pts, n_sel))
    for i in range(n_pts):
        shifted[i, :] = np.roll(exp.traces[i, exp.hdn_indexes[sort]], phase_bins[i])

    # Bin in small temporal chunks of PHASE_BINS_S seconds:
    bin_wnd = int(PHASE_BINS_S * exp.fn)  # binning window, in timepoints units
    shifted_all_binned = np.array(
        [
            shifted[i * bin_wnd : (i + 1) * bin_wnd, :].mean(0)
            for i in range(n_pts // bin_wnd)
        ]
    )

    all_profiles.append(shifted_all_binned)
    all_phases.append(neuron_phases)

In [None]:
all_mean_profiles = [np.mean(p, 0) for p in all_profiles]

all_interpd = []
new_x = np.arange(0, 1, 0.01)

fig, axs = plt.subplots(2, 1, constrained_layout=True, figsize=(4, 5), sharex=True)

all_interpd
for i, p in enumerate(all_profiles):
    med_p = np.mean(p, 0)
    std_p = np.std(p, 0)
    n_rois = len(med_p)
    roll_pts = -np.argmax(med_p) + n_rois // 4
    circ_perm_mn = np.roll(med_p, roll_pts)
    circ_perm_std = np.roll(std_p, roll_pts)

    x = np.arange(n_rois) / n_rois
    interp_mn = interpolate(x, circ_perm_mn, new_x)
    interp_std = interpolate(x, circ_perm_std, new_x)

    axs[1].fill_between(
        new_x, interp_mn - interp_std, interp_mn + interp_std, lw=0, alpha=0.1
    )
    axs[1].plot(new_x, interp_mn)

    all_interpd.append(interp_mn)

axs[1].axvline(0.5, lw=1.0, c=(0.2,) * 3)
axs[1].set(xlabel="Norm. network position", ylabel="mean dF")
axs[0].imshow(
    all_interpd, aspect="auto", extent=[0, 1, len(all_interpd), 0],
)
axs[0].set(ylabel="Fish n.")

plt.tight_layout()