In [None]:
%matplotlib widget
%load_ext autoreload
from ipywidgets import interact, interact_manual

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import flammkuchen as fl
import local_utils
from bouter import utilities, decorators, bout_stats
from bouter.angles import reduce_to_pi
from scipy.interpolate import interp1d
from tqdm import tqdm


from matplotlib import pyplot as plt
import seaborn as sns
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()

In [None]:
master_path = Path(r"/Users/luigipetrucco/Google Drive/data/ECs_E50")

exp_df = fl.load(master_path / "exp_df.h5")
trials_df = fl.load(master_path / "trials_df.h5")
cells_df = fl.load(master_path / "cells_df.h5")
bouts_df = fl.load(master_path / "bouts_df.h5")
traces_df = fl.load(master_path / "traces_df.h5")

In [None]:
min_dist_s = 2  # Minimum distance between 2 bouts for inclusion, secs

bouts_df["mindist_included"] = (bouts_df["after_interbout"] > min_dist_s) & (bouts_df["inter_bout"] > min_dist_s)
bouts_df["g0"] = (bouts_df["base_vel"] < 0) & (bouts_df["gain"] == 0)
bouts_df["g1"] = (bouts_df["base_vel"] < 0) & (bouts_df["gain"] == 1)
bouts_df["spont"] = bouts_df["base_vel"] > -10

##############################################
# Match bouts by duration & temporal proximity
# Here we select for each fish a subset of bouts in closed and open loop 
# that had similar duration and occourred reasonable close to each other 
# in the experiment to make sure we can compare responses with and w/o visual 
# reafference.

bout_length_similarity_thr = 0.05
bout_max_timedistance = 600


bouts_df["matched"] = False
for fid in tqdm(exp_df.index):
    common_sel = (bouts_df["fid"]==fid) & (bouts_df["mindist_included"]) & ~bouts_df["spont"]
    for b in bouts_df.loc[(bouts_df["gain"]==1) & common_sel].index:

        time_distances = np.abs(bouts_df.loc[bouts_df["fid"]==fid, "t_start"] - bouts_df.loc[b, "t_start"])

        # Candidate bouts to match: gain 1, not matched yet, 
        # with minimum spacing from other bouts, and not too far in time:
        selection = (bouts_df["gain"]==0) & ~bouts_df["matched"] \
                    & (time_distances < bout_max_timedistance) &  common_sel

        # Calculate all duration differences 
        diffs = np.abs(bouts_df.loc[selection, "duration"] - bouts_df.loc[b, "duration"])

        # If we have a valid candidate, match it :
        if diffs.min() < bout_length_similarity_thr:
            bouts_df.loc[diffs.sort_values().index[0], "matched"] = True
            bouts_df.loc[b, "matched"] = True

In [None]:
bouts_df[bouts_df["matched"] & bouts_df["g0"]].groupby("fid").sum()

# Test distributions of reliability indexes

In [None]:
# Analysis parameters:
dt = 0.2  # dt of the imaging #TODO have this in exp dictionary
pre_int_s = 2  # time before bout for the crop, secs
post_int_s = 4  # time after the bout for the crop, secs
amplitude_percent = 90  # percentile for the calculation of the response amplitude

In [None]:
fid = exp_df.index[-1]
cells_fsel = cells_df.loc[cells_df["fid"]==fid, :]# .copy()
traces = traces_df.loc[:, cells_fsel.index].copy()

In [None]:
np.random.randint(traces.shape[0])

In [None]:
fid = exp_df.index[-1]
cells_fsel = cells_df.loc[cells_df["fid"]==fid, :]# .copy()
traces = traces_df.loc[:, cells_fsel.index].copy()

n_reps = 20
n_tests = 20
all_shuf = np.zeros((n_reps, n_tests, traces.shape[1]))
for j in tqdm(range(n_reps)):
    for i in range(n_tests):

        sel_start_idxs = np.random.randint(0, traces.shape[0], (i+1)*10)

        # Crop cell responses around bouts:
        cropped = utilities.crop(traces, 
                                 sel_start_idxs, 
                                 pre_int=int(pre_int_s / dt), 
                                 post_int=int(post_int_s / dt))

        all_shuf[j, i, :] = utilities.reliability(cropped)

In [None]:
bins = np.arange(-0.05, 1, 0.01)
all_hist = np.zeros((n_tests, len(bins)-1))
for i in range(n_tests):
    a, b = np.histogram(all_shuf[i, :], bins)
    all_hist[i, :] = a


In [None]:
plt.figure(figsize=(4,3))
for i in range(10):
    plt.fill_between(np.arange(n_tests)*10, np.percentile(all_shuf[i, :, :], 5, axis=1), np.percentile(all_shuf[i, :, :], 95, axis=1), alpha=0.1)
    plt.plot(np.arange(n_tests)*10, np.percentile(all_shuf[i, :, :], 50, axis=1), c=cols[i%10])
