In [None]:
%matplotlib widget
%load_ext autoreload
from ipywidgets import interact, interact_manual

In [None]:
from pathlib import Path

import flammkuchen as fl
import numpy as np
import pandas as pd
import seaborn as sns
from bouter import bout_stats, decorators, utilities
from bouter.angles import reduce_to_pi
from bouter.utilities import crop
from ec_code.analysis_utils import bout_nan_traces, max_amplitude_resp
from ec_code.file_utils import get_dataset_location
from ec_code.plotting_utils import *
from matplotlib import pyplot as plt
from scipy.interpolate import interp1d
from scipy.stats import ttest_ind
from tqdm import tqdm

sns.set(style="ticks", palette="deep")
cols = sns.color_palette()


def crop_trace(trace, timepoints, dt, pre_int_s, post_int_s, normalize=False):
    """Crop a trace given timepoints and crop interval in seconds and sampling dt.
    """
    start_idxs = np.round(timepoints / dt).astype(np.int)
    cropped = utilities.crop(
        trace, start_idxs, pre_int=int(pre_int_s / dt), post_int=int(post_int_s / dt)
    )
    if normalize:
        cropped = cropped - np.nanmean(cropped[: int(pre_int_s / dt), :], 0)

    return cropped

In [None]:
master_path = get_dataset_location("fb_effect")

exp_df = fl.load(master_path / "exp_df.h5")
trials_df = fl.load(master_path / "trials_df.h5")
cells_df = fl.load(master_path / "cells_df.h5")
bouts_df = fl.load(master_path / "bouts_df.h5")
traces_df = fl.load(master_path / "traces_df.h5")

In [None]:
fid = exp_df.index[6]

In [None]:
# Analysis parameters:
dt = 0.2  # dt of the imaging #TODO have this in exp dictionary
pre_int_s = 2  # time before bout for the crop, secs
post_int_s = 6  # time after the bout for the crop, secs
amplitude_percent = 90  # percentile for the calculation of the response amplitude

min_dist_s = 2

# Widow for nanning out the bout artefacts
wnd_pre_bout_nan_s = 0.2
wnd_post_bout_nan_s = 0.2

min_distance_exclusion = (bouts_df["after_interbout"] > post_int_s) & (
    bouts_df["inter_bout"] > min_dist_s
)

selections_dict = dict(
    motor=min_distance_exclusion,
    motor_g0=min_distance_exclusion
    & (bouts_df["base_vel"] < 0)
    & (bouts_df["gain"] == 0),
    motor_g1=min_distance_exclusion
    & (bouts_df["base_vel"] < 0)
    & (bouts_df["gain"] == 1),
    motor_spont=min_distance_exclusion & (bouts_df["base_vel"] > -10),
)

#
for val in ["rel", "amp"]:
    for sel in selections_dict.keys():
        column_id = f"{sel}_{val}"
        if column_id not in cells_df.columns:
            cells_df[column_id] = np.nan


pre_wnd_bout_nan = int(wnd_pre_bout_nan_s / dt)
post_wnd_bout_nan = int(wnd_post_bout_nan_s / dt)

# Loop over criteria for the different reliabilities:
for selection in selections_dict.keys():

    # Loop over fish:
    for fid in tqdm(exp_df.index):
        cells_fsel = cells_df.loc[cells_df["fid"] == fid, :]  # .copy()
        traces = traces_df.loc[:, cells_fsel.index].copy()

        # Nan all bouts:
        start_idxs = np.round(
            bouts_df.loc[bouts_df["fid"] == fid, "t_start"] / dt
        ).astype(np.int)
        traces = bout_nan_traces(
            traces.values,
            start_idxs,
            wnd_pre=pre_wnd_bout_nan,
            wnd_post=post_wnd_bout_nan,
        )

        beh_df = fl.load(master_path / "beh_dict.h5", f"/{fid}")
        stim_df = fl.load(master_path / "stim_dict.h5", f"/{fid}")

        sel_bouts = bouts_df[(bouts_df["fid"] == fid) & selections_dict[selection]]
        sel_start_idxs = np.round(sel_bouts["t_start"] / dt).astype(np.int)

        # Crop cell responses around bouts:
        cropped = utilities.crop(
            traces,
            sel_start_idxs,
            pre_int=int(pre_int_s / dt),
            post_int=int(post_int_s / dt),
        )

        # Subtract pre-bout baseline:
        cropped = cropped - np.nanmean(cropped[: int(pre_int_s / dt), :, :], 0)

        # Calculate reliability indexes:
        reliabilities = utilities.reliability(cropped)

        # Calculate mean response for all cells:
        mean_resps = np.nanmean(cropped, 1)

        # Calculate amplitude of the response looking at top 20% percentile of the response
        # (response is normalized at pre-stim onset):
        amplitudes = max_amplitude_resp(mean_resps, percentile=amplitude_percent)

        cells_df.loc[cells_fsel.index, f"{selection}_rel"] = reliabilities
        cells_df.loc[cells_fsel.index, f"{selection}_amp"] = amplitudes

    # fl.save(master_path / "cells_df.h5", cells_df)

