In [None]:
%reload_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42 # for pdfs
matplotlib.rcParams['svg.fonttype'] = 'none' # for ps
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
from matplotlib import cm

import flexiznam as flz
from cottage_analysis.analysis import spheres, common_utils
from cottage_analysis.plotting import basic_vis_plots
from cottage_analysis.pipelines import pipeline_utils
from v1_depth_map.figure_utils import depth_selectivity, closed_loop_rsof, get_session_list, rf, roi_location
from v1_depth_map.figure_utils import common_utils as plt_common_utils
from sklearn.metrics.pairwise import pairwise_distances

In [None]:
project = "hey2_3d-vision_foodres_20220101"
flexilims_session = flz.get_flexilims_session(project)
READ_VERSION = 10
VERSION = 10
READ_ROOT = flz.get_data_root("processed", flexilims_session=flexilims_session) / "v1_manuscript_2023"/f"ver{READ_VERSION}"
SAVE_ROOT = flz.get_data_root("processed", flexilims_session=flexilims_session) / "v1_manuscript_2023"/f"ver{VERSION}"
SAVE_ROOT.mkdir(parents=True, exist_ok=True)
(SAVE_ROOT / "rf_supp").mkdir(parents=True, exist_ok=True)

reload = False

In [None]:
import warnings
warnings.filterwarnings("ignore")
# calculate pairwise distance for roi centers, rf azimuth, elevation, and depth
def calculate_pairwise_distance_per_session(flexilims_session,
                                            session_name,):
    # find_ndepths
    if ("PZAH6.4b" in session_name) or ("PZAG3.4f" in session_name):
        ndepths = 5
    else:
        ndepths = 8
    # load neurons_df and stat
    neurons_ds = pipeline_utils.create_neurons_ds(
        session_name=session_name,
        flexilims_session=flexilims_session,
        conflicts="skip",
    )

    neurons_df = pd.read_pickle(neurons_ds.path_full)

    suite2p_ds = flz.get_datasets(
        flexilims_session=flexilims_session,
        origin_name=session_name,
        dataset_type="suite2p_rois",
        filter_datasets={"anatomical_only": 3},
        allow_multiple=False,
        return_dataseries=False,
    )
    stat = np.load(suite2p_ds.path_full / "plane0" / "stat.npy", allow_pickle=True)
    iscell = np.load(suite2p_ds.path_full / "plane0" / "iscell.npy", allow_pickle=True)[
        :, 0
    ]
    ops = np.load(suite2p_ds.path_full / "plane0" / "ops.npy", allow_pickle=True).item()

    neurons_df["iscell"] = iscell
    neurons_df["depth_tuned"] = neurons_df.apply(
        lambda x: x["depth_tuning_test_spearmanr_rval_closedloop"] > 0.1
        and x["depth_tuning_test_spearmanr_pval_closedloop"] < 0.05,
        axis=1,
    )

    # find rf and roi centers
    rf.find_rf_centers(
        neurons_df,
        ndepths=ndepths,
        frame_shape=(16, 24),
        is_closed_loop=1,
        resolution=5,
    )
    roi_location.find_roi_centers(neurons_df, stat)
    
    # correct preferred depth
    neurons_df["preferred_depth_corrected"] = neurons_df[
        "preferred_depth_closedloop"
        ] / np.sqrt(
            (np.sin(np.deg2rad(neurons_df["rf_azi"])) ** 2)
            * (np.cos(np.deg2rad(neurons_df["rf_ele"])) ** 2)
            + (np.sin(np.deg2rad(neurons_df["rf_ele"])) ** 2)
        )
    # find significant rfs
    coef = np.stack(neurons_df["rf_coef_closedloop"].values)
    coef_ipsi = np.stack(neurons_df["rf_coef_ipsi_closedloop"].values)
    if coef_ipsi.ndim == 3:
        sig, sig_ipsi = spheres.find_sig_rfs(
            np.swapaxes(np.swapaxes(coef, 0, 2), 0, 1),
            np.swapaxes(np.swapaxes(coef_ipsi, 0, 2), 0, 1),
            n_std=6,
        )
        neurons_df["rf_sig"] = sig
        neurons_df["rf_sig_ipsi"] = sig_ipsi
        
    # find pairwise distance for roi coordinate centers
    session_df = pd.DataFrame()
    neurons_df["preferred_depth_amplitude"] = neurons_df[
            "depth_tuning_popt_closedloop"
        ].apply(lambda x: np.exp(x[0]) + x[-1])
    select_neurons = ((neurons_df.depth_tuned) 
                      & (neurons_df.iscell)
                      & (neurons_df.rf_sig==1)
    )
    coords = [[i, j] for i, j in zip(neurons_df[select_neurons]["center_x"], neurons_df[select_neurons]["center_y"])]
    if len(coords) == 0:
        session_df["roi_distance"] = np.nan
        session_df["rf_azi_distance"] = np.nan
        session_df["rf_ele_distance"] = np.nan
        session_df["preferred_depth_closedloop_distance"] = np.nan
    else:
        ds = pairwise_distances(coords, metric='euclidean')
        ds = ds[np.triu_indices(ds.shape[0], k=1)]
        session_df["roi_distance"] = ds
    
        # find pairwise distance for rf azimuth, elevation, and depth
        for col in ["rf_azi", "rf_ele", "preferred_depth_closedloop", "preferred_depth_corrected"]:
            if "preferred_depth" in col:
                coords = [[np.log2(i)] for i in neurons_df[select_neurons][col]]
            else:
                coords = [[i] for i in neurons_df[select_neurons][col]]
            col_ds = pairwise_distances(coords, metric='euclidean')
            col_ds = col_ds[np.triu_indices(col_ds.shape[0], k=1)]
            session_df[f"{col}_distance"] = col_ds
    
    session_df["session"] = session_name
    return session_df


