In [None]:
%matplotlib widget
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import flammkuchen as fl

import seaborn as sns
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()

from lotr.experiment_class import LotrExperiment
from bouter.utilities import crop

from lotr.plotting import add_cbar, color_plot, despine
from lotr.pca import get_fictive_trajectory, pca_and_phase, \
                     linear_regression, fit_phase_neurons
from lotr.utils import zscore, interpolate

from tqdm import tqdm

In [None]:
master_path = Path("/Users/luigipetrucco/Desktop/all_source_data/full_ring")
file_list = [f.parent for f in master_path.glob("*/*[0-9]_f*/selected.h5")]

In [None]:
UM_PER_PX = 0.6

all_profiles = []
all_phases = []
for path in tqdm(file_list):
    exp = LotrExperiment(path)
    FN = exp.fn
    exp = LotrExperiment(path)
    n_pts, n_cells = exp.n_pts, exp.n_rois
    n_sel = len(exp.hdn_indexes)

    time_array = np.arange(n_pts) / FN

    pcaed, phase, _ = pca_and_phase(exp.traces[exp.pca_t_slice, exp.hdn_indexes], 
                                    exp.traces[:, exp.hdn_indexes])
    phase_unwrapped = np.unwrap(phase)

    neuron_phases, _ = fit_phase_neurons(exp.traces[:, exp.hdn_indexes], phase, disable_bar=True)

    
    cc = np.corrcoef(exp.traces[:, exp.hdn_indexes].T)
    np.fill_diagonal(cc, np.nan)
    sort = np.argsort(neuron_phases)
    peak = np.argmax(np.nanmean(np.abs(cc[sort[1:], :] - cc[sort[:-1], :]), 1))
    sort = np.roll(sort, -peak - 1)
    sorted_traces = exp.traces[:, exp.hdn_indexes[sort]]
    
    phase_bins = np.round(((phase / (2*np.pi)) + 0.5)*(n_sel - 1)).astype(np.int)
    phase_bins.min(), phase_bins.max()
    shifted = np.zeros((n_pts, n_sel))
    for i in range(n_pts):
        shifted[i, :] = np.roll(exp.traces[i, exp.hdn_indexes[sort]], phase_bins[i])   
    bn = 50
    shifted_all_bin = np.array([shifted[i*bn:(i+1)*bn, :].mean(0) for i in range(n_pts // bn)])
    
    all_profiles.append(shifted_all_bin)
    all_phases.append(neuron_phases)

In [None]:
bad = [3, 17, 33, 36, 38, 19]

all_median_profiles = [np.mean(p, 0) for p in all_profiles]
all_stds_profiles = [np.std(p, 0) for p in all_profiles]

all_interpd = []
new_x = np.arange(0, 1, 0.01)

fig, axs = plt.subplots(2, 1, constrained_layout=True, figsize=(4, 5), sharex=True)

all_interpd
for i, p in enumerate(all_profiles):
    if i not in bad:
        med_p = np.mean(p, 0)
        std_p = np.std(p, 0)
        n_rois = len(med_p)
        roll_pts =  -np.argmax(med_p) + n_rois// 4
        circ_perm_mn = np.roll(med_p, roll_pts)
        circ_perm_std = np.roll(std_p, roll_pts)
        # circ_perm = circ_perm / np.std(circ_perm)
        
        x = np.arange(n_rois) / n_rois
        interp_mn = interpolate(x, circ_perm_mn, new_x)
        interp_std = interpolate(x, circ_perm_std, new_x)
        
        axs[1].fill_between(new_x, interp_mn-interp_std, 
                            interp_mn+interp_std, lw=0, alpha=0.1)
        axs[1].plot(new_x, interp_mn)
        
        all_interpd.append(interp_mn)

axs[1].axvline(0.5, lw=1., c=(0.2,)*3)
axs[1].set(xlabel="Norm. network position", ylabel="mean dF")
axs[0].imshow(all_interpd, extent=[0, 1, 0, len(all_interpd)], aspect="auto")
axs[0].set(ylabel="Fish n.")

plt.tight_layout()