In [None]:
%matplotlib widget
%load_ext autoreload
from ipywidgets import interact, interact_manual, widgets

In [None]:
from pathlib import Path

import flammkuchen as fl
import numpy as np
import pandas as pd
import seaborn as sns
from bouter import utilities
from ec_code.analysis_utils import bout_nan_traces
from ec_code.plotting_utils import cols, despine, shade_error
from matplotlib import pyplot as plt
from scipy.interpolate import interp1d
from tqdm import tqdm

sns.set(style="ticks", palette="deep")
from matplotlib.backends.backend_pdf import PdfPages
from scipy.stats import ttest_ind

In [None]:
master_path = Path(r"/Users/luigipetrucco/Google Drive/data/ECs_E50")

exp_df = fl.load(master_path / "exp_df.h5")
trials_df = fl.load(master_path / "trials_df.h5")
cells_df = fl.load(master_path / "cells_df.h5")
trials_df = fl.load(master_path / "trials_df.h5")
traces_df = fl.load(master_path / "traces_df.h5")
nanned_traces_df = traces_df.copy()
bouts_df = fl.load(master_path / "bouts_df.h5")

# df of traces nanned around bouts
for fid in tqdm(exp_df.index):
    start_idxs = (
        np.round(
            bouts_df.loc[bouts_df["fid"] == fid, "t_start"] / exp_df.loc[fid, "dt"]
        )
        .astype(np.int)
        .values
    )
    nanned_traces_df.loc[:, cells_df[cells_df["fid"] == fid].index] = bout_nan_traces(
        traces_df.loc[:, cells_df[cells_df["fid"] == fid].index].values,
        start_idxs,
        wnd_pre=3,
        wnd_post=25,
    )

    # correct timing:
    # bouts_df.loc[bouts_df["fid"]==fid, "t_start"] -= exp_df.loc[fid, "offset"]

In [None]:
# Analysis parameters:
dt = 0.2  # dt of the imaging #TODO have this in exp dictionary
pre_int_s = 2  # time before bout for the crop, secs
post_int_s = 6  # time after the bout for the crop, secs

In [None]:
def crop_trace(trace, timepoints, dt, pre_int_s, post_int_s, normalize=False):
    """Crop a trace given timepoints and crop interval in seconds and sampling dt.
    """
    start_idxs = np.round(timepoints / dt).astype(np.int)
    cropped = utilities.crop(
        trace, start_idxs, pre_int=int(pre_int_s / dt), post_int=int(post_int_s / dt)
    )
    if normalize:
        cropped = cropped - np.nanmean(cropped[: int(pre_int_s / dt), :], 0)

    return cropped


def crop_f_beh_at_times(
    cid,
    timepoints,
    pre_int_s,
    post_int_s,
    cells_df,
    traces_df,
    exp_df,
    normalize=True,
    behavior_key="vigor",
    crop_behavior=False,
):
    """Crop fluorescence trace and behavio trace at given times, and normalize if required.
    """
    fid = cells_df.loc[cid, "fid"]
    cropped = crop_trace(
        traces_df[cid].values,
        timepoints,
        exp_df.loc[fid, "dt"],
        pre_int_s,
        post_int_s,
        normalize=normalize,
    )

    # Crop behavior:
    if crop_behavior:
        beh_trace = fl.load(master_path / "resamp_beh_dict.h5", f"/{fid}")
        dt_beh = np.diff(beh_trace.index[:5]).mean()
        cropped_be = crop_trace(
            beh_trace["vigor"].values,
            timepoints,
            dt_beh,
            pre_int_s,
            post_int_s,
            normalize=False,
        )
        return cropped, cropped_be

    return cropped


def plot_crop(data_mat, f=None, bound_box=None, vlim=3, r=0.65):
    """Plot full matrix and individual and average traces for cropped data.
    """
    if f is None:
        f = plt.figure()
    if bound_box is None:
        bound_box = (0.1, 0.1, 0.6, 0.8)

    hp, vp, w, h = bound_box
    ax = f.add_axes((hp, vp + h * (1 - r), w, h * r))
    ax.imshow(
        data_mat.T,
        aspect="auto",
        extent=(-pre_int_s, post_int_s, 0, data_mat.shape[1]),
        cmap="RdBu_r",
        vmin=-vlim,
        vmax=vlim,
    )
    despine(ax, ticks=True)
    ax1 = f.add_axes((hp, vp, w, h * (1 - r)))
    ax1.axvline(0, linewidth=0.5, c=(0.6,) * 3)
    ax1.plot(
        np.linspace(-pre_int_s, post_int_s, data_mat.shape[0]),
        data_mat,
        linewidth=0.1,
        c="b",
    )
    ax1.plot(
        np.linspace(-pre_int_s, post_int_s, data_mat.shape[0]),
        np.nanmean(data_mat, 1),
        linewidth=1.5,
        c="r",
    )
    despine(ax1, spare=["left", "bottom"])
    ax1.set_xlim(-pre_int_s, post_int_s)
    ax1.set_xlabel("time (s)")

    return ax, ax1

