# Metrics


## Setup


In [None]:
import json
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt

from boilerdata.axes_enum import AxesEnum as A  # noqa: N814
from boilerdata.models.project import Project
from boilerdata.stages.common import per_trial
from boilerdata.stages.modelfun import model_with_uncertainty
from boilerdata.stages.notebooks.common import set_format
from boilerdata.stages.notebooks.metrics import add_units, plot_new_fits, tex_wrap

### Data


In [None]:
proj = Project.get_project()
meta = [col.name for col in proj.axes.meta]
errors = proj.params.free_errors
fits = proj.params.free_params
df_in = pd.read_csv(
    proj.dirs.file_results,
    index_col=(index := [A.trial, A.run]),
    parse_dates=index,
    dtype={col.name: col.dtype for col in proj.axes.cols},
)
limits_in = {
    A.T_5: (50, 300),
    A.T_s_err: (0, 10),
    A.q_s_err: (0, 5),
    A.k_err: (0, 10),
    A.h_a_err: (0, 40),
    A.h_w_err: (0, 40),
}

### Notebook

In [None]:
set_format()
_ = display()
"""Suppress implicit outputs. Unlike `;`, doesn't get formatted out by `black`."""
_

### Plotting


In [None]:
"""This warning fires unnecessarily when Seaborn or Pandas plots are placed in existing
axes. This warning can't be caught in context of `warnings.catch_warnings()` because
it fires *after* a cell finishes executing. So we have to disable this globally."""
warnings.filterwarnings(
    category=UserWarning,
    action="ignore",
    message="This figure includes Axes that are not compatible with",
)
sns.set_theme(
    context="notebook",
    style="whitegrid",
    palette="bright",
    font="sans-serif",
)

plt.style.use(style=proj.dirs.mpl_base)

## Metrics


In [None]:
metrics = (
    df_in[errors].agg(["median", "max"]).rename(axis="index", mapper={"median": "med"})
).to_dict()
metrics["fit_failure_rate"] = (
    df_in[errors[0]].apply(np.isinf).sum() + df_in[errors[0]].isna().sum()
) / len(df_in)

Path(proj.dirs.file_pipeline_metrics).write_text(json.dumps(metrics), encoding="utf-8")

_

## Plots


### Error and temperature

In [None]:
if proj.params.do_plot:
    cols = [A.joint, A.T_5, *errors]
    df, col_to_unitcol = add_units(df_in, proj)
    col_to_unitcol = {k: v for k, v in col_to_unitcol.items() if k in cols}
    df = df[[col_to_unitcol[col] for col in cols]]
    df, unitcol_to_texunitcol = tex_wrap(df)
    c = dict(zip(col_to_unitcol, unitcol_to_texunitcol.values(), strict=True))

    joints = dict(
        paste=".",
        epoxy="^",
        solder="s",
        none="D",
    )
    limits = {c[k]: v for k, v in limits_in.items() if k in c}

    for error, path in zip(
        errors,
        [
            proj.dirs.plot_error_T_s,
            proj.dirs.plot_error_q_s,
            proj.dirs.plot_error_h_a,
        ],
        strict=True,
    ):
        jg = sns.JointGrid()
        jg.ax_marg_x.remove()

        common = dict(
            data=df,
            y=c[error],
        )
        sns.scatterplot(
            ax=jg.ax_joint,
            **common,
            x=c[A.T_5],
            hue=c[A.joint],
            style=c[A.joint],
            markers=joints,  # type: ignore  # seaborn
            edgecolor="gray",
            hue_order=joints.keys(),
            alpha=0.9,
        )
        sns.histplot(
            ax=jg.ax_marg_y,
            **common,
            stat="count",
            bins=16,  # type: ignore  # seaborn
            color="gray",
        )

        xlo, xhi = zip(jg.ax_joint.get_xlim(), limits[c[A.T_5]], strict=True)
        jg.ax_joint.set_xlim((min(xlo), max(xhi)))

        ylo, yhi = zip(jg.ax_joint.get_ylim(), limits[c[error]], strict=True)
        jg.ax_joint.set_ylim((min(ylo), max(yhi)))  # type: ignore  # seaborn

        jg.figure.savefig(path, dpi=300)

### New model fits


In [None]:
if proj.params.do_plot:
    per_trial(df_in, plot_new_fits, proj, model_with_uncertainty)