In [None]:
%matplotlib widget
%load_ext autoreload
from ipywidgets import interact, interact_manual

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import flammkuchen as fl
from bouter import utilities
import local_utils
from bouter.angles import reduce_to_pi
from scipy.interpolate import interp1d
from tqdm import tqdm


from matplotlib import pyplot as plt
import seaborn as sns
sns.set(style="ticks", palette="deep")
cols = sns.color_palette()

In [None]:
master_path = Path(r"/Users/luigipetrucco/Google Drive/data/ECs_E50")

exp_df = fl.load(master_path / "exp_df.h5")
trials_df = fl.load(master_path / "trials_df.h5")
cells_df = fl.load(master_path / "cells_df.h5")
trials_df = fl.load(master_path / "trials_df.h5")
traces_df = fl.load(master_path / "traces_df.h5")
bouts_df = fl.load(master_path / "bouts_df.h5")

dt = 0.2

# Add columns if necessary:
for new_col in ["forward_rel", "backward_rel", "forward_amp", "backward_amp"]:
    if new_col not in cells_df.columns:
        cells_df[new_col] = np.nan

# Compute all indexes

In [None]:
crop_params_dict = dict(forward=dict(pre_int_s=1, post_int_s=9),
                        backward=dict(pre_int_s=1, post_int_s=5))

wnd_pre_bout_nan_s = 0.4
wnd_post_bout_nan_s = 6

In [None]:
# stim = "backward"
amplitude_percent = 80

pre_wnd_bout_nan = int(wnd_pre_bout_nan_s / dt)
post_wnd_bout_nan = int(wnd_post_bout_nan_s / dt)

for stim in ["forward", "backward"]:
    pre_int_s = crop_params_dict[stim]["pre_int_s"]
    post_int_s = crop_params_dict[stim]["post_int_s"]
    
    for fid in tqdm(exp_df.index):
        cells_fsel = cells_df["fid"] == fid # .copy()

        # Traces matrix, and bout-nanned version:
        bout_start_idxs = np.round(bouts_df.loc[bouts_df["fid"] == fid, "t_start"] / dt).values.astype(np.int)

        traces = traces_df.loc[:, cells_df[cells_fsel].index].copy()
        traces_mat = traces.values
        traces_mat_nanbouts = local_utils.bout_nan_traces(traces_mat, bout_start_idxs, 
                                              wnd_pre=pre_wnd_bout_nan,
                                              wnd_post=post_wnd_bout_nan)
        # trials:
        ftrials_df = trials_df.loc[trials_df["fid"]==fid, :]
        # Exclude trials with a leading bout too close to stimulus onset:
        # ftrials_df = ftrials_df.loc[np.isnan(ftrials_df["lead_bout_latency"]), :]


        start_idxs = np.round(ftrials_df.loc[ftrials_df["trial_type"]==stim, "t_start"] / dt)
        bout_lat_sort = np.argsort(ftrials_df.loc[ftrials_df["trial_type"]==stim, "bout_latency"].values)

        cropped_nan = utilities.crop(traces_mat_nanbouts, 
                                 start_idxs, 
                                 pre_int=int(pre_int_s / dt), 
                                 post_int=int(post_int_s / dt))
        cropped_nan = cropped_nan[:, bout_lat_sort, :]
        cropped_nan = cropped_nan - np.nanmean(cropped_nan[:int(pre_int_s / dt), :, :], 0)

        # cells_df.loc[cells_df["fid"]==fid, "motor_rel"] = utilities.reliability(cropped)
        reliabilities = utilities.reliability(cropped_nan)

        # Calculate mean response for all cells:
        mean_resps = np.nanmean(cropped_nan, 1)

        # Calculate amplitude of the response looking at top 20% percentile of the response
        # (response is normalized at pre-stim onset)
        
        amplitudes = local_utils.max_amplitude_resp(mean_resps, 
                                                    percentile=amplitude_percent)
        
        cells_df.loc[cells_fsel, f"{stim}_rel"] = reliabilities
        cells_df.loc[cells_fsel, f"{stim}_amp"] = amplitudes

In [None]:
fl.save(master_path / "cells_df.h5", cells_df)

# Explore

In [None]:

def browse_cells(i=(0, len(reliability) - 1)):  
    i = idxs[i]
    for j, mat in enumerate([cropped, cropped_nan]):
        ax = axs[j, 0]
        ax.cla()
        ax.axvline(0, zorder=-100)
        ax.plot(np.arange(mat.shape[0])*dt-1, mat[:, :, i], linewidth=0.1, c="k")
        ax.plot(np.arange(mat.shape[0])*dt-1, np.nanmean(mat[:, :, i], 1), linewidth=2, c="r")
        ax.set_ylim(-1, 2.5)
        sns.despine()
        ax.set_xlabel("Time from bout (s)")

        ax = axs[j, 1]
        ax.imshow(mat[:, :, i].T, aspect="auto", vmin=-1, vmax=2.5)
    
    axs[0, 0].set_title(f"{i}, {reliability[i]}")

f, axs = plt.subplots(2,2, figsize=(9, 8))
interact(browse_cells)

In [None]:
cells_df["genotype"] = cells_df["fid"].map(exp_df["genotype"])

In [None]:
for d in ["forward", "backward"]:
    plt.figure()
    plt.title(d + " reliability")
    sns.violinplot(data=cells_df, x="genotype", hue="genotype", y=f"{d}_rel", s=1)
    
    plt.figure()
    plt.title(d)
    sns.violinplot(data=cells_df.loc[cells_df[f"{d}_rel"] > 0.1, :], x="genotype", hue="genotype", y=f"{d}_amp", s=1)
    # sns.swarmplot(data=cells_df.loc[cells_df[f"{d}_rel"] > 0.05, :], x="genotype", hue="genotype", y=f"{d}_amp", s=1)

In [None]:
reliability = utilities.reliability(cropped)

In [None]:
cells_df

In [None]:
fid