In [None]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42 # for pdfs
import matplotlib.pyplot as plt
from matplotlib import cm as cmap
import seaborn as sns
from pathlib import Path
import flexiznam as flz
from cottage_analysis.analysis import spheres, common_utils
from cottage_analysis.pipelines import pipeline_utils
from cottage_analysis.plotting import basic_vis_plots
from v1_depth_analysis.v1_manuscript_2023 import depth_selectivity, get_session_list, depth_decoder
from v1_depth_analysis.v1_manuscript_2023 import common_utils as plt_common_utils

In [None]:
VERSION = 6
SAVE_ROOT = Path(
    f"/camp/lab/znamenskiyp/home/shared/presentations/v1_manuscript_2023/ver{VERSION}"
)
SAVE_ROOT.mkdir(parents=True, exist_ok=True)

In [None]:
results_all = pd.read_pickle(SAVE_ROOT/"supp"/"results_all_rs_supp.pickle")

# project = "hey2_3d-vision_foodres_20220101"
# flexilims_session = flz.get_flexilims_session(project)

# session_list = ["PZAH6.4b_S20220419", "PZAH6.4b_S20220426", "PZAH8.2h_S20230113", "PZAH8.2h_S20230116"]
# results_all = depth_selectivity.get_rs_stats_all_sessions(
#     flexilims_session,
#     session_list,
#     nbins=60,
#     rs_thr_min=None,
#     rs_thr_max=None,
#     still_only=False,
#     still_time=1,
#     corridor_length=6,
#     blank_length=3,
#     overwrite=False,
# )

In [None]:
results = results_all_5depths = results_all[(results_all["session"].iloc[:,0].str.contains("PZAH6.4b"))|(results_all["session"].iloc[:,0].str.contains("PZAG3.4f"))]
depth_list = np.geomspace(0.06,6,5)
nbins=60
psth = np.vstack([j for i in results.rs_psth_closedloop.values for j in i]).reshape(len(results), len(depth_list)+1, nbins)

In [None]:
psth_of = psth/np.hstack([depth_list,1]).reshape(1,-1,1,)

In [None]:
# Plot RS PSTH for all sessions
# plt.rcParams["font.family"] = "Arial"
fontsize_dict = {"title": 7, "label": 7, "tick": 5, "legend": 5}
cm = 1 / 2.54
of_threshold = 0.01
fig = plt.figure(figsize=(18 * cm, 18 * cm))
results_all_5depths = results_all[(results_all["session"].iloc[:,0].str.contains("PZAH6.4b"))|(results_all["session"].iloc[:,0].str.contains("PZAG3.4f"))]
results_all_8depths = results_all[~((results_all["session"].iloc[:,0].str.contains("PZAH6.4b"))|(results_all["session"].iloc[:,0].str.contains("PZAG3.4f")))]
nbins=60
for iplot, (results, depth_list) in enumerate(zip(
    [results_all_5depths, results_all_8depths],
    [np.geomspace(0.06,6,5), np.geomspace(0.05,6.4,8)]
)):
    psth = np.vstack([j for i in results.rs_psth_closedloop.values for j in i]).reshape(len(results), len(depth_list)+1, nbins)
    psth_of = np.degrees(psth/np.hstack([depth_list,1]).reshape(1,-1,1))
    psth_of[psth_of < of_threshold] = of_threshold
    
    ax = fig.add_axes([0.05+iplot*0.45, 1, 0.3, 0.2])
    depth_selectivity.plot_PSTH(
        trials_df=None,
        depth_list=depth_list,
        psth=psth,
        roi=0,
        is_closed_loop=True,
        use_col="RS",
        corridor_length=6,
        blank_length=3,
        nbins=60,
        frame_rate=15,
        fontsize_dict=fontsize_dict,
        linewidth=1,
        legend_on=True,
        legend_loc="lower right",
        legend_bbox_to_anchor=(1.2, 0),
        show_ci=True,
        ylim=(0,None),
    )
    ax.set_ylabel("Running speed (m/s)", fontsize=fontsize_dict["label"])
    
    ax = fig.add_axes([0.05+iplot*0.45, 0.4, 0.3, 0.2])
    depth_selectivity.plot_PSTH(
        trials_df=None,
        depth_list=depth_list,
        psth=psth_of,
        roi=0,
        is_closed_loop=True,
        use_col="RS",
        corridor_length=6,
        blank_length=3,
        nbins=60,
        frame_rate=15,
        fontsize_dict=fontsize_dict,
        linewidth=1,
        legend_on=True,
        legend_loc="lower right",
        legend_bbox_to_anchor=(1.2, 0),
        show_ci=True,
        ylim=(None,None),
    )
    ax.set_yscale("log")
    ax.set_ylabel("Optic flow speed (degrees/s)", fontsize=fontsize_dict["label"])
    

for iplot, (results, depth_list) in enumerate(zip(
    [results_all_5depths, results_all_8depths],
    [np.geomspace(0.06,6,5), np.geomspace(0.05,6.4,8)]
)):
    ax = fig.add_axes([0.05+iplot*0.45, 0.7, 0.3, 0.2])
    depth_selectivity.plot_mean_running_speed_alldepths(results, depth_list, fontsize_dict,
                                                        param="RS",
                                                        linewidth=2, elinewidth=2, jitter=0.2, scatter_markersize=2, scatter_alpha=0.4, capsize=5, capthick=2)
    ax = fig.add_axes([0.05+iplot*0.45, 0.1, 0.3, 0.2])
    depth_selectivity.plot_mean_running_speed_alldepths(results, depth_list, fontsize_dict,
                                                        param="OF", of_threshold=of_threshold,
                                                        linewidth=2, elinewidth=2, jitter=0.2, scatter_markersize=2, scatter_alpha=0.4, capsize=5, capthick=2)