def calculate_pairwise_distance_all_sessions(
    session_list,
    flexilims_session,
):
    sessions_df_all = pd.DataFrame()
    for session_name in session_list:
        session_df = calculate_pairwise_distance_per_session(
            flexilims_session=flexilims_session,
            session_name=session_name,
        )
        sessions_df_all = pd.concat([sessions_df_all, session_df], axis=0)
        print(f"Finished {session_name}")
    return sessions_df_all

In [None]:
if reload:
    project="hey2_3d-vision_foodres_20220101"
    flexilims_session = flz.get_flexilims_session(project)


    session_df_all = calculate_pairwise_distance_all_sessions(
        session_list= get_session_list.get_sessions(
            flexilims_session=flexilims_session,
            exclude_openloop=False,
            exclude_pure_closedloop=False,
            v1_only=True,
        ),
        flexilims_session=flexilims_session,
    )
    session_df_all.to_pickle(SAVE_ROOT / "rf_supp"/ "pairwise_distance_all_sessions.pkl")
else:
    session_df_all = pd.read_pickle(SAVE_ROOT / "rf_supp"/ "pairwise_distance_all_sessions.pkl")

pixel_size = 661/1024 # 2p FOV 661 um, 1024 pixels
session_df_all["roi_distance"] = session_df_all["roi_distance"]*pixel_size

In [None]:
fig=plt.figure(figsize=(18/2.54, 18/2.54))
fontsize_dict={"title":7, "label": 7, "tick": 5}
plot_CI=True
recompute_CI=True
plot_param = "mean"

select_pairs = session_df_all["roi_distance"] > 10 # remove any pairs with roi distance less than 10 um
bins =  np.linspace(10, 860, 30)
bin_centers = (bins[1:] + bins[:-1]) / 2
if plot_CI:
    if recompute_CI:
        CIs_low = np.zeros((3,len(bin_centers)))
        CIs_high = np.zeros((3,len(bin_centers)))
        for icol, (col, ylabel) in enumerate(zip(["rf_azi", "rf_ele", "preferred_depth_closedloop"],
                                                ["Mean RF azimuth distance (degrees)", 
                                                "Mean RF elevation distance (degrees)", 
                                                "Mean ln-preferred depth distance"])):
            session_df_all['Binned'] = pd.cut(session_df_all['roi_distance'], bins=bins)
            mean_binned = session_df_all[select_pairs].groupby('Binned')[f'{col}_distance'].mean()
            
            for i in range(len(mean_binned)):
                CI_low, CI_high = common_utils.get_bootstrap_ci(session_df_all[select_pairs].groupby('Binned')[f'{col}_distance'].apply(lambda x: x.values).values[i])
                CIs_low[icol, i] = CI_low
                CIs_high[icol, i] = CI_high

for icol, (col, ylabel) in enumerate(zip(["rf_azi", "rf_ele", "preferred_depth_closedloop"],
                                         [f"{plot_param.capitalize()} |\u0394(RF azimuth)| (degrees)", 
                                          f"{plot_param.capitalize()} |\u0394(RF elevation)| (degrees)", 
                                          f"{plot_param.capitalize()} |\u0394log\u2082(preferred virtual depth)|"])):
    ax=fig.add_axes([0.1 + icol * 0.3, 0.1, 0.2, 0.18])
    session_df_all['Binned'] = pd.cut(session_df_all['roi_distance'], bins=bins)
    if plot_param == "mean":
        mean_binned = session_df_all[select_pairs].groupby('Binned')[f'{col}_distance'].mean()
    elif plot_param == "median":
        mean_binned = session_df_all[select_pairs].groupby('Binned')[f'{col}_distance'].median()
    ax.plot(bin_centers, mean_binned.values, color='k', linewidth=1)
    if plot_CI:
        CI_low, CI_high = CIs_low[icol], CIs_high[icol]
        ax.fill_between(bin_centers, np.array(CI_low).flatten(), np.array(CI_high).flatten(), color='k', alpha=0.3, edgecolor='none')
    ax.set_xticks(np.linspace(0,860,3))
    ax.set_xticklabels(np.linspace(0,860,3).astype("int"),fontsize=fontsize_dict["tick"])
    ax.set_xlim(0,860)
    if icol < 2 and plot_CI:
        plt.ylim(0,np.ceil(np.nanmax(CI_high)))
        plt.yticks([0, np.ceil(np.nanmax(CI_high))])
    else:
        plt.ylim(0,plt_common_utils.ceil(np.nanmax(mean_binned),1))
        plt.yticks([0, plt_common_utils.ceil(np.nanmax(mean_binned),1)])
    ax.tick_params(axis='x', rotation=0, labelsize=fontsize_dict["tick"])
    ax.tick_params(axis='y', labelsize=fontsize_dict["tick"])
    plt.xlabel('Distance between cells (\u03bcm)', fontsize=fontsize_dict["label"])
    plt.ylabel(ylabel, fontsize=fontsize_dict["label"])
    sns.despine(ax=ax)
fig.savefig(SAVE_ROOT / "rf_supp" / "pairwise_distance_all_sessions.svg", bbox_inches='tight')