In [None]:
def monster_plot(cid):
    w = 0.2
    h = 0.25
    hpad = 0.05
    vpad = 0.2
    hoff = (1 - hpad * 2) / 4
    voff = (1 - vpad * 2) / 3
    ylim_percent = 2
    f = plt.figure(figsize=(11, 9))

    plt.rcParams["figure.constrained_layout.use"] = True
    plt.rcParams["font.family"] = "sans-serif"
    plt.rcParams["font.sans-serif"] = ["Libertinus Sans"]
    plt.rcParams["xtick.labelsize"] = 10
    plt.rcParams["ytick.labelsize"] = 10
    plt.rcParams["axes.labelsize"] = 12
    plt.rcParams["axes.linewidth"] = 0.5
    plt.rcParams["axes.edgecolor"] = "0.3"
    plt.rcParams["xtick.major.width"] = 0.5
    plt.rcParams["ytick.major.width"] = 0.5

    fid = cells_df.loc[cid, "fid"]
    # bout responses:
    fish_bouts = bouts_df.loc[bouts_df["fid"] == fid, :].copy()
    timepoints = fish_bouts["t_start"]
    bt_crop_f, bt_crop_be = crop_f_beh_at_times(
        cid,
        timepoints,
        pre_int_s,
        post_int_s,
        cells_df,
        traces_df,
        exp_df,
        crop_behavior=True,
    )

    # gratings responses:
    trials = trials_df.loc[
        (trials_df["fid"] == fid) & (trials_df["trial_type"] == "forward"), :
    ]
    trial_starts = trials_df.loc[
        (trials_df["fid"] == fid) & (trials_df["trial_type"] == "forward"), "t_start"
    ]
    fw_crop_f = crop_f_beh_at_times(
        cid,
        trials["t_start"],
        pre_int_s,
        post_int_s,
        cells_df,
        nanned_traces_df,
        exp_df,
    )

    # gratings end responses:
    trial_ends = trials_df.loc[
        (trials_df["fid"] == fid)
        & (trials_df["trial_type"] == "forward")
        & (trials_df["bout_n"] == 0),
        "t_end",
    ]
    if len(trial_ends) > 0:
        fwe_crop_f = crop_f_beh_at_times(
            cid, trial_ends, pre_int_s, post_int_s, cells_df, traces_df, exp_df
        )
    else:
        fwe_crop_f = None

    trial_ends = trials_df.loc[
        (trials_df["fid"] == fid)
        & (trials_df["trial_type"] == "backward")
        & (trials_df["bout_n"] == 0),
        "t_start",
    ]
    if len(trial_ends) > 0:
        bw_crop_f = crop_f_beh_at_times(
            cid, trial_ends, pre_int_s, post_int_s, cells_df, traces_df, exp_df
        )
    else:
        bw_crop_f = None

    # Find y bounds that work with all responses:
    all_crop = np.concatenate(
        [
            a.flatten()
            for a in [bt_crop_f, fw_crop_f, fwe_crop_f, bw_crop_f]
            if a is not None
        ]
    )
    y_bounds = (
        np.nanpercentile(all_crop, ylim_percent),
        np.nanpercentile(all_crop, 100 - ylim_percent),
    )

    # Plot bout responses:
    for i, (k, lab) in enumerate(
        zip(["spont", "g0", "g1"], ["spontaneous", "gain 0", "gain 1"])
    ):
        idxs = np.argwhere((fish_bouts["matched"] & fish_bouts[k]).values)[:, 0]
        idxs_sort = idxs[
            np.argsort(
                fish_bouts.loc[fish_bouts["matched"] & fish_bouts[k], "duration"]
            )
        ]
        if len(idxs_sort) > 0:
            ax0f, ax1f = plot_crop(
                bt_crop_f[:, idxs_sort],
                f=f,
                bound_box=(hpad - 0.012 + hoff * i, 0.27 + voff * 1, w, h),
            )
            ax1f.set_xlabel("Time from bout (s)")
            ax1f.set_ylim(y_bounds)
            for j in range(len(idxs_sort)):
                ax1f.axvspan(
                    0,
                    fish_bouts.loc[
                        fish_bouts["matched"] & fish_bouts[k], "duration"
                    ].values[j],
                    linewidth=0,
                    facecolor=(0.6,) * 3,
                    alpha=0.01,
                )

            ax0v, ax1v = plot_crop(
                bt_crop_be[:, idxs_sort],
                f=f,
                bound_box=(hpad - 0.012 + hoff * i, 0.27 + voff * 2, w, h),
                vlim=1,
            )
            ax0v.set_title(f"Bouts, {lab}, n={len(idxs_sort)}")
            ax1v.set_visible(False)

            if i > 0:
                ax1f.set_yticklabels([])
            else:
                ax0f.set_ylabel("dF/F")
                ax1f.set_ylabel("dF/F")
                ax0v.set_ylabel("Vigor")

    # Shade plot and scatterplot:
    ax_shade = f.add_axes((hpad + hoff * 3.15, 0.27 + voff * 2.2, w, h * 0.6))
    ax_scatter = f.add_axes((hpad + hoff * 3.15, 0.27 + voff * 1, w, h * 0.6))

    for g in [0, 1]:
        idxs = np.argwhere((fish_bouts["matched"] & fish_bouts[f"g{g}"]).values)[:, 0]
        comparison_int = [cells_df.loc[cid, f"int{i}_clol"] for i in [0, 1]]
        ax_shade.axvspan(
            *comparison_int, linewidth=0, facecolor=(0.9,) * 3, zorder=-100
        )
        shade_error(
            bt_crop_f[:, idxs],
            ax=ax_shade,
            xarr=np.arange(bt_crop_f.shape[0]) * dt - pre_int_s,
            c=cols[g],
            label=f"gain {g}",
        )
        despine(ax_shade, spare=["left", "bottom"])
        ax_shade.set_ylabel("dF/F")
        ax_shade.set_xlabel("Time from bout (s)")
        ax_shade.legend(frameon=False, fontsize=10)

        avg_f_interval = [int((pre_int_s + d) / dt) for d in comparison_int]
        fluo_intensities = np.nanmean(
            bt_crop_f[avg_f_interval[0] : avg_f_interval[1], idxs], 0
        )
        durations = fish_bouts.loc[
            fish_bouts["matched"] & fish_bouts[f"g{g}"], "duration"
        ]
        ax_scatter.scatter(
            durations, fluo_intensities, c=cols[g], s=8, label=f"gain {g}"
        )
        despine(ax_scatter, spare=["left", "bottom"])
        ax_scatter.set_ylabel(f"Avg. dF/F ({comparison_int[0]}-{comparison_int[1]} s)")
        ax_scatter.set_xlabel("Bout duration (s)")

    # Responses to gratings
    # Plot responses to fw gratings:
    idxs = np.argsort(trials["bout_latency"].values)
    ax0, ax1 = plot_crop(fw_crop_f[:, idxs], f=f, bound_box=(hpad, 0.1, w, h))
    ax1.axvspan(0, 5, linewidth=0, facecolor=(0.9,) * 3, zorder=-100)
    ax0.set_title(f"All fw trials start")
    ax0.set_ylabel("dF/F")
    ax1.set_xlabel("Time from gratings start (s)")
    ax1.set_ylim(y_bounds)
    ax1.set_ylabel("dF/F")

    # Plot responses to fw gratings end and bg gratings:
    for i, (cropped, title, xlabel) in enumerate(
        zip(
            [fwe_crop_f, bw_crop_f],
            ["No-bout fw trials stop", "No-bout bw trials start"],
            ["Time from gratings stop (s)", "Time from gratings start (s)"],
        )
    ):
        if cropped is not None:
            ax0, ax1 = plot_crop(
                cropped, f=f, bound_box=(hpad + hoff * (i + 1), 0.1, w, h)
            )
            ax0.set_title(title)
            ax1.set_xlabel(xlabel)
            ax1.set_ylim(y_bounds)
            ax1.set_yticklabels([])
            ax1.axvspan(
                [-5 + pre_int_s, 0][i],
                [0, 1][i],
                linewidth=0,
                facecolor=(0.9,) * 3,
                zorder=-100,
            )

    # idxs = np.argwhere((fish_bouts["mindist_included"] & fish_bouts["g1"] & (fish_bouts["duration"] > 0.9)).values)[:, 0]
    # if len(idxs) > 0:
    #    ax0, ax1 = plot_crop(bt_crop_f[:, idxs], f=f, bound_box=(hpad+hoff*3, 0.1, w, h))
    #    ax1.set_yticklabels([])

    plt.suptitle(cid)
    plt.suptitle(
        f"{cid} ({exp_df.loc[fid, 'genotype']}); p={cells_df.loc[cid, 'pval_clol']:1.2e}"
    )

In [None]:
fid = exp_df.index[39]
cid = (
    cells_df.loc[
        (cells_df["fid"] == fid)
        & (cells_df["motor_rel"] > 0.1)
        & (cells_df["backward_rel"] < 0.02)
        & (cells_df["forward_rel"] < 0.02),
        "pval_clol",
    ]
    .sort_values()
    .index[0]
)
cid = "201015_f3_IO_75"
monster_plot(cid)
cells_df.loc[cid, :]

In [None]:
# Generate PDFs
with PdfPages("/Users/luigipetrucco/Desktop/all_best_candidates_nofilt.pdf") as pdf:
    for fid in tqdm(exp_df.index):
        try:
            cid = (
                cells_df.loc[
                    (cells_df["fid"] == fid) & (cells_df["motor_rel"] > 0.1),
                    "pval_clol",
                ]
                .sort_values()
                .index[0]
            )
            monster_plot(cid)
        except IndexError:
            plt.figure()
        pdf.savefig()  # saves the current figure into a pdf page
        plt.close()