# Explore network activity profile

We will now explore how consistent in time seems to be the wave of activation that spreads through our ring of neurons.

In [None]:
%matplotlib widget
from pathlib import Path

import flammkuchen as fl
import lotr.plotting as ltrplt
import numpy as np
import pandas as pd
import seaborn as sns
from bouter.utilities import crop
from lotr import A_FISH, DATASET_LOCATION, LotrExperiment
from lotr.pca import fit_phase_neurons, pca_and_phase
from lotr.utils import interpolate, zscore
from matplotlib import pyplot as plt
from tqdm import tqdm

COLS = ltrplt.COLS

In [None]:
exp = LotrExperiment(A_FISH)

# Compute preferred phase of each neuron:
rpc_angles = exp.rpc_angles

# Find sorting over neurons according to angles...
sort = np.argsort(rpc_angles)

# ...and apply it to sort the traces:
sorted_traces = exp.traces[:, exp.hdn_indexes[sort]]

# From phase array, compute amount of shift in number of neurons units:
phase_bins = np.round(
    ((exp.network_phase / (2 * np.pi)) + 0.5) * (exp.n_hdns - 1)
).astype(int)

# Shift every timepoint column of the traces matrix by an amount speficied by the phase:
shifted = np.zeros((exp.n_pts, exp.n_hdns))
for i in range(exp.n_pts):
    shifted[i, :] = np.roll(exp.traces[i, exp.hdn_indexes[sort]], -phase_bins[i])


PHASE_BINS_S = 20
# Bin in small temporal chunks of PHASE_BINS_S seconds:
bin_wnd = int(PHASE_BINS_S * exp.fn)  # binning window, in timepoints units
shifted_all_binned = np.array(
    [
        shifted[i * bin_wnd : (i + 1) * bin_wnd, :].mean(0)
        for i in range(exp.n_pts // bin_wnd)
    ]
)

In [None]:
f_lims = dict(vmin=-1.7, vmax=1.7)
f, axs = plt.subplots(2, 1, figsize=(6, 4))

for i, traces in enumerate([sorted_traces, shifted]):
    axs[i].imshow(traces.T, cmap=COLS["dff_plot"], aspect="auto", **f_lims)
    ltrplt.despine(axs[i])

In [None]:
plt.figure()
plt.plot(-np.unwrap(exp.network_phase))

In [None]:
master_path = Path(DATASET_LOCATION)
file_list = [
    A_FISH
]  # sorted([f.parent for f in master_path.glob("*/*[0-9]_f*/selected.h5")])

In [None]:
PHASE_BINS_S = 10

all_profiles = []
all_phases = []
for path in tqdm(file_list):
    exp = LotrExperiment(path)
    n_pts, n_cells = exp.n_pts, exp.n_rois
    n_sel = len(exp.hdn_indexes)

    # Compute network phase with the usual fit of the ring over PCA space:
    pcaed, phase, _, _ = pca_and_phase(
        exp.traces[exp.pca_t_slice, exp.hdn_indexes], exp.traces[:, exp.hdn_indexes]
    )

    # Compute preferred phase of each neuron:
    neuron_phases, _ = fit_phase_neurons(
        exp.traces[:, exp.hdn_indexes], phase, disable_bar=True
    )

    # Find sorting over neurons according to phases...
    sort = np.argsort(neuron_phases)
    # ...and apply it to sort the traces:
    sorted_traces = exp.traces[:, exp.hdn_indexes[sort]]

    # From phase array, compute amount of shift in number of neurons units:
    phase_bins = np.round(((phase / (2 * np.pi)) + 0.5) * (n_sel - 1)).astype(int)

    # Shift every timepoint column of the traces matrix by an amount speficied by the phase:
    shifted = np.zeros((n_pts, n_sel))
    for i in range(n_pts):
        shifted[i, :] = np.roll(exp.traces[i, exp.hdn_indexes[sort]], phase_bins[i])

    # Bin in small temporal chunks of PHASE_BINS_S seconds:
    bin_wnd = int(PHASE_BINS_S * exp.fn)  # binning window, in timepoints units
    shifted_all_binned = np.array(
        [
            shifted[i * bin_wnd : (i + 1) * bin_wnd, :].mean(0)
            for i in range(n_pts // bin_wnd)
        ]
    )

    all_profiles.append(shifted_all_binned)
    all_phases.append(neuron_phases)

In [None]:
all_mean_profiles = [np.mean(p, 0) for p in all_profiles]

all_interpd = []
new_x = np.arange(0, 1, 0.01)

fig, axs = plt.subplots(2, 1, constrained_layout=True, figsize=(4, 5), sharex=True)

all_interpd
for i, p in enumerate(all_profiles):
    med_p = np.mean(p, 0)
    std_p = np.std(p, 0)
    n_rois = len(med_p)
    roll_pts = -np.argmax(med_p) + n_rois // 4
    circ_perm_mn = np.roll(med_p, roll_pts)
    circ_perm_std = np.roll(std_p, roll_pts)

    x = np.arange(n_rois) / n_rois
    interp_mn = interpolate(x, circ_perm_mn, new_x)
    interp_std = interpolate(x, circ_perm_std, new_x)

    axs[1].fill_between(
        new_x, interp_mn - interp_std, interp_mn + interp_std, lw=0, alpha=0.1
    )
    axs[1].plot(new_x, interp_mn)

    all_interpd.append(interp_mn)

axs[1].axvline(0.5, lw=1.0, c=(0.2,) * 3)
axs[1].set(xlabel="Norm. network position", ylabel="mean dF")
axs[0].imshow(
    all_interpd, aspect="auto", extent=[0, 1, len(all_interpd), 0],
)
axs[0].set(ylabel="Fish n.")

plt.tight_layout()