In [None]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42 # for pdfs
matplotlib.rcParams['svg.fonttype'] = 'none' # for svgs
import matplotlib.pyplot as plt
from matplotlib import cm as cmap
import seaborn as sns
from pathlib import Path
import flexiznam as flz
from cottage_analysis.analysis import spheres, common_utils
from cottage_analysis.pipelines import pipeline_utils
from cottage_analysis.plotting import basic_vis_plots
from v1_depth_map.figure_utils import depth_selectivity, get_session_list, depth_decoder
from v1_depth_map.figure_utils import common_utils as plt_common_utils

In [None]:
project = "hey2_3d-vision_foodres_20220101"
flexilims_session = flz.get_flexilims_session(project)
READ_VERSION = 10
VERSION = 10
READ_ROOT = flz.get_data_root("processed", flexilims_session=flexilims_session) / "v1_manuscript_figures"/f"ver{READ_VERSION}"
SAVE_ROOT = flz.get_data_root("processed", flexilims_session=flexilims_session) / "v1_manuscript_figures"/f"ver{VERSION}"
SAVE_ROOT.mkdir(parents=True, exist_ok=True)

reload=False

In [None]:
# run calculate_rs_stats_all_sessions.py to recompute data
if reload:
    mouse_list = flz.get_entities("mouse", flexilims_session=flexilims_session)
    mouse_list = mouse_list[
        mouse_list.name.isin(
            [
                "PZAH6.4b",
                "PZAG3.4f",
                "PZAH8.2h",
                "PZAH8.2i",
                "PZAH8.2f",
                "PZAH10.2d",
                "PZAH10.2f",
            ]
        )
    ]
    session_list = get_session_list.get_sessions(
        flexilims_session,
        exclude_openloop=False,
        exclude_pure_closedloop=False,
        mouse_list=mouse_list,
    )
    results_all = depth_selectivity.get_rs_stats_all_sessions(
        flexilims_session,
        session_list,
        nbins=60,
        rs_thr_min=None,
        rs_thr_max=None,
        still_only=False,
        still_time=1,
        corridor_length=6,
        blank_length=3,
        overwrite=False,
    )
    results_all.to_pickle(SAVE_ROOT/"supp"/"results_all_rs_supp.pickle")
else:
    results_all = pd.read_pickle(READ_ROOT/"supp"/"results_all_rs_supp.pickle")

In [None]:
# Plot RS PSTH for all sessions
fontsize_dict = {"title": 7, "label": 7, "tick": 5, "legend": 5}
cm = 1 / 2.54
of_threshold = 0.01
fig = plt.figure(figsize=(18 * cm, 18 * cm))
results_all_5depths = results_all[(results_all["session"].iloc[:,0].str.contains("PZAH6.4b"))|(results_all["session"].iloc[:,0].str.contains("PZAG3.4f"))]
results_all_8depths = results_all[~((results_all["session"].iloc[:,0].str.contains("PZAH6.4b"))|(results_all["session"].iloc[:,0].str.contains("PZAG3.4f")))]
nbins=60
blank_length=3
corridor_length=6
blank_ratio = blank_length / (blank_length*2 + corridor_length)
corridor_ratio = corridor_length / (blank_length*2 + corridor_length)
for iplot, (results, depth_list) in enumerate(zip(
    [results_all_5depths, results_all_8depths],
    [np.geomspace(0.06,6,5), np.geomspace(0.05,6.4,8)]
)):
    psth = np.vstack([j for i in results.rs_psth_closedloop.values for j in i]).reshape(len(results), len(depth_list)+1, nbins)
    psth_of = np.degrees(psth/np.hstack([depth_list,1]).reshape(1,-1,1))
    
    ax = fig.add_axes([0.1+iplot*0.45, 0.72, 0.3, 0.1])
    depth_selectivity.plot_PSTH(
        trials_df=None,
        depth_list=depth_list,
        psth=psth,
        roi=0,
        is_closed_loop=True,
        use_col="RS",
        corridor_length=6,
        blank_length=3,
        nbins=60,
        frame_rate=15,
        fontsize_dict=fontsize_dict,
        linewidth=1,
        legend_on=True,
        legend_loc="lower right",
        legend_bbox_to_anchor=(1.2, 0),
        show_ci=True,
        ylim=(0,(np.floor_divide(np.nanmax(np.nanmean(psth, axis=0))*100,10)+1)*10),
    )
    ax.set_ylabel("Running speed (cm/s)", fontsize=fontsize_dict["label"])
    ax.set_yticks([0,(np.floor_divide(np.nanmax(np.nanmean(psth, axis=0))*100,10)+1)*10])
    
    ax = fig.add_axes([0.1+iplot*0.45, 0.27, 0.3, 0.1])
    depth_selectivity.plot_PSTH(
        trials_df=None,
        depth_list=depth_list,
        psth=psth_of[:,:,int(psth_of.shape[2]*blank_ratio-1):int(psth_of.shape[2]*(blank_ratio+corridor_ratio)+1)], # only include of within trials
        roi=0,
        is_closed_loop=True,
        use_col="OF",
        corridor_length=6,
        blank_length=(corridor_length+blank_length*2)/psth_of.shape[2],
        nbins=int(psth_of.shape[2]*corridor_ratio)+2,
        frame_rate=15,
        fontsize_dict=fontsize_dict,
        linewidth=1,
        legend_on=True,
        legend_loc="lower right",
        legend_bbox_to_anchor=(1.2, 0),
        show_ci=True,
        ylim=(1e0,1e3),
    )
    ax.set_yscale("log")
    ax.set_ylabel("Optic flow speed\n(degrees/s)", fontsize=fontsize_dict["label"])
    ax.set_yticks(np.geomspace(1e0,1e3,4))
    

for iplot, (results, depth_list) in enumerate(zip(
    [results_all_5depths, results_all_8depths],
    [np.geomspace(0.06,6,5), np.geomspace(0.05,6.4,8)]
)):
    ax = fig.add_axes([0.1+iplot*0.45, 0.49, 0.3, 0.1])
    depth_selectivity.plot_mean_running_speed_alldepths(results, depth_list, fontsize_dict,
                                                        param="RS", ylim=(0,120),
                                                        linewidth=1.5, elinewidth=1.5, jitter=0.2, scatter_markersize=2, scatter_alpha=0.4, capsize=5, capthick=1.5)
    ax = fig.add_axes([0.1+iplot*0.45, 0.05, 0.3, 0.1])
    depth_selectivity.plot_mean_running_speed_alldepths(results, depth_list, fontsize_dict,
                                                        param="OF", of_threshold=of_threshold,ylim=(5e-1,1e3),
                                                        linewidth=1.5, elinewidth=1.5, jitter=0.2, scatter_markersize=2, scatter_alpha=0.4, capsize=5, capthick=1.5)
plt.savefig(SAVE_ROOT / "fig_supp_speeds.svg", bbox_inches="tight", dpi=300)