In [None]:
%reload_ext autoreload
%autoreload 2


import functools
print = functools.partial(print, flush=True)

import os
import numpy as np
import pandas as pd
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42 # for pdfs
import matplotlib.pyplot as plt
from matplotlib import cm
from pathlib import Path
import pickle
from tqdm import tqdm
import scipy

import flexiznam as flz
from cottage_analysis.analysis import spheres, common_utils
from cottage_analysis.pipelines import pipeline_utils
from v1_depth_analysis.v1_manuscript_2023 import depth_selectivity, get_session_list, depth_decoder
from v1_depth_analysis.v1_manuscript_2023 import common_utils as plt_common_utils

In [None]:
VERSION = 3
SAVE_ROOT = (
    "/camp/lab/znamenskiyp/home/shared/presentations/v1_manuscript_2023/ver"
    + str(VERSION)
    + "/fig_decoder/"
)
os.makedirs(SAVE_ROOT, exist_ok=True)

In [None]:
# Load data
project = "hey2_3d-vision_foodres_20220101"
flexilims_session = flz.get_flexilims_session(project)
results_5 = depth_decoder.concatenate_all_decoder_results(
    flexilims_session,
    session_list=get_session_list.get_all_sessions(
        project=project,
        mouse_list=[
            "PZAH6.4b",
            "PZAG3.4f",
        ],
        closedloop_only=False,
        openloop_only=True,
    ),
    filename="decoder_results.pickle",
)
results_8 = depth_decoder.concatenate_all_decoder_results(
    flexilims_session,
    session_list=get_session_list.get_all_sessions(
        project=project,
        mouse_list=["PZAH8.2h", "PZAH8.2i", "PZAH8.2f", "PZAH10.2d", "PZAH10.2f"],
        closedloop_only=False,
        openloop_only=True,
    ),
    filename="decoder_results.pickle",
)

In [None]:
results_8_matched = depth_decoder.concatenate_all_decoder_results(
    flexilims_session,
    session_list=get_session_list.get_all_sessions(
        project=project,
        mouse_list=["PZAH10.2d", "PZAH10.2f"],
        closedloop_only=False,
        openloop_only=True,
    ),
    filename="decoder_results.pickle",
)

In [None]:
fig = plt.figure()
# Plot accuracy
depth_decoder.dot_plot(
    group1=results_5["accuracy_closedloop"],
    group2=results_5["accuracy_openloop"],
    labels=["5 depths", "8 depths"],
    baselines=[1 / 5, 1 / 8],
    fig=fig,
    group3=results_8["accuracy_closedloop"],
    group4=results_8["accuracy_openloop"],
    colors=["b", "g"],
    markersize=7,
    linewidth=3,
    plot_x=0,
    plot_y=0,
    plot_width=0.3,
    plot_height=1,
    fontsize_dict={"title": 15, "label": 10, "tick": 10, "legend": 5},
)

# Plot conmat
results_5, conmat_mean_closedloop = depth_decoder.calculate_average_confusion_matrix(
    results_5, "conmat", "_closedloop"
)
results_5, conmat_mean_openloop = depth_decoder.calculate_average_confusion_matrix(
    results_5, "conmat", "_openloop"
)
depth_decoder.plot_closed_open_conmat(
    conmat_closed=conmat_mean_closedloop,
    conmat_open=conmat_mean_openloop,
    normalize=True,
    fig=fig,
    plot_x=0.5,
    plot_y=0,
    plot_width=0.7,
    plot_height=0.5,
    fontsize_dict={"title": 15, "label": 10, "tick": 10, "legend": 8},
)

results_8, conmat_mean_closedloop = depth_decoder.calculate_average_confusion_matrix(
    results_8, "conmat", "_closedloop"
)
results_8, conmat_mean_openloop = depth_decoder.calculate_average_confusion_matrix(
    results_8, "conmat", "_openloop"
)
depth_decoder.plot_closed_open_conmat(
    conmat_closed=conmat_mean_closedloop,
    conmat_open=conmat_mean_openloop,
    normalize=True,
    fig=fig,
    plot_x=0.5,
    plot_y=0.6,
    plot_width=0.7,
    plot_height=0.5,
    fontsize_dict={"title": 15, "label": 10, "tick": 10, "legend": 5},
)

# Plot error 
results_5 = depth_decoder.calculate_error_all_sessions(results_5)
results_8 = depth_decoder.calculate_error_all_sessions(results_8)
depth_decoder.dot_plot(
    group1=np.concatenate([results_5["error_closedloop"].values, results_8["error_closedloop"].values]),
    group2=np.concatenate([results_5["error_openloop"].values, results_8["error_openloop"].values]),
    labels=[None, None],
    baselines=[],
    fig=fig,
    group3=None,
    group4=None,
    colors=["k", "k"],
    errorbar=True,
    markersize=7,
    linewidth=3,
    ylim=[0, 200],
    ylabel="Average squared error of log depth",
    plot_x=1.4,
    plot_y=0,
    plot_width=0.3,
    plot_height=1,
    fontsize_dict={"title": 15, "label": 10, "tick": 10, "legend": 5},
)

plt.savefig(SAVE_ROOT + "/decoder.pdf", bbox_inches="tight", transparent=True)
plt.savefig(SAVE_ROOT + "/decoder.png", bbox_inches="tight", transparent=True, dpi=300)

# ADD ERROR

In [None]:
# try to classify the average dff when
# decode from OF only

In [None]:
results_8[results_8.error_openloop < results_8.error_closedloop]

### SUPP

In [None]:
fig = plt.figure()
depth_decoder.dot_plot(
    group1=results_8_matched["accuracy_closedloop"],
    group2=results_8_matched["accuracy_openloop"],
    labels=["8 depths matched"],
    baselines=[1 / 8],
    fig=fig,
    group3=None,
    group4=None,
    markersize=8,
    colors=["g"],
    linewidth=3,
    plot_x=0,
    plot_y=1,
    plot_width=0.3,
    plot_height=1,
    fontsize_dict={"title": 15, "label": 10, "tick": 10, "legend": 5},
)