plt.xlabel("number of shuffled responses")
plt.ylabel("Reliability score")
sns.despine()
plt.tight_layout()

In [None]:
all_hist

In [None]:
# Analysis parameters:
dt = 0.2  # dt of the imaging #TODO have this in exp dictionary
pre_int_s = 2  # time before bout for the crop, secs
post_int_s = 4  # time after the bout for the crop, secs
amplitude_percent = 90  # percentile for the calculation of the response amplitude

# Widow for nanning out the bout artefacts
wnd_pre_bout_nan_s = 0.2
wnd_post_bout_nan_s = 0.2

min_distance_exclusion = (bouts_df["after_interbout"] > post_int_s) & (bouts_df["inter_bout"] > min_dist_s)

selections_dict = dict(motor=min_distance_exclusion,
                       motor_g0=min_distance_exclusion & (bouts_df["base_vel"] < 0) & (bouts_df["gain"] == 0),
                       motor_g1=min_distance_exclusion & (bouts_df["base_vel"] < 0) & (bouts_df["gain"] == 1),
                       motor_spont=min_distance_exclusion & (bouts_df["base_vel"] > -10))

# 
for val in ["rel", "amp"]:
    for sel in selections_dict.keys():
        column_id = f"{sel}_{val}"
        if column_id not in cells_df.columns:
            cells_df[column_id] = np.nan

            
pre_wnd_bout_nan = int(wnd_pre_bout_nan_s / dt)
post_wnd_bout_nan = int(wnd_post_bout_nan_s / dt)

# Loop over criteria for the different reliabilities:
for selection in selections_dict.keys():
    
    # Loop over fish:
    for fid in tqdm(exp_df.index):
        cells_fsel = cells_df.loc[cells_df["fid"]==fid, :]# .copy()
        traces = traces_df.loc[:, cells_fsel.index].copy()
        
        # Nan all bouts:
        start_idxs = np.round(bouts_df.loc[bouts_df["fid"]==fid, "t_start"] / dt).astype(np.int)
        traces = local_utils.bout_nan_traces(traces.values, start_idxs, 
                                              wnd_pre=pre_wnd_bout_nan,
                                              wnd_post=post_wnd_bout_nan)

        beh_df = fl.load(master_path / "beh_dict.h5", f"/{fid}")
        stim_df = fl.load(master_path / "stim_dict.h5", f"/{fid}")

        sel_bouts = bouts_df[(bouts_df["fid"]==fid) & selections_dict[selection]]
        sel_start_idxs = np.round(sel_bouts["t_start"] / dt).astype(np.int)

        # Crop cell responses around bouts:
        cropped = utilities.crop(traces, 
                                 sel_start_idxs, 
                                 pre_int=int(pre_int_s / dt), 
                                 post_int=int(post_int_s / dt))

        # Subtract pre-bout baseline:
        cropped = cropped - np.nanmean(cropped[:int(pre_int_s / dt), :, :], 0)

        # Calculate reliability indexes:
        reliabilities = utilities.reliability(cropped)

        # Calculate mean response for all cells:
        mean_resps = np.nanmean(cropped, 1)

        # Calculate amplitude of the response looking at top 20% percentile of the response
        # (response is normalized at pre-stim onset):
        amplitudes = local_utils.max_amplitude_resp(mean_resps, 
                                                    percentile=amplitude_percent)
        
        cells_df.loc[cells_fsel.index, f"{selection}_rel"] = reliabilities
        cells_df.loc[cells_fsel.index, f"{selection}_amp"] = amplitudes

    # fl.save(master_path / "cells_df.h5", cells_df)

In [None]:
fl.save(master_path / "bouts_df.h5", bouts_df)

# Plots

In [None]:
plt.figure()
sns.violinplot(data=cells_df, y="motor_spont_rel", x="genotype")

In [None]:
plt.figure()
sns.violinplot(data=cells_df, y="motor_spont_amp", x="genotype")

In [None]:
plt.figure()
plt.scatter(cells_df["forward_rel"], cells_df["motor_spont_rel"], s=5)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import mplcursors
np.random.seed(42)

fig, ax = plt.subplots()
ax.scatter(*np.random.random((2, 26)))
ax.set_title("Mouse over a point")

mplcursors.cursor(hover=True)

plt.show()

In [None]:
reliability = utilities.reliability(cropped)

In [None]:
def browse_cells(i=(0, len(cells) - 1)):  
    ax.cla()
    idxs = np.argsort(reliability)
    i = idxs[i]
    ax.axvline(0, zorder=-100)
    ax.plot(np.arange(cropped.shape[0])*dt-2, cropped[:, :, i], linewidth=0.1, c="k")
    ax.plot(np.arange(cropped.shape[0])*dt-2, cropped[:, :, i].mean(1), linewidth=2, c="r")
    ax.set_ylim(-1, 4)
    sns.despine()
    ax.set_xlabel("Time from bout (s)")

In [None]:
f, ax = plt.subplots()
interact(browse_cells)

In [None]:
fid