In [None]:
%matplotlib widget
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import flammkuchen as fl

import seaborn as sns
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()

plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Libertinus Sans']
plt.rcParams['axes.linewidth'] = 0.5 
plt.rcParams['axes.labelsize'] = 10
plt.rcParams["legend.fontsize"] = 8
for t in ["x", "y"]:
    plt.rcParams[t+'tick.major.size'] = 3
    plt.rcParams[t+'tick.labelsize'] = 8
    plt.rcParams[t+'tick.major.width'] = 0.5

from lotr.experiment_class import LotrExperiment
from bouter.utilities import crop

from lotr.plotting import add_cbar, color_plot, despine
from lotr.pca import get_fictive_trajectory, pca_and_phase, \
                     linear_regression, fit_phase_neurons
from lotr.utils import zscore

In [None]:
master_path = Path("/Users/luigipetrucco/Desktop/all_source_data/full_ring")
file_list = list(master_path.glob("*/*[0-9]_f*"))

In [None]:
FN = 5
UM_PER_PX = 0.6

for path in file_list[:1]:
    selected = fl.load(next(path.glob("selected.h5")))
    beh_df = fl.load(next(path.glob("*behavior_log*")), "/data")
    anatomy_stack = fl.load(path / "data_from_suite2p_unfiltered.h5", "/anatomy_stack")
    rois_stack = fl.load(path / "data_from_suite2p_unfiltered.h5", "/rois_stack")
    coords = fl.load(path / "data_from_suite2p_unfiltered.h5", "/coords")
    bouts_df = fl.load(path / "bouts_df.h5")
    traces = fl.load(path / "filtered_traces.h5", "/detr")
    exp = LotrExperiment(path)
    n_pts, n_cells = traces.shape
    n_sel, = selected.shape

    time_array = np.arange(n_pts) / FN

    pcaed, phase, _ = pca_and_phase(traces[2000:8000, selected], traces[:, selected])
    phase_unwrapped = np.unwrap(phase)
    if "idx_imaging" not in bouts_df.columns:
        bouts_df["idx_imaging"] = np.round(bouts_df["t_start"]).astype(np.int) * FN

    traj = get_fictive_trajectory(len(phase_unwrapped), bouts_df, min_bias=0)
    
    params = linear_regression(phase_unwrapped[:len(traj) // 2], traj[:len(traj) // 2])

    neuron_phases, _ = fit_phase_neurons(traces[:, selected], phase)
    
    norm_phase = phase.copy()
    norm_phase[np.insert(np.abs(np.diff(phase)), 0, 0) > 0.1] = np.nan
    
    cc = np.corrcoef(traces[:, selected].T)
    # np.fill_diagonal(cc, np.nan)
    sort = np.argsort(neuron_phases)
    peak = np.argmax(np.nanmean(np.abs(cc[sort[1:], :] - cc[sort[:-1], :]), 1))
    sort = np.roll(sort, -peak - 1)
    sorted_traces = traces[:, selected[sort]]
    
    phase_bins = np.round(((phase / (2*np.pi)) + 0.5)*(len(selected) - 1)).astype(np.int)
    phase_bins.min(), phase_bins.max()
    shifted = np.zeros((n_pts, len(selected)))
    for i in range(n_pts):
        shifted[i, :] = np.roll(traces[i, selected[sort]], phase_bins[i])   
    bn = 50
    shifted_all_bin = np.array([shifted[i*bn:(i+1)*bn, :].mean(0) for i in range(n_pts // bn)])
    
    coords_sel = coords[selected, 1:] * UM_PER_PX
    midline = anatomy_stack.shape[1] // 2
    caud_lim = coords_sel.max(0)[1]
    coords_sel[:, 1] = coords_sel[:, 1] - caud_lim
    left = coords_sel[:, 0] < midline
    coords_linearized = coords_sel[:, 1].copy()
    # coords_linearized[~left] = -coords_linearized[~left]
    
    ## Figure code ##
    
    fig = plt.figure(constrained_layout=True, figsize=(9, 5))
    gs = fig.add_gridspec(5, 4)

    anat_ax = fig.add_subplot(gs[2:-1, 0])
    anat_ax.imshow(anatomy_stack.mean(0).T, origin="lower", cmap="gray_r", vmin=3, vmax=30)
    n_cb = anat_ax.scatter(coords[selected, 1], coords[selected, 2], c=neuron_phases, s=15, 
                cmap="twilight", alpha=0.8, vmin=-np.pi, vmax=np.pi, lw=0)
    anat_ax.axis("off")
    cb = add_cbar(fig.add_axes((0.05, 0.3, 0.05, 0.015)), n_cb, label="pref. ϕ", labelsize=10,
             orientation="horizontal", ticks=[-np.pi+0.2, np.pi-0.2], ticklabels=["-π", "π"])
    # cb.set_label(,fontdict={"size": 8})
    
    anat_quant = fig.add_subplot(gs[-1, 0])
    anat_quant.scatter(neuron_phases, coords_linearized, s=5)
    anat_quant.set(ylim=(-50, 1), xlabel="neuron phase (rad)", ylabel="pos. (μm)")
    despine(anat_quant)


    pc_ax = fig.add_subplot(gs[:2, 0])
    pc_ax.plot(pcaed[:, 0], pcaed[:, 1], 
             c=(0.6,)*3, lw=0.5, zorder=-100)
    cb = pc_ax.scatter(pcaed[:, 0], pcaed[:, 1],  c=phase[:], lw=0.5, s=3, cmap="twilight",)
    add_cbar(fig.add_axes((0.05, 0.95, 0.05, 0.015)), cb, label='ϕ',  labelsize=10,
             orientation="horizontal", ticks=[-np.pi+0.2, np.pi-0.2], ticklabels=["-π", "π"])
    
    b_len = 3
    bar_pos_x, bar_pos_y = pcaed[:, :2].min(0) - b_len 
    pc_ax.plot([bar_pos_x, bar_pos_x, bar_pos_x+b_len], 
              [bar_pos_y+b_len, bar_pos_y, bar_pos_y], lw=0.5, c=(0.3,)*3)
    pc_ax.text(bar_pos_x, bar_pos_y + b_len/2, "PC2", ha="right", va="center", 
               rotation='vertical', fontsize=8)
    pc_ax.text(bar_pos_x + b_len/2, bar_pos_y, "PC1", ha="center", va="top", fontsize=8)
    pc_ax.axis("off")
    
    cc_ax = fig.add_subplot(gs[:2, -1])
    cc_cb = cc_ax.imshow(cc[sort, :][:, sort], 
           cmap="RdBu_r", vmin=-1, vmax=1)
    cc_ax.set(xlabel="phase-sorted roi", xticks=[], yticks=[])
    despine(cc_ax, sides="all")
    cb = add_cbar(fig.add_axes((0.77, 0.87, 0.008, 0.1)), cc_cb, label="corr.", labelsize=10,
             orientation="vertical")
    
    prof_ax = fig.add_subplot(gs[2:3, -1])
    prof_ax.imshow(shifted_all_bin, aspect="auto", cmap="gray", extent=(0, len(selected), time_array[-1], 0),
               vmin=np.percentile(shifted_all_bin, 1), vmax=np.percentile(shifted_all_bin, 99))
    prof_ax.set(ylabel="time (s)", xticks=[])
    despine(prof_ax)
    
    prof_plot_ax = fig.add_subplot(gs[3:4, -1])
    prof_plot_ax.plot(shifted_all_bin.T, lw=0.3, c=(0.5,)*3, rasterized=True)
    prof_plot_ax.plot(shifted_all_bin.mean(0), lw=2, c=cols[2])
    prof_plot_ax.set(xlabel="phase-shifted pos", ylabel="dF")
    despine(prof_plot_ax)

    traj_ax = fig.add_subplot(gs[3:, 1:-1])
    traj_ax.plot(time_array, phase_unwrapped*params[1] + params[0], label="network phase")
    traj_ax.plot(time_array, traj, label="fictive heading")
    traj_ax.legend(frameon=False)
    traj_ax.set(xlabel="time (s)", xlim=(0, time_array[-1]), title=path.name)
    despine(traj_ax)
    
    beh_ax = fig.add_subplot(gs[0, 1:-1])
    beh_ax.plot(beh_df["t"], beh_df["tail_sum"], label="network phase", lw=0.7, rasterized=True)
    beh_ax.set(ylabel="tail sum", xticklabels=[], xlim=(0, time_array[-1]))
    for s in exp["stimulus"]["log"]:
        if s["name"] in ["open_loop", "bg"]:
            beh_ax.axvspan(s["t_start"], s["t_start"] + s["duration"], lw=0, fc=(0.8,)*3, zorder=-100)
            beh_ax.text(s["t_start"]+10, beh_ax.get_ylim()[1]-0.1, "playback", ha="left", va="top", c=(0.4,)*3, fontsize=7)
    despine(beh_ax)
    
    
    traces_ax = fig.add_subplot(gs[1:3, 1:-1])
    tr_cb = traces_ax.imshow(sorted_traces.T, extent=(0, n_pts/FN, 0, n_sel), aspect="auto", cmap="gray", 
               vmin=np.percentile(sorted_traces, 1), vmax=np.percentile(sorted_traces, 99))
    traces_ax.plot(time_array, n_sel * (1 + norm_phase / np.pi)/2, c=cols[3], lw=1)
    traces_ax.set(yticks=[], ylabel="phase-sorted roi", xticklabels=[])
    traces_ax.axes.spines["left"].set_visible(False)
    despine(traces_ax, sides=["right", "top", "left"])
    cb = add_cbar(fig.add_axes((0.73, 0.7, 0.008, 0.1)), tr_cb, label="dF", labelsize=10,
             orientation="vertical")

    # plt.savefig(path / "summary_plots.pdf", dpi=300)
    # plt.tight_layout()