# Visualize the matching procedure

This notebook run after having added the "matched" column to the bouts dataframe

In [None]:
%matplotlib widget
from pathlib import Path

import flammkuchen as fl
import numpy as np
from ec_code.analysis_utils import crop_trace
from ec_code.file_utils import get_dataset_location

# from ec_code.plotting_utils import *
from matplotlib import pyplot as plt

master_path = get_dataset_location("fb_effect")

# Load dataframes:
exp_df = fl.load(master_path / "exp_df.h5")
bouts_df = fl.load(master_path / "bouts_df.h5")

In [None]:
# Plot parameters:
dt = 0.2  # dt of the imaging #TODO have this in exp dictionary
PRE_INT_S = 1  # time before bout for the crop, secs
POST_INT_S = 3  # time after the bout for the crop, secs

# Just an example fish:
fid = exp_df.index[6]

In [None]:
fish_bouts = bouts_df.loc[bouts_df["fid"] == fid, :].copy()
timepoints = fish_bouts["t_start"]

beh_trace = fl.load(master_path / "resamp_beh_dict.h5", f"/{fid}")
dt_beh = np.diff(beh_trace.index[:5]).mean()
bt_crop_be = crop_trace(
    beh_trace["vigor"].values,
    timepoints,
    dt_beh,
    PRE_INT_S,
    POST_INT_S,
    normalize=False,
)

In [None]:
for selection in "all", "matched":
    f, axs = plt.subplots(1, 2, figsize=(4, 2), tight_layout=True)

    for ax, gain in zip(axs, ["g0", "g1"]):
        # Select bouts of this gain:
        selection_arr = fish_bouts[gain]

        # If required, select only matched bouts:
        if selection == "matched":
            selection_arr = selection_arr & fish_bouts["matched"]

        # Sort by duration:
        sel_bouts = fish_bouts[selection_arr]
        sort_idxs = np.argsort(sel_bouts["duration"])

        # Select entries in the cropped matrix:
        sel_crops = bt_crop_be[:, selection_arr]

        # plot
        ax.imshow(
            sel_crops[:, sort_idxs].T,
            aspect="auto",
            vmin=0,
            vmax=0.4,
            extent=[-PRE_INT_S, POST_INT_S, 0, sel_crops.shape[1]],
            cmap="gray_r",
        )
        ax.set(xlabel="Time from bout (s)")
    plt.suptitle(f"{fid} ({selection} bouts)", fontsize=10)