In [611]:
%matplotlib widget

In [612]:
from fimpylab import LightsheetExperiment

In [613]:
from pathlib import Path
import numpy as np
from matplotlib import  pyplot as plt
import flammkuchen as fl
import pandas as pd

from fimpylab import autoload_experiment

import pandas as pd

import seaborn as sns
from tqdm import tqdm
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()
import ipywidgets as widgets

plt.rcParams['figure.constrained_layout.use'] = True
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Libertinus Sans']

# %autoreload
from lotr.utils import zscore
from lotr.data_loading import preprocess_traces
from lotr.pca import pca_and_phase, get_fictive_trajectory, fictive_trajectory_and_fit

import scipy.optimize as soo

def qap_sorting_and_phase(traces, t_lims=None):
    n_pts, n = traces.shape
    
    if t_lims is None:
        t_lims = (0, n_pts)
    
    distance = np.corrcoef(traces[t_lims[0]:t_lims[1], :].T)

    flow = np.zeros((n, n))
    toshift = np.cos(np.linspace(-np.pi, np.pi, n))
    for i in range(n):
        flow[i, :] = np.roll(toshift, i)

    options = {"P0": "randomized"}
    res = min([soo.quadratic_assignment(flow, distance, method="faq", options = options) 
                 for i in range(1000)], key=lambda x: x.fun)

    options = {"partial_guess": np.array([np.arange(n), res.col_ind]).T}
    res = soo.quadratic_assignment(flow, distance, method="2opt", options=options)

    perm = res["col_ind"]
    
    traces_sorted = traces[:, perm]

    base = np.linspace(0, 2*np.pi, traces_sorted.shape[1])
    com_phase = np.arctan2(np.sum(np.sin(base) * traces_sorted, 1), 
                           np.sum(np.cos(base) * traces_sorted, 1)) 
    
    return perm, com_phase

In [763]:
plt.close("all")
path = Path("/Users/luigipetrucco/Desktop/source_data_batch1/210728_f1_cwccw")
traces = fl.load(path / "filtered_traces.h5", "/detr")

reg_df = fl.load(path / "motor_regressors.h5")
cc_motor = reg_df["all_bias_abs"].values
cc_motor_integr = reg_df["all_bias_abs_dfdt"].values

df = fl.load(path / "bouts_df.h5")# exp.get_bout_properties()
exp = LightsheetExperiment(path)
fn = int(exp.fn)
beh_df = exp.behavior_log

In [764]:
sel = (df["t_start"] > 10) & (df["t_start"] < 2020)
plt.figure(figsize=(5, 2))
plt.hist(df["bias"], 10, density=True, zorder=-100)
plt.scatter(df["bias"], df["peak_vig"], 100, c=cols[1])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

*c* argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with *x* & *y*.  Please use the *color* keyword-argument or provide a 2D array with a single row if you intend to specify the same RGB or RGBA value for all points.


<matplotlib.collections.PathCollection at 0x7fdff3c1f220>

In [765]:
t_lims = (1000, 5100)
t_slice = slice(*t_lims)

corr_mat = np.corrcoef(traces[t_slice, :].T)
selection_arr = np.zeros(traces.shape[1])

f = plt.figure(figsize=(3, 3))
x = np.arange(-0.2, cc_motor.max(), 0.05)
s = plt.scatter(cc_motor, cc_motor_integr, s=10, c=selection_arr, vmin=0, vmax=1)

l = plt.plot(x, x*0.2 + 0.15)
l_max = plt.axvline(1)
l_min = plt.axhline(0)

@widgets.interact(c=(0.05, 2, 0.05), o=(-0.5, 1, 0.02), mot_max=(0, 1, 0.05),
                 integr_min=(0, 1, 0.05), max_corr=(-1, 0, 0.05))
def update(o=0.15, c=0.2, mot_max=1, integr_min=0, max_corr=-1):
    l[0].set_data(x, x*c + o)
    print(cc_motor_integr.shape, cc_motor.shape, (cc_motor*c + o).shape)
    selection_arr[:] = (cc_motor_integr > cc_motor*c + o) & \
                       (np.abs(cc_motor) < mot_max) & \
                       (np.abs(cc_motor_integr)> integr_min) | \
                       ((np.min(corr_mat, 0) < max_corr) & (np.abs(cc_motor) < mot_max)) 
    l_max.set_xdata(mot_max)
    l_min.set_ydata(integr_min)

    
    s.set_array(selection_arr)

