## Network activation profile

Here we calculate the average shape of the network activation profile.

#### TODO
 - [ ] Write explanation figures
 - [ ] Solve fitting issue for excluded fish

In [None]:
%matplotlib widget
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import flammkuchen as fl

import seaborn as sns
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()

from lotr import LotrExperiment, DATASET_LOCATION
from bouter.utilities import crop

from lotr.plotting import add_cbar, color_plot, despine
from lotr.pca import pca_and_phase, fit_phase_neurons
from lotr.utils import zscore, interpolate

from tqdm import tqdm

In [None]:
master_path = Path(DATASET_LOCATION)
file_list = sorted([f.parent for f in master_path.glob("*/*[0-9]_f*/selected.h5")])

In [None]:
PHASE_BINS_S = 10

all_profiles = []
all_phases = []
for path in tqdm(file_list):
    exp = LotrExperiment(path)
    n_pts, n_cells = exp.n_pts, exp.n_rois
    n_sel = len(exp.hdn_indexes)

    # Compute network phase with the usual fit of the ring over PCA space:
    pcaed, phase, _, _ = pca_and_phase(exp.traces[exp.pca_t_slice, exp.hdn_indexes], 
                                    exp.traces[:, exp.hdn_indexes])

    # Compute preferred phase of each neuron:
    neuron_phases, _ = fit_phase_neurons(exp.traces[:, exp.hdn_indexes], phase, disable_bar=True)

    # Find sorting over neurons according to phases...
    sort = np.argsort(neuron_phases)
    # ...and apply it to sort the traces:
    sorted_traces = exp.traces[:, exp.hdn_indexes[sort]]
    
    # From phase array, compute amount of shift in number of neurons units:
    phase_bins = np.round(((phase / (2*np.pi)) + 0.5)*(n_sel - 1)).astype(np.int)

    # Shift every timepoint column of the traces matrix by an amount speficied by the phase:
    shifted = np.zeros((n_pts, n_sel))
    for i in range(n_pts):
        shifted[i, :] = np.roll(exp.traces[i, exp.hdn_indexes[sort]], phase_bins[i])   
    
    # Bin in small temporal chunks of PHASE_BINS_S seconds:
    bin_wnd = int(PHASE_BINS_S * exp.fn)  # binning window, in timepoints units
    shifted_all_binned = np.array([shifted[i*bin_wnd:(i+1)*bin_wnd, :].mean(0) 
                                    for i in range(n_pts // bin_wnd)])
    
    all_profiles.append(shifted_all_binned)
    all_phases.append(neuron_phases)

In [None]:
[print(file_list[i]) for i in [28, 26, 14, 2, 6, 19, 23, 25]]

In [None]:
# bad = [28, 26, 14, 2, 6, 19, 23, 25] # [3, 17, 33, 36, 38, 19]

all_mean_profiles = [np.mean(p, 0) for p in all_profiles]

all_interpd = []
new_x = np.arange(0, 1, 0.01)

fig, axs = plt.subplots(2, 1, constrained_layout=True, figsize=(4, 5), sharex=True)

all_interpd
for i, p in enumerate(all_profiles):
    med_p = np.mean(p, 0)
    std_p = np.std(p, 0)
    n_rois = len(med_p)
    roll_pts =  -np.argmax(med_p) + n_rois// 4
    circ_perm_mn = np.roll(med_p, roll_pts)
    circ_perm_std = np.roll(std_p, roll_pts)

    x = np.arange(n_rois) / n_rois
    interp_mn = interpolate(x, circ_perm_mn, new_x)
    interp_std = interpolate(x, circ_perm_std, new_x)

    axs[1].fill_between(new_x, interp_mn-interp_std, 
                        interp_mn+interp_std, lw=0, alpha=0.1)
    axs[1].plot(new_x, interp_mn)

    all_interpd.append(interp_mn)

axs[1].axvline(0.5, lw=1., c=(0.2,)*3)
axs[1].set(xlabel="Norm. network position", ylabel="mean dF")
axs[0].imshow(all_interpd, aspect="auto", extent=[0, 1, len(all_interpd), 0], )
axs[0].set(ylabel="Fish n.")

plt.tight_layout()