In [None]:
#%matplotlib widget

import numpy as np
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

from pathlib import Path

from matplotlib import pyplot as plt
from matplotlib import gridspec

from lotr import LotrExperiment, A_FISH, FIGURES_LOCATION
from lotr.pca import pca_and_phase
from lotr.plotting import COLS, color_stack, despine, add_scalebar, plot_arrow, add_cbar

def color_by_activation(roi_stack, exp, idx):
    activations = np.full(exp.n_rois, np.nan)
    activations[exp.hdn_indexes] = exp.traces[idx, exp.hdn_indexes]
    return color_stack(fixed_rois, activations, color_scheme=COLS["dff_img"])

In [None]:
plt.figure()

In [None]:
exp = LotrExperiment(Path(A_FISH))

pcaed, phase, pca, circle_params = pca_and_phase(exp.traces[:, exp.hdn_indexes])

In [None]:
fixed_rois = exp.rois_stack.copy()
sizes = np.array([np.sum(fixed_rois==i) for i in exp.hdn_indexes])

# Remove 5-6 large ROIs that are clearly not cells. Visualization won't change much
# including them, as they have same activity as neighbor regions
for i, s in zip(exp.hdn_indexes, sizes):
    if s > 300:
        fixed_rois[fixed_rois==i] = -1
        fixed_rois[fixed_rois==i] = -1

In [None]:
# Indexes around which to crop:
bouts_idxs = dict(rt=[1165, 1010], lf=[3490, 280])

# Number of timepoints shown, and size of each step in imaging samples units:
t_timepts = 3
step = 25

fig = plt.figure(figsize=(8, 6))
gs = gridspec.GridSpec(5, 4, figure=fig)
pca_axs = [fig.add_subplot(gs[:2, i]) for i in range(4)]
img_axs = [fig.add_subplot(gs[2:, i]) for i in range(4)]
beh_axs = [inset_axes(ax, width="100%", height="100%",
                   bbox_to_anchor=(0.3, -0.1, .7, .3),
                   bbox_transform=ax.transAxes, loc=2, borderpad=0)for ax in pca_axs]

y_crop = (80, 260)
pca_t_sl = slice(0, 4000)

for i in range(2):
    for j, side in enumerate(["rt", "lf"]):
        i_col = i*2+j
        start_idx = bouts_idxs[side][i]
        
        pca_ax, beh_ax, img_ax = pca_axs[i_col], beh_axs[i_col], img_axs[i_col]
        
        # PCA plot
        # --------
        
        # pca_ax.set_title(dict(rt="Right turn", lf="Left turn")[side], fontsize=8)
        pca_ax.text(4, 17, dict(rt="Right turn", lf="Left turn")[side], ha="center", fontsize=8)
        
        # Plot gray:
        pca_ax.plot(pcaed[pca_t_sl, 0], -pcaed[pca_t_sl, 1], lw=1, c=(0.9,)*3)
            
        # Plot segment in color:
        pca_seg = pcaed[start_idx: start_idx + step*t_timepts, :].copy()
        pca_seg[:, 1] = -pca_seg[:, 1]  # invert one axis to match anatomy
        plot_arrow(pca_seg, ax=pca_ax, col=COLS["sides"][side], s=8)
        
        pca_ax.axis("equal")
        
        # Behavior plot:
        # --------------
        seg = (exp.behavior_log.t > (start_idx / exp.fn)) & \
                    (exp.behavior_log.t < (start_idx + step*t_timepts) / exp.fn)
        
        # Downsample to reduce number of plot points - maybe rasterize in the future
        beh_ax.plot(exp.behavior_log[seg].t[::3], 
                    exp.behavior_log[seg].tail_sum[::3], lw=1, c=COLS["beh"])
        
        beh_ax.set(ylim=(-np.pi, np.pi))
        
        
        xlabel_bh, ylabel_bh = (None, "tail") if i + j == 0 else ("", "")
 
        add_scalebar(ax=beh_ax, xlen=3, ylen=2, ypos=-1.5, xlabel=xlabel_bh, ylabel=ylabel_bh,
                     xunits="s", yunits="rad", text_params=dict(fontsize=7))
        
          
        # ROI maps plot:
        # --------------
        colored = []
        for i_t in range(t_timepts):
            idx = start_idx + i_t*step
            colored.append(color_by_activation(fixed_rois, exp, idx)[:, y_crop[0]:y_crop[1], :, :])
        stacked = np.concatenate(colored, axis=1)
        
        im_plot = img_ax.imshow(stacked.max(0), aspect="equal", cmap="Greens")
        y_span = y_crop[1] - y_crop[0]
        
        if i + j == 0:
            [img_ax.text(-10, i*y_span-10, f"t={int(i*step/exp.fn)} s", fontsize=8) 
                 for i in range(t_timepts)]
        
        
        [despine(ax, "all") for ax in [pca_ax, img_ax, beh_ax]]
        
        if i_col == 0:
            add_scalebar(ax=pca_ax, xlen=7, ylen=7, xlabel="PC1", 
                     ylabel="PC2", text_params=dict(fontsize=8))
            add_scalebar(ax=img_ax, ylen=-83, xlen=83, ypos=y_span*3-20, xpos=20,
                         xlabel="lf.-rt.", ylabel="post.-ant.",
                     text_params=dict(fontsize=8))
        elif i_col == 3:
            cbar = add_cbar(im_plot, img_ax, (.9, 0.05, .05, .15), ticks=[], ticklabels=[])
            cbar.set_label("ΔF (Z sc.)", fontsize=8)


# save if necessary
fig.savefig(FIGURES_LOCATION / "network_evolution_bouts.pdf", dpi=300)

In [None]:
img_ax