## Motor-selective neurons vs. heading direction neurons
In this notebook, we compare the activity/location of neurons that are left/right motion selective, and activity of heading direction selective neurons.

#### TODO
 - [ ] make map including all fish after the morphing 

In [None]:
%matplotlib widget
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

from lotr import DATASET_LOCATION, LotrExperiment
from bouter.utilities import crop

In [None]:
master_path = Path(DATASET_LOCATION)
file_list = list(master_path.glob("*/*[0-9]_f*"))

In [None]:
path = master_path / "210314_f1" / "210314_f1_natmov"
exp = LotrExperiment(path)

In [None]:
bouts_df = exp.bouts_df
regr_df = exp.motor_regressors

THR = 0.4
OFF = 0.45
directions = "lf", "rt"
bout_sel = dict(rt=(regr_df["right_1"] > (regr_df["left_1"] + OFF)) & (regr_df["right_1"] > THR),
                lf=(regr_df["left_1"] > (regr_df["right_1"] + OFF))  & (regr_df["left_1"] > THR))

In [None]:
f, axs = plt.subplots(1,2, figsize=(6, 3))
s = 15
for i, coords in enumerate([[regr_df["right_1"], regr_df["left_1"]],
                           [exp.coords[:, 1], exp.coords[:, 2]]]):
    axs[i].scatter(coords[0], coords[1], c=(0.8,)*3, s=s)
    axs[i].scatter(coords[0][exp.hdn_indexes], coords[1][exp.hdn_indexes], fc=(0.8,)*3, ec=(0.3,)*3,
                  lw=0.5, s=s, label="HDNs")

    for k, sel in bout_sel.items():
        axs[i].scatter(coords[0][sel], coords[1][sel], s=s, label=k + "_sel")
    axs[i].axis("equal")


axs[0].plot([0, 0.5], [0, 0.5], "k", lw=0.5)
axs[0].legend(frameon=False, fontsize=7)
axs[0].set(xlabel="Left bouts corr.", ylabel="Right bouts corr.")
axs[1].axis("off")
plt.tight_layout()

In [None]:
exp.hdn_indexes

In [None]:
pre_wnd_s = 10
post_wnd_s = 20

traces_resps = dict()
for direction in directions:
    idx = exp.bouts_df.loc[exp.bouts_df["direction"] == direction, "idx_imaging"]
    cropped = crop(exp.raw_traces, idx, 
                        pre_int=int(pre_wnd_s * exp.fn), post_int=int(post_wnd_s * exp.fn))
    cropped = np.nanmean(cropped, 1)
    
    cropped = cropped - np.nanmean(cropped[:int(pre_wnd_s * exp.fn), :], 0)
    
    traces_resps[direction] = cropped

In [None]:
sortings = {sel:np.argsort(traces_resps[d][50:80, bout_sel[sel]].mean(0)) 
            for d, sel in zip(directions, directions[::-1])}

In [None]:
f, axs = plt.subplots(2, 2, figsize=(8, 3))
for i, d in enumerate(directions):
    for j, sel in enumerate(directions):
        axs[j, i].plot(traces_resps[d][:, bout_sel[sel]][:, sortings[sel]])

In [None]:
f, axs = plt.subplots(1, 2, figsize=(8, 3))
for i, (d, sel) in enumerate(zip(directions, directions[::-1])):
    axs[i].scatter(exp.coords[:, 1], exp.coords[:, 2],
                  c=(0.6,)*3)
    axs[i].scatter(exp.coords[bout_sel[d], 1], exp.coords[bout_sel[d], 2],
                  c=traces_resps[d][80:100, bout_sel[d]].mean(0))
    
    axs[i].axis("equal")

In [None]:
from lotr.pca import pca_and_phase

In [None]:
pcaed, phase, _ = pca_and_phase(exp.raw_traces[:, bout_sel["lf"] | bout_sel["rt"]])

In [None]:
plt.figure()
plt.plot(pcaed[:, 0], pcaed[:, 1])