In [None]:
#############################
# Calculate ol vs cl pvalues:
wnd_s = 2  # Window of average response over which calculate pval
wnd = int(wnd_s / dt)
perc_excluding_shortbouts = 20
n_pval_intervals = 4
step_pval_intervals = 1

for i in range(step_pval_intervals):
    cells_df[f"pval_clol"] = np.nan
    cells_df[f"int0_clol"] = np.nan
    cells_df[f"int1_clol"] = np.nan
    # cells_df[f"amp_cl"] = np.nan
    # cells_df[f"amp_ol"] = np.nan

for fid in tqdm(exp_df.index):
    cell_idxs = cells_df[cells_df["fid"] == fid].index

    sel = (bouts_df["fid"] == fid) & bouts_df["matched"]

    # Esclude short bouts from p val calculation:
    min_dur = np.percentile(bouts_df.loc[sel, "duration"], perc_excluding_shortbouts)
    sel = sel & (bouts_df["duration"] >= min_dur)

    # Crop bouts:
    timepoints = bouts_df.loc[sel, "t_start"]
    cropped = crop_trace(
        traces_df[cells_df[cells_df["fid"] == fid].index].values,
        timepoints,
        0.2,
        pre_int_s,
        post_int_s,
        normalize=True,
    )

    for n, cell_idx in enumerate(cell_idxs):
        # Calculate p value over 4 intervals:
        pvals = np.zeros(n_pval_intervals)
        # amps = np.zeros(n_pval_intervals, 2)

        for i in range(n_pval_intervals):
            t_start = pre_int_s + i * step_pval_intervals
            i_start = int(t_start / dt)
            mean_resps = [
                np.nanmean(
                    cropped[i_start : i_start + wnd, bouts_df.loc[sel, "gain"] == g, n],
                    0,
                )
                for g in range(2)
            ]
            pvals[i] = ttest_ind(mean_resps[0], mean_resps[1]).pvalue
            # amps[i, :] =

        best_p_idx = np.argmin(pvals)
        best_t_start = best_p_idx * step_pval_intervals

        cells_df.loc[cell_idx, f"pval_clol"] = pvals[best_p_idx]
        cells_df.loc[cell_idx, f"int0_clol"] = best_t_start
        cells_df.loc[cell_idx, f"int1_clol"] = best_t_start + wnd_s
        # cells_df.loc[cell_idx, f"amp_cl"] = np.nan
        # cells_df.loc[cell_idx, f"amp_ol"] = np.nan

In [None]:
fl.save(master_path / "cells_df.h5", cells_df)
fl.save(master_path / "bouts_df.h5", bouts_df)

# Plots

In [None]:
plt.figure()
sns.violinplot(data=cells_df, y="motor_spont_rel", x="genotype")

In [None]:
plt.figure()
sns.violinplot(data=cells_df, y="motor_spont_amp", x="genotype")

In [None]:
plt.figure()
plt.scatter(cells_df["forward_rel"], cells_df["motor_spont_rel"], s=5)

In [None]:
import matplotlib.pyplot as plt
import mplcursors
import numpy as np

np.random.seed(42)

fig, ax = plt.subplots()
ax.scatter(*np.random.random((2, 26)))
ax.set_title("Mouse over a point")

mplcursors.cursor(hover=True)

plt.show()

In [None]:
reliability = utilities.reliability(cropped)

In [None]:
def browse_cells(i=(0, len(cells) - 1)):
    ax.cla()
    idxs = np.argsort(reliability)
    i = idxs[i]
    ax.axvline(0, zorder=-100)
    ax.plot(
        np.arange(cropped.shape[0]) * dt - 2, cropped[:, :, i], linewidth=0.1, c="k"
    )
    ax.plot(
        np.arange(cropped.shape[0]) * dt - 2,
        cropped[:, :, i].mean(1),
        linewidth=2,
        c="r",
    )
    ax.set_ylim(-1, 4)
    sns.despine()
    ax.set_xlabel("Time from bout (s)")

In [None]:
f, ax = plt.subplots()
interact(browse_cells)

In [None]:
fid