In [None]:
%matplotlib widget

import warnings
from pathlib import Path

import flammkuchen as fl
import numpy as np
import pandas as pd
import seaborn as sns
from fimpylab import LightsheetExperiment
from matplotlib import pyplot as plt
from tqdm import tqdm

sns.set(style="ticks", palette="deep")
cols = sns.color_palette()
import ipywidgets as widgets
from circle_fit import hyper_fit
from lotr import A_FISH, LotrExperiment
from lotr.pca import (
    fictive_heading_and_fit,
    fit_phase_neurons,
    get_fictive_heading,
    pca_and_phase,
    qap_sorting_and_phase,
)
from lotr.plotting import COLS
from lotr.utils import zscore

In [None]:
plt.close("all")
path = Path(
    "/Volumes/Shared/experiments/E0040_motions_cardinal/v26/220420_f2_2amp_2per"
)
traces = fl.load(path / "filtered_traces.h5", "/detr")

reg_df = fl.load(path / "motor_regressors.h5")
cc_motor = reg_df["all_bias_abs"].values
cc_motor_integr = reg_df["all_bias_abs_dfdt"].values
coords = fl.load(path / "data_from_suite2p_unfiltered.h5", "/coords")
anat = fl.load(path / "data_from_suite2p_unfiltered.h5", "/anatomy_stack")
traces[np.isnan(traces)] = 0

df = fl.load(path / "bouts_df.h5")  # exp.get_bout_properties()
exp = LotrExperiment(path)
fn = int(exp.fn)
beh_df = exp.behavior_log

