In [None]:
%reload_ext autoreload
%autoreload 2
import numpy as np
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42 # for pdfs
import matplotlib.pyplot as plt
from pathlib import Path
import flexiznam as flz
from v1_depth_analysis.v1_manuscript_2023 import get_session_list, depth_decoder

In [None]:
VERSION = 5
SAVE_ROOT = Path(
    f"/camp/lab/znamenskiyp/home/shared/presentations/v1_manuscript_2023/ver{VERSION}"
)
SAVE_ROOT.mkdir(parents=True, exist_ok=True)

In [None]:
# Load data
project = "hey2_3d-vision_foodres_20220101"
flexilims_session = flz.get_flexilims_session(project)
decoder_results = depth_decoder.concatenate_all_decoder_results(
    flexilims_session,
    session_list=get_session_list.get_sessions(
        flexilims_session,
        closedloop_only=False,
        openloop_only=True,
    ),
    filename="decoder_results.pickle",
)

In [None]:
results_8_matched = depth_decoder.concatenate_all_decoder_results(
    flexilims_session,
    session_list=get_session_list.get_all_sessions(
        project=project,
        mouse_list=["PZAH10.2d", "PZAH10.2f"],
        closedloop_only=False,
        openloop_only=True,
    ),
    filename="decoder_results.pickle",
)

In [None]:
fontsize_dict = {"title": 5, "label": 7, "tick": 5, "legend": 5}


fig = plt.figure(figsize=(16 / 2.54, 10 / 2.54))

# Plot conmat
for i, ndepths in enumerate(decoder_results["ndepths"].unique()):
    conmat_mean = depth_decoder.calculate_average_confusion_matrix(
        decoder_results[decoder_results["ndepths"] == ndepths]
    )

    depth_decoder.plot_closed_open_conmat(
        conmat_mean,
        normalize=True,
        fig=fig,
        plot_x=0.05,
        plot_y=0.55 - i * 0.45,
        plot_width=0.38,
        plot_height=0.5,
        fontsize_dict=fontsize_dict,
    )

# Plot accuracy
fig.add_axes([0.52, 0.2, 0.15, 0.65])
# Plot accuracy
depth_decoder.decoder_accuracy(
    decoder_results,
    markersize=3,
    colors=["b", "g"],
    linewidth=0.5,
    xlabel=["Closed loop", "Open loop"],
    ylabel="Decoding accuracy",
    fontsize_dict=fontsize_dict,
)
depth_decoder.calculate_error_all_sessions(decoder_results)

fig.add_axes([0.75, 0.2, 0.15, 0.65])
# Plot accuracy
depth_decoder.decoder_accuracy(
    decoder_results,
    markersize=3,
    colors=["b", "g"],
    linewidth=0.5,
    xlabel=["Closed loop", "Open loop"],
    ylabel="Decoding error\n$|log_2$(predicted depth / true depth)$|$",
    fontsize_dict=fontsize_dict,
    mode="error",
)
plt.yticks([1, 2, 3], [2, 4, 8])

plt.savefig(SAVE_ROOT / "decoder.pdf", dpi=300)
plt.savefig(SAVE_ROOT / "decoder.png", dpi=300)

# ADD ERROR

In [None]:
# try to classify the average dff when
# decode from OF only

### SUPP

In [None]:
fig = plt.figure()
depth_decoder.dot_plot(
    group1=results_8_matched["accuracy_closedloop"],
    group2=results_8_matched["accuracy_openloop"],
    labels=["8 depths matched"],
    baselines=[1 / 8],
    fig=fig,
    group3=None,
    group4=None,
    markersize=8,
    colors=["g"],
    linewidth=3,
    plot_x=0,
    plot_y=1,
    plot_width=0.3,
    plot_height=1,
    fontsize_dict={"title": 15, "label": 10, "tick": 10, "legend": 5},
)