plt.ylim(-0.15, 0.4)
plt.xlim(-0.3, 1.01)
plt.xlabel("cc. traces - motor regressor")
plt.ylabel("cc. d(traces)/dt - regressor")
sns.despine()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

interactive(children=(FloatSlider(value=0.15, description='o', max=1.0, min=-0.5, step=0.02), FloatSlider(valu…

In [774]:
selected = np.argwhere(selection_arr)[:, 0]
#selected = fl.load(path / "selected.h5")

pcaed_t, phase_t = pca_and_phase(traces[t_slice, selected].T, traces[t_slice, selected].T)

plt.figure(figsize=(7, 3))
thr = 35
sel = (pcaed_t[:, 0]**2+pcaed_t[:, 1]**2)**(1/2) > thr
plt.scatter(pcaed_t[:, 0], pcaed_t[:, 1], c=sel)
plt.axis("equal")

#selected = selected[sel]
# pcaed, phase = pca_and_phase(traces[t_slice, selected], traces[:, selected])
#pcaed_spont, phase_spont = pca_and_phase(traces[t_slice, selected], traces[t_slice, selected])
pcaed, phase = pca_and_phase(traces[2000:8000, selected], traces[:, selected])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [775]:
plt.figure(figsize=(7, 2.5))
plt.plot(traces[:, selected] + 4)
print(len(selected))
plt.plot(beh_df["t"]*5, beh_df["tail_sum"])
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

106


In [776]:
# plt.close("all")
mot_t_slice = slice(5100, 10200)
f, axs = plt.subplots(1, 3, figsize=(7., 3.), sharex=True, sharey=True)
# plt.subplot(1,2,1)
# phase = np.angle((pcaed[:, 0]) + 1j * (pcaed[:, 1]))
axs[0].plot(pcaed[t_slice, 0], pcaed[t_slice, 1], 
         c=(0.6,)*3, lw=0.5, zorder=-100) # , c=phase, cmap="twilight", lw=3)
axs[0].scatter(pcaed[t_slice, 0], pcaed[t_slice, 1], 
                 c=phase[t_slice], lw=0.5, s=5, cmap="twilight",) 

axs[1].plot(pcaed[mot_t_slice, 0], pcaed[mot_t_slice, 1], 
         c=(0.6,)*3, lw=0.5, zorder=-100) # , c=phase, cmap="twilight", lw=3)
axs[1].scatter(pcaed[mot_t_slice, 0], pcaed[mot_t_slice, 1], 
                 c=phase[mot_t_slice], lw=0.5, s=5, cmap="twilight",) 

axs[2].plot(pcaed[:, 0], pcaed[:, 1], 
         c=(0.6,)*3, lw=0.5, zorder=-100) # , c=phase, cmap="twilight", lw=3)
axs[2].scatter(pcaed[:, 0], pcaed[:, 1], 
                 c=phase[:], lw=0.5, s=5, cmap="twilight",) 
sns.despine()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [777]:
perm, com_phase = qap_sorting_and_phase(traces[:, selected], t_lims=(1000, 5100))

In [778]:
plt.figure(figsize=(7, 3))
plt.subplot(121)
plt.imshow(np.corrcoef(traces[1000:5100, selected].T)[perm, :][:, perm], 
           vmax=1, vmin=-1, cmap="RdBu_r")

plt.subplot(122)
plt.imshow(traces[:, selected[perm]].T, cmap="gray_r", aspect="auto")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x7fdff1f171f0>

In [781]:
unwrapped_phase = np.unwrap(phase)
unwrapped_com_phase = np.unwrap(com_phase)

traj, params = fictive_trajectory_and_fit(unwrapped_phase, df, min_bias=0.01)
print(params)

plt.figure(figsize=(7, 3))
#plt.scatter(np.arange(len(traj[:])), -zscore(unwrapped_phase), 
#            c=phase[:], cmap="twilight", s=2)
plt.scatter(np.arange(len(traj[:])), -zscore(unwrapped_com_phase), 
            c=com_phase[:], cmap="twilight", s=0.2)
plt.plot(zscore(traj), c=cols[1])

[13.48404596 -2.068904  ]


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

[<matplotlib.lines.Line2D at 0x7fdff0a48790>]

In [782]:
coords = fl.load(path / "data_from_suite2p_unfiltered.h5", "/coords")

In [783]:
plt.figure(figsize=(4, 4))
plt.scatter(coords[:, 1], coords[:, 2], c=(0.9,)*3)
plt.scatter(coords[selected, 1], coords[selected, 2],
            c=np.linspace(-np.pi, np.pi, perm.max()+1)[np.argsort(perm)] , cmap="twilight")
plt.axis("equal")
plt.axis("off")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

*c* argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with *x* & *y*.  Please use the *color* keyword-argument or provide a 2D array with a single row if you intend to specify the same RGB or RGBA value for all points.


(7.40472074468085, 291.37586436170216, 5.819523809523808, 283.56777777777774)

In [366]:
fl.save(path / "selected.h5", selected)

In [77]:
from sklearn.decomposition import PCA
from circle_fit import hyper_fit
comp0, comp1 = 0, 1

traces_fit = traces[2000:8000, selected].T
traces_transform = traces_fit
if traces_transform is None:
    traces_transform = traces_fit

# Compute PCA and transform traces:
pca = PCA(n_components=5).fit(traces_fit)
pcaed_t = pca.transform(traces_transform)

# Fit circle:
hf_c = hyper_fit(pcaed[:, [comp0, comp1]])

# Compute phase, after subtracting center of the circle
phase_t = np.angle((pcaed_t[:, 0] - hf_c[0]) + 1j * (pcaed_t[:, 1] - hf_c[1]))

plt.figure(figsize=(7, 3))
plt.scatter(pcaed_t[:, 0], pcaed_t[:, 1], c=phase_t, cmap="twilight")
plt.axis("equal")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(-89.43453865051269, 93.88069801330566, -60.719933319091794, 75.67397384643554)

In [78]:
# plt.close("all")
plot_t_s = 10
plot_t_pts = int(plot_t_s * fn)

rot_wnd_s = 1
rot_wnd_pts = int(rot_wnd_s * fn)

f, axs = plt.subplots(1, 3, figsize=(8, 3), sharex=True, sharey=True)
# [ax.set_xlabel("PC 1") for ax in axs]
# axs[0].set_ylabel("PC 2")

cbars = []
for x, ax, idx_list, cmap, title in zip(np.arange(3), 
                                     axs, 
                                     [idx_l, idx_r, random_trig], 
                                     ["Blues", "Reds", "gray_r"],
                                     ["Left bouts", "Right bouts", "Shuffle"]):   
    for l in idx_list[:]:
        try:
            crop_seg = pcaed[l:l+plot_t_pts, :2]
            crop_seg = crop_seg - crop_seg[0, :]
            th = np.arctan2(crop_seg[rot_wnd_pts, 0] - crop_seg[0, 0], 
                            crop_seg[rot_wnd_pts, 1] - crop_seg[0, 1])

            rot_mat = np.array([[np.cos(th), np.sin(th)], [-np.sin(th), np.cos(th)]]).T 
            crop_seg = (rot_mat @ crop_seg[:, :].T).T
            ax.plot(crop_seg[::3, 0], crop_seg[::3, 1], c=(0.4,)*3, lw=0.3, zorder=-100)
            ax.scatter(crop_seg[::3, 0], crop_seg[::3, 1], c=np.arange(len(crop_seg[::3, 1])), cmap=cmap, s=1)
            #cp = color_plot(crop_seg[::3, 0], crop_seg[::3, 1], ax=ax,
             #          cmap=cmap, lw=1)
        except IndexError:
            pass
    cbars.append(add_cbar((0.93, 0.8 + 0.026*x, 0.06, 0.023), cp, label="time (s)" if x==2 else "", 
                          ticks=[], orientation="horizontal"))
    ax.set_title(title)
    # elif x == 1:
    # elif x == 2:
    ax.axvline(0, lw=0.5, c=(0.4,)*3)
    ax.axhline(0, lw=0.5, c=(0.4,)*3)

plt.tight_layout()
sns.despine()

cbars[0].set_ticks([0, 0.95])
cbars[0].set_ticklabels([0, plot_t_s])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

NameError: name 'idx_l' is not defined

In [730]:
plt.close("all")
plot_t_s = 8
plot_t_pts = int(plot_t_s * fn)

f, axs = plt.subplots(1, 3, figsize=(8, 3), sharex=True, sharey=True)
[ax.set_xlabel("PC 1") for ax in axs]
axs[0].set_ylabel("PC 2")

cbars = []
for x, ax, idx_list, cmap, title in zip(np.arange(3), 
                                     axs, 
                                     [idx_l, idx_r, random_trig], 
                                     ["Blues", "Reds", "gray_r"],
                                     ["Left bouts", "Right bouts", "Shuffle"]):   
    for l in idx_list:
        try:
            crop_seg = pcaed[l:l+plot_t_pts, :2]
            # crop_seg = crop_seg - crop_seg[0, :]
            #cp = color_plot(crop_seg[::3, 0], crop_seg[::3, 1], ax=ax,
            #           cmap=cmap, lw=1)
            ax.plot(crop_seg[::3, 0], crop_seg[::3, 1], c=(0.4,)*3, lw=0.3, zorder=-100)
            ax.scatter(crop_seg[::3, 0], crop_seg[::3, 1], c=np.arange(len(crop_seg[::3, 1])), cmap=cmap, s=1)
            
        except (IndexError, ValueError):
            pass
    cbars.append(add_cbar((0.92, 0.8 + 0.026*x, 0.06, 0.023), cp, label="time (s)" if x==2 else "", 
                          ticks=[], orientation="horizontal"))
    ax.set_title(title)
    # elif x == 1:
    # elif x == 2:
        
plt.tight_layout()
sns.despine()

cbars[0].set_ticks([0, 0.8])
cbars[0].set_ticklabels([0, plot_t_s])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  plt.tight_layout()
  plt.tight_layout()


In [501]:
path

PosixPath('/Volumes/Shared/experiments/E0040_motions_cardinal/v15_playback/210511_f3b_natmov')

In [121]:
fl.save(path / "selected_new.h5", selected)

In [None]:
master_path = Path("/Volumes/Shared/experiments/E0040_motions_cardinal_old/v15_playback")
all_list = list(master_path.glob("[0-9]*_f[0-9]*"))
all_valid = list(master_path.glob("[0-9]*_f[0-9]*/selected.h5"))
print(f"{len(all_valid)}/{len(all_list)}")

In [1160]:
plt.close("all")
plot_t_s = 8
plot_t_pts = int(plot_t_s * fn)

f, axs = plt.subplots(1, 3, figsize=(8, 3), sharex=True, sharey=True)
[ax.set_xlabel("PC 1") for ax in axs]
axs[0].set_ylabel("PC 2")

cbars = []
for x, ax, idx_list, cmap, title in zip(np.arange(3), 
                                     axs, 
                                     [idx_l, idx_r, random_trig], 
                                     ["Blues", "Reds", "gray_r"],
                                     ["Left bouts", "Right bouts", "Shuffle"]):   
    for l in idx_list:
        try:
            crop_seg = pcaed[l:l+plot_t_pts, :2]
            # crop_seg = crop_seg - crop_seg[0, :]
            cp = color_plot(crop_seg[:, 0], crop_seg[:, 1], ax=ax,
                       cmap=cmap, lw=1)
        except IndexError:
            pass
    cbars.append(add_cbar((0.92, 0.8 + 0.026*x, 0.06, 0.023), cp, label="time (s)" if x==2 else "", 
                          ticks=[], orientation="horizontal"))
    ax.set_title(title)
    # elif x == 1:
    # elif x == 2:
        
plt.tight_layout()
sns.despine()

cbars[0].set_ticks([0, 0.8])
cbars[0].set_ticklabels([0, plot_t_s])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  plt.tight_layout()
  plt.tight_layout()