t_start_s = 150
time_slices_dict = {
    "natmov": (t_start_s * exp.fn, exp.n_pts - t_start_s * exp.fn),
    "clol": (t_start_s * exp.fn, exp.n_pts - t_start_s * exp.fn),
    "jumps": (t_start_s * exp.fn, exp.n_pts - t_start_s * exp.fn),
    "cl": (t_start_s * exp.fn, exp.n_pts - t_start_s * exp.fn),
    "cwccw": (500, exp.n_pts // 2),
    "2dvr": (t_start_s * exp.fn, 2000 * exp.fn),
    "f6": (t_start_s * exp.fn, 2000 * exp.fn),
    "2d": (t_start_s * exp.fn, 2000 * exp.fn),
    "spont": (t_start_s * exp.fn, exp.n_pts - t_start_s * exp.fn),
    "gainmod": (t_start_s * exp.fn, exp.n_pts - t_start_s * exp.fn),
}

In [None]:
t_lims = (
    t_start_s * exp.fn,
    2000 * exp.fn,
)  # time_slices_dict[path.name.split("_")[-1]]
t_slice = slice(*t_lims)

In [None]:
plt.figure(figsize=(5, 2))
plt.hist(df["bias"], 10, density=True, zorder=-100)
plt.scatter(df["bias"], df["peak_vig"], 100, c=cols[1])
plt.ylim(0.0, 3)

In [None]:
plt.figure(figsize=(3, 3))
plt.imshow(anat.mean(0), vmax=100, vmin=0)
# plt.scatter(coords[:, 1], coords[:, 2], c=(0.9,)*3)
s1 = 100
s2 = 270
plt.axvline(s1)
plt.axvline(s2)

In [None]:
sel_to_nan = (coords[:, 2] < s1) | (coords[:, 2] > s2)
traces[:, sel_to_nan] = 0
cc_motor[sel_to_nan] = np.nan
cc_motor_integr[sel_to_nan] = np.nan

In [None]:
warnings.filterwarnings("ignore")

cc_wnd = 2000
i_array = np.arange(t_slice.start, t_slice.stop, cc_wnd * fn)
cc_mats = np.zeros((traces.shape[1], traces.shape[1], len(i_array)))

for n, i in enumerate(i_array):
    cc_mats[:, :, n] = np.corrcoef(traces[i : i + cc_wnd * fn, :].T)
corr_mat = np.nanmean(cc_mats, 2)

# corr_mat = np.corrcoef(traces[t_lims[0] : t_lims[1], :].T)

selection_arr = np.zeros(traces.shape[1])

f = plt.figure(figsize=(3, 3))
x = np.arange(-0.2, np.nanmax(cc_motor), 0.05)
s = plt.scatter(cc_motor, cc_motor_integr, s=10, c=selection_arr, vmin=0, vmax=1)

l_plot = plt.plot(x, x * 0.2 + 0.15)
l_max = plt.axvline(1)
l_min = plt.axhline(0)


@widgets.interact(
    c=(0.05, 2, 0.05),
    o=(-0.5, 1, 0.02),
    mot_max=(0, 1, 0.05),
    integr_min=(0, 1, 0.02),
    max_corr=(-1, 0, 0.05),
)
def update(o=0.3, c=0.2, mot_max=1, integr_min=0, max_corr=-0.75):
    l_plot[0].set_data(x, x * c + o)
    print(cc_motor_integr.shape, cc_motor.shape, (cc_motor * c + o).shape)
    selection_arr[:] = (cc_motor_integr > cc_motor * c + o) & (
        np.abs(cc_motor) < mot_max
    ) & (np.abs(cc_motor_integr) > integr_min) | (
        (np.nanmin(corr_mat, 0) < max_corr)
        & (np.abs(cc_motor) < mot_max)
        & (np.abs(cc_motor_integr) > integr_min)
    )
    l_max.set_xdata(mot_max)
    l_min.set_ydata(integr_min)

    s.set_array(selection_arr)


plt.ylim(-0.15, 0.4)
plt.xlim(-0.3, 1.01)
plt.xlabel("cc. traces - motor regressor")
plt.ylabel("cc. d(traces)/dt - regressor")
sns.despine()

In [None]:
selected = np.argwhere(selection_arr)[:, 0]
print(len(selected))

In [None]:
# selected = fl.load(path / "selected.h5")
pcaed_t, phase_t, _, _ = pca_and_phase(
    traces[t_slice, selected].T, traces[t_slice, selected].T
)
hf_c = hyper_fit(pcaed_t)
pcaed_t_all, _, _, _ = pca_and_phase(traces[t_slice, selected].T, traces[t_slice, :].T)


plt.figure(figsize=(7, 3))
thr = 35
sel = (pcaed_t[:, 0] ** 2 + pcaed_t[:, 1] ** 2) ** (1 / 2) > thr
plt.scatter(pcaed_t[:, 0], pcaed_t[:, 1], c=sel)
plt.scatter(
    pcaed_t_all[:, 0], pcaed_t_all[:, 1], edgecolor="k", facecolor="none", lw=0.2
)
plt.axis("equal")

# selected = selected[sel]
# pcaed, phase = pca_and_phase(traces[t_slice, selected], traces[:, selected])
# pcaed_spont, phase_spont = pca_and_phase(traces[t_slice, selected], traces[t_slice, selected])
pcaed, phase, _, _ = pca_and_phase(traces[t_slice, selected], traces[:, selected])

x1 = hf_c[2] * np.cos(np.linspace(0, 2 * np.pi, 100)) + hf_c[0]
x2 = hf_c[2] * np.sin(np.linspace(0, 2 * np.pi, 100)) + hf_c[1]

plt.plot(x1, x2)

new_selection_arr = (
    np.abs(
        np.sqrt((pcaed_t_all[:, 0] - hf_c[0]) ** 2 + (pcaed_t_all[:, 1] - hf_c[1]) ** 2)
        - hf_c[2]
    )
    < 20
) | (
    (
        np.sqrt((pcaed_t_all[:, 0] - hf_c[0]) ** 2 + (pcaed_t_all[:, 1] - hf_c[1]) ** 2)
        - hf_c[2]
    )
    > 0
)
selected = np.argwhere(new_selection_arr)[:, 0]

In [None]:
plt.figure(figsize=(7, 2.5))
plt.plot(traces[:, selected] + 4)
print(len(selected))
plt.plot(beh_df["t"] * fn, beh_df["tail_sum"])
plt.show()

path = Path(
    "/Volumes/Shared/experiments/E0071_lotr/full_ring/211118_f0/211118_f0b_spont"
)

exp = LotrExperiment(path, selected=None)
perm = np.argsort(exp.rpc_angles)
phase = exp.network_phase
selected = exp.hdn_indexes

In [None]:
perm = np.argsort(phase_t)
exp._hdn_indexes = selected
exp._rpc_angles = phase_t

In [None]:
l = 2
f, axs = plt.subplots(1, 2, figsize=(7, 3), sharey=True)
# plt.subplot(121)
if perm is not None:
    axs[0].imshow(
        np.corrcoef(exp.traces[t_slice, selected].T)[perm, :][:, perm],
        vmax=1,
        vmin=-1,
        cmap="RdBu_r",
        aspect="auto",
    )

    axs[1].imshow(
        exp.traces[:, selected[perm]].T,
        cmap="gray_r",
        interpolation="none",
        aspect="auto",
        vmin=-l,
        vmax=l,
    )

In [None]:
# old_selected = selected.copy()
rm_from_selected = np.array([])
selected[perm[rm_from_selected]] = -1
##for i in s:# [23,  64,  82, 110, 152, 193,  87, 127, 143,  57,  33, 226, 201,  89, 155,  92,  34]:
#    if i is not " " and i is not "":
#        selected[perm == int(i)] = -1
selected = selected[selected > 0]

In [None]:
unwrapped_phase = np.unwrap(phase)


traj, params = fictive_heading_and_fit(unwrapped_phase, df, min_bias=0.6)
print(params)
plt.figure(figsize=(7, 3))
plt.scatter(df["t_start"] * exp.fs, df["bias"], s=2)
plt.scatter(
    np.arange(len(traj[:])), zscore(unwrapped_phase), c=phase[:], cmap="twilight", s=2
)
plt.plot(-zscore(traj), c=cols[1])

In [None]:
f, ax = plt.subplots(1, 1, figsize=(3, 3))

ax.scatter(exp.coords[:, 1], exp.coords[:, 2], color=(0.5,) * 3)
ax.scatter(
    exp.coords[exp.hdn_indexes, 1],
    exp.coords[exp.hdn_indexes, 2],
    c=exp.rpc_angles,
    cmap=COLS["phase"],
)
ax.axis("equal")
ax.axis("off")

In [None]:
f, axs = plt.subplots(1, 1, figsize=(3, 3))
s = coords[:, 0] > 0
selection = np.full(coords.shape[0], False)
selection[selected] = True
all_phases = np.zeros(coords.shape[0])
all_phases[selected] = phases_neuron

all_perm = -np.ones(coords.shape[0])
all_perm[selected] = perm

axs[0].scatter(coords[s, 1], coords[s, 2], c=(0.5,) * 3)
axs[0].scatter(
    coords[s, :][selection[s], 1],
    coords[s, :][selection[s], 2],
    c=exp.rpc_angles,
    cmap="twilight",
)
axs[0].axis("equal")
axs[0].axis("off")

In [None]:
plt.figure(figsize=(8, 3.5))
plt.subplot(221)
# sel = # (coords[selected, 1] < 50 ) #& (coords[selected, 1] < 209) & \
# (coords[selected, 2] > 206 ) & (coords[selected, 2] < 250)
sel = coords[selected, 2] < 167
plt.scatter(
    coords[selected, 1], np.linspace(-np.pi, np.pi, perm.max() + 1)[np.argsort(perm)]
)
plt.subplot(222)
plt.scatter(coords[selected, 1], perm)

plt.subplot(223)
plt.scatter(
    coords[selected[sel], 1],
    np.linspace(-np.pi, np.pi, perm.max() + 1)[np.argsort(perm)][sel],
)
plt.subplot(224)
plt.scatter(coords[selected[sel], 1], perm[sel])
print(perm[sel])

In [None]:
fl.save(path / "selected.h5", selected)

In [None]:
# plt.close("all")
# t_slice = slice(0, )
pcaed, phase, _, _ = pca_and_phase(traces[t_slice, selected], traces[:, selected])
mot_t_slice = slice(traces.shape[0] // 2, traces.shape[0])
f, axs = plt.subplots(1, 3, figsize=(9.0, 4.0), sharex=True, sharey=True)
# plt.subplot(1,2,1)
# phase = np.angle((pcaed[:, 0] - 2) + 1j * (pcaed[:, 1] +5))
for i, s in enumerate([t_slice, mot_t_slice, t_slice]):

    axs[i].plot(
        pcaed[s, 0], pcaed[s, 1], c=(0.6,) * 3, lw=0.5, zorder=-100
    )  # , c=phase, cmap="twilight", lw=3)
    axs[i].scatter(
        pcaed[s, 0], pcaed[s, 1], c=phase[s], lw=0.5, s=5, cmap="twilight",
    )
# plt.axis("equal")
sns.despine()

In [None]:
from circle_fit import hyper_fit
from sklearn.decomposition import PCA

comp0, comp1 = 0, 1

traces_fit = traces[2000:8000, selected].T
traces_transform = traces_fit
if traces_transform is None:
    traces_transform = traces_fit

# Compute PCA and transform traces:
pca = PCA(n_components=5).fit(traces_fit)
pcaed_t = pca.transform(traces_transform)

# Fit circle:
hf_c = hyper_fit(pcaed[:, [comp0, comp1]])

# Compute phase, after subtracting center of the circle
phase_t = np.angle((pcaed_t[:, 0] - hf_c[0]) + 1j * (pcaed_t[:, 1] - hf_c[1]))

plt.figure(figsize=(7, 3))
plt.scatter(pcaed_t[:, 0], pcaed_t[:, 1], c=phase_t, cmap="twilight")
plt.axis("equal")

In [None]:
# plt.close("all")
plot_t_s = 10
plot_t_pts = int(plot_t_s * fn)

rot_wnd_s = 1
rot_wnd_pts = int(rot_wnd_s * fn)

f, axs = plt.subplots(1, 3, figsize=(8, 3), sharex=True, sharey=True)
# [ax.set_xlabel("PC 1") for ax in axs]
# axs[0].set_ylabel("PC 2")

cbars = []
for x, ax, idx_list, cmap, title in zip(
    np.arange(3),
    axs,
    [idx_l, idx_r, random_trig],
    ["Blues", "Reds", "gray_r"],
    ["Left bouts", "Right bouts", "Shuffle"],
):
    for l in idx_list[:]:
        try:
            crop_seg = pcaed[l : l + plot_t_pts, :2]
            crop_seg = crop_seg - crop_seg[0, :]
            th = np.arctan2(
                crop_seg[rot_wnd_pts, 0] - crop_seg[0, 0],
                crop_seg[rot_wnd_pts, 1] - crop_seg[0, 1],
            )

            rot_mat = np.array([[np.cos(th), np.sin(th)], [-np.sin(th), np.cos(th)]]).T
            crop_seg = (rot_mat @ crop_seg[:, :].T).T
            ax.plot(
                crop_seg[::3, 0], crop_seg[::3, 1], c=(0.4,) * 3, lw=0.3, zorder=-100
            )
            ax.scatter(
                crop_seg[::3, 0],
                crop_seg[::3, 1],
                c=np.arange(len(crop_seg[::3, 1])),
                cmap=cmap,
                s=1,
            )
            # cp = color_plot(crop_seg[::3, 0], crop_seg[::3, 1], ax=ax,
            #          cmap=cmap, lw=1)
        except IndexError:
            pass
    cbars.append(
        add_cbar(
            (0.93, 0.8 + 0.026 * x, 0.06, 0.023),
            cp,
            label="time (s)" if x == 2 else "",
            ticks=[],
            orientation="horizontal",
        )
    )
    ax.set_title(title)
    # elif x == 1:
    # elif x == 2:
    ax.axvline(0, lw=0.5, c=(0.4,) * 3)
    ax.axhline(0, lw=0.5, c=(0.4,) * 3)

plt.tight_layout()
sns.despine()

cbars[0].set_ticks([0, 0.95])
cbars[0].set_ticklabels([0, plot_t_s])

In [None]:
plt.close("all")
plot_t_s = 8
plot_t_pts = int(plot_t_s * fn)

f, axs = plt.subplots(1, 3, figsize=(8, 3), sharex=True, sharey=True)
[ax.set_xlabel("PC 1") for ax in axs]
axs[0].set_ylabel("PC 2")

cbars = []
for x, ax, idx_list, cmap, title in zip(
    np.arange(3),
    axs,
    [idx_l, idx_r, random_trig],
    ["Blues", "Reds", "gray_r"],
    ["Left bouts", "Right bouts", "Shuffle"],
):
    for l in idx_list:
        try:
            crop_seg = pcaed[l : l + plot_t_pts, :2]
            # crop_seg = crop_seg - crop_seg[0, :]
            # cp = color_plot(crop_seg[::3, 0], crop_seg[::3, 1], ax=ax,
            #           cmap=cmap, lw=1)
            ax.plot(
                crop_seg[::3, 0], crop_seg[::3, 1], c=(0.4,) * 3, lw=0.3, zorder=-100
            )
            ax.scatter(
                crop_seg[::3, 0],
                crop_seg[::3, 1],
                c=np.arange(len(crop_seg[::3, 1])),
                cmap=cmap,
                s=1,
            )

        except (IndexError, ValueError):
            pass
    cbars.append(
        add_cbar(
            (0.92, 0.8 + 0.026 * x, 0.06, 0.023),
            cp,
            label="time (s)" if x == 2 else "",
            ticks=[],
            orientation="horizontal",
        )
    )
    ax.set_title(title)
    # elif x == 1:
    # elif x == 2:

plt.tight_layout()
sns.despine()

cbars[0].set_ticks([0, 0.8])
cbars[0].set_ticklabels([0, plot_t_s])

In [None]:
path

In [None]:
fl.save(path / "selected.h5", selected)

In [None]:
master_path = Path(
    "/Volumes/Shared/experiments/E0040_motions_cardinal_old/v15_playback"
)
all_list = list(master_path.glob("[0-9]*_f[0-9]*"))
all_valid = list(master_path.glob("[0-9]*_f[0-9]*/selected.h5"))
print(f"{len(all_valid)}/{len(all_list)}")

In [None]:
plt.close("all")
plot_t_s = 8
plot_t_pts = int(plot_t_s * fn)

f, axs = plt.subplots(1, 3, figsize=(8, 3), sharex=True, sharey=True)
[ax.set_xlabel("PC 1") for ax in axs]
axs[0].set_ylabel("PC 2")

cbars = []
for x, ax, idx_list, cmap, title in zip(
    np.arange(3),
    axs,
    [idx_l, idx_r, random_trig],
    ["Blues", "Reds", "gray_r"],
    ["Left bouts", "Right bouts", "Shuffle"],
):
    for l in idx_list:
        try:
            crop_seg = pcaed[l : l + plot_t_pts, :2]
            # crop_seg = crop_seg - crop_seg[0, :]
            cp = color_plot(crop_seg[:, 0], crop_seg[:, 1], ax=ax, cmap=cmap, lw=1)
        except IndexError:
            pass
    cbars.append(
        add_cbar(
            (0.92, 0.8 + 0.026 * x, 0.06, 0.023),
            cp,
            label="time (s)" if x == 2 else "",
            ticks=[],
            orientation="horizontal",
        )
    )
    ax.set_title(title)
    # elif x == 1:
    # elif x == 2:

plt.tight_layout()
sns.despine()

cbars[0].set_ticks([0, 0.8])
cbars[0].set_ticklabels([0, plot_t_s])

In [None]:
data = fl.load(
    r"/Volumes/Shared/experiments/E0040_motions_cardinal/batch211204/tofix/211207_f0beyes_natomov/original/0000.h5"
)

In [None]:
data = fl.load(
    r"/Volumes/Shared/experiments/E0040_motions_cardinal/batch211204/tofix/211207_f0beyes_natomov/original/0000.h5"
)

In [None]:
data["stack_4D"].shape

In [None]:
data = fl.load(
    r"/Volumes/Shared/experiments/E0040_motions_cardinal/batch211204/tofix/211207_f0beyes_natomov/original/0037.h5"
)

In [None]:
data["stack_4D"].shape

In [None]:
236 * 38