In [None]:
%matplotlib widget

In [None]:
import numpy as np
from scipy.ndimage import gaussian_filter
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import seaborn as sns
#sns.set(palette="deep")
#cols = sns.color_palette()
cols = [[0.        , 0.62352941, 0.88627451],
        [0.83529412, 0.36470588, 0.28235294]]
from pathlib import Path

from matplotlib import pyplot as plt
from matplotlib import gridspec

from lotr import LotrExperiment, DATASET_LOCATION
from lotr.pca import pca_and_phase
from lotr.plotting import color_stack, despine, add_scalebar, plot_arrow, add_cbar

def color_by_activation(roi_stack, exp, idx):
    activations = np.full(exp.n_rois, np.nan)
    activations[exp.hdn_indexes] = exp.traces[idx, exp.hdn_indexes]
    return color_stack(fixed_rois, activations, color_scheme="Greens")

In [None]:
# List all experiments
master_path = Path(DATASET_LOCATION)

path = master_path / "210314_f1" / "210314_f1_natmov"# "210926_f0" / "210926_f0_gainmod" 

exp = LotrExperiment(path)
anatomy = exp.anatomy_stack
rois_stack = exp.rois_stack

pcaed, phase, pca, circle_params = pca_and_phase(exp.traces[:, exp.hdn_indexes])

In [None]:
fixed_rois = exp.rois_stack.copy()
sizes = np.array([np.sum(fixed_rois==i) for i in exp.hdn_indexes])

for i, s in zip(exp.hdn_indexes, sizes):
    if s > 300:
        fixed_rois[fixed_rois==i] = -1
        fixed_rois[fixed_rois==i] = -1

In [None]:
50/0.6

In [None]:
bouts_idxs = dict(rt=[1165, 1010], lf=[3490, 280])

t_timepts = 3
step = 25

fig = plt.figure(figsize=(8, 6))  # plt.subplots(2, 4, figsize=(8, 6))
gs = gridspec.GridSpec(5, 4, figure=fig)
pca_axs = [fig.add_subplot(gs[:2, i]) for i in range(4)]
img_axs = [fig.add_subplot(gs[2:, i]) for i in range(4)]
beh_axs = [inset_axes(ax, width="100%", height="100%",
                   bbox_to_anchor=(0.3, -0.1, .7, .3),
                   bbox_transform=ax.transAxes, loc=2, borderpad=0)for ax in pca_axs]
cbar_ax = inset_axes(img_axs[-1], width="100%", height="100%",
                   bbox_to_anchor=(.9, 0.05, .05, .15),
                   bbox_transform=img_axs[-1].transAxes, loc=2, borderpad=0)

y_crop = (260, 80)
pca_t_sl = slice(0, 4000)

for i in range(2):
    for j, side in enumerate(["rt", "lf"]):
        i_col = i*2+j
        start_idx = bouts_idxs[side][i]
        pca_ax, beh_ax, img_ax = pca_axs[i_col], beh_axs[i_col], img_axs[i_col]
        pca_ax.set_title(dict(rt="Right bouts", lf="Left bouts")[side], fontsize=8)
        
        # PCA plot
        # Plot gray:
        pca_ax.plot(pcaed[pca_t_sl, 0], -pcaed[pca_t_sl, 1], lw=1, c=(0.9,)*3)
            
        # Plot segment in color:
        pca_seg = pcaed[start_idx: start_idx + step*t_timepts, :].copy()
        pca_seg[:, 1] = -pca_seg[:, 1]  # invert one axis to match anatomy
        plot_arrow(pca_seg, ax=pca_ax, col=cols[j], s=8)
        
        pca_ax.axis("equal")
        
        
        
        # Behavior plot:
        seg = (exp.behavior_log.t > (start_idx / exp.fn)) & \
                    (exp.behavior_log.t < (start_idx + step*t_timepts) / exp.fn)
        
        # Downsample to reduce number of plot points - maybe rasterize in the future
        beh_ax.plot(exp.behavior_log[seg].t[::3], 
                    exp.behavior_log[seg].tail_sum[::3], lw=1, c=(0.4,)*3)
        
        beh_ax.set(ylim=(-np.pi, np.pi))
        
        
        xlabel_bh, ylabel_bh = (None, None) if i + j == 0 else ("", "")
 
        add_scalebar(ax=beh_ax, xlen=3, ylen=2, ypos=-1.5, xlabel=xlabel_bh, ylabel=ylabel_bh,
                     xunits="s", yunits="rad", text_params=dict(fontsize=7))
        
          
        # Plot ROI maps:
        colored = []
        for i_t in range(t_timepts):
            idx = start_idx + i_t*step
            colored.append(color_by_activation(fixed_rois, exp, idx)[:, :, y_crop[0]:y_crop[1]:-1, :])
        stacked = np.concatenate(colored, axis=2)
        
        im_plot = img_ax.imshow(stacked.max(0).swapaxes(0, 1), aspect="equal", cmap="Greens")
        y_span = y_crop[0] - y_crop[1]
        
        if i + j == 0:
            [img_ax.text(-10, i*y_span-10, f"t={int(i*step/exp.fn)} s", fontsize=8) 
                 for i in range(t_timepts)]
        
        
        [despine(ax, "all") for ax in [pca_ax, img_ax, beh_ax]]
        
        if i_col == 0:
            add_scalebar(ax=pca_ax, xlen=7, ylen=7, xlabel="PC1", 
                     ylabel="PC2", text_params=dict(fontsize=8))
            add_scalebar(ax=img_ax, ylen=-83, xlen=83, ypos=y_span*3-20, xpos=20,
                         xlabel="lf.-rt.", ylabel="post.-ant.",
                     text_params=dict(fontsize=8))
        elif i_col == 3:
            cbar = add_cbar(cbar_ax, im_plot, ticks=[], ticklabels=[])
            cbar.set_label("ΔF (Z sc.)", fontsize=8)

            
fig.savefig("/Users/luigipetrucco/Desktop/network_evolution.pdf")

In [None]:
cbar.set_label("ΔF (Z sc.)")

In [None]:
plt.figure()
plt.imshow(colored[0])

In [None]:
colored[0].shape

In [None]:
color_by_activation(fixed_rois, exp, idx)[:, :, 250:80:-1, :].shape

In [None]:
7/11