In [None]:
%matplotlib widget
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

from lotr.experiment_class import LotrExperiment
from bouter.utilities import crop

In [None]:
master_path = Path("/Users/luigipetrucco/Desktop/all_source_data/full_ring")
file_list = list(master_path.glob("*/*[0-9]_f*"))

In [None]:
path = master_path / "210314_f1" / "210314_f1_natmov"
exp = LotrExperiment(path)

In [None]:
traces = exp.traces
bouts_df = exp.bouts_df
regr_df = exp.motor_regressors

THR = 0.0
OFF = 0.4
directions = "lf", "rt"
bout_sel = dict(rt=(regr_df["right_1"] > (regr_df["left_1"] + OFF)) & (regr_df["right_1"] > THR),
                lf=(regr_df["left_1"] > (regr_df["right_1"] + OFF))  & (regr_df["left_1"] > THR))

In [None]:
f, axs = plt.subplots(1,2, figsize=(8, 3))

for i, coords in enumerate([[regr_df["right_1"], regr_df["left_1"]],
                           [exp.coords[:, 1], exp.coords[:, 2]]]):
    axs[i].scatter(coords[0], coords[1], c=(0.6,)*3)

    for k, sel in bout_sel.items():
        axs[i].scatter(coords[0][sel], coords[1][sel])
    axs[i].axis("equal")


axs[0].plot([0, 0.5], [0, 0.5], "k")

In [None]:
pre_wnd_s = 10
post_wnd_s = 20

traces_resps = dict()
for direction in directions:
    idx = exp.bouts_df.loc[exp.bouts_df["direction"] == direction, "idx_imaging"]
    cropped = crop(exp.traces, idx, 
                        pre_int=int(pre_wnd_s * exp.fn), post_int=int(post_wnd_s * exp.fn))
    cropped = np.nanmean(cropped, 1)
    
    cropped = cropped - np.nanmean(cropped[:int(pre_wnd_s * exp.fn), :], 0)
    
    traces_resps[direction] = cropped

In [None]:
sortings = {sel:np.argsort(traces_resps[d][50:80, bout_sel[sel]].mean(0)) 
            for d, sel in zip(directions, directions[::-1])}

In [None]:
f, axs = plt.subplots(2, 2, figsize=(8, 3))
for i, d in enumerate(directions):
    for j, sel in enumerate(directions):
        axs[j, i].imshow(traces_resps[d][:, bout_sel[sel]][:, sortings[sel]].T)

In [None]:
f, axs = plt.subplots(1, 2, figsize=(8, 3))
for i, (d, sel) in enumerate(zip(directions, directions[::-1])):
    axs[i].scatter(exp.coords[:, 1], exp.coords[:, 2],
                  c=(0.6,)*3)
    axs[i].scatter(exp.coords[bout_sel[d], 1], exp.coords[bout_sel[d], 2],
                  c=traces_resps[d][50:80, bout_sel[d]].mean(0))
    
    axs[i].axis("equal")