In [None]:
DEBUG = True
"""Toggle this to run exploratory cells in the notebook."""
display()

# Metrics

## Setup

### Imports

In [None]:
# # This check is flaky in a notebook
# pyright: reportUnnecessaryTypeIgnoreComment=none
from contextlib import contextmanager
import json
from pathlib import Path
from shutil import copy
from typing import Any
import warnings

from IPython.display import Markdown
import janitor  # pyright: ignore [reportUnusedImport]  # adds methods to pd.DataFrame
from matplotlib import pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import seaborn as sns
from uncertainties import ufloat

from boilerdata.axes_enum import AxesEnum as A  # noqa: N814
from boilerdata.models.enums import Joint
from boilerdata.models.project import Project
from boilerdata.stages.common import get_tcs, get_trial, per_run
from boilerdata.stages.notebooks.common import FLOAT_SPEC  # type: ignore  # magic
from boilerdata.stages.notebooks.common import (
    add_units,
    chdir_to_nearest_git_root_and_get_project,
    tex_wrap,
)

### Notebook and plot formatting

In [None]:
%precision %$FLOAT_SPEC

In [None]:
proj = chdir_to_nearest_git_root_and_get_project()

### Parameters and inputs

#### Data

In [None]:
meta = [col.name for col in proj.axes.meta]
errors = proj.params.free_errors
fits = proj.params.free_params
df_in = pd.read_csv(
    proj.dirs.file_results,
    index_col=(index := [A.trial, A.run]),
    parse_dates=index,
    dtype={col.name: col.dtype for col in proj.axes.cols},
)

#### Plotting

In [None]:
"""This warning fires unnecessarily when Seaborn or Pandas plots are placed in existing
axes. This warning can't be caught in context of `warnings.catch_warnings()` because
it fires *after* a cell finishes executing. So we have to disable this globally."""
warnings.filterwarnings(
    category=UserWarning,
    action="ignore",
    message="This figure includes Axes that are not compatible with",
)

sns.set_theme(context="notebook", style="whitegrid", palette="deep", font="sans-serif")
plt.style.use(style=proj.dirs.file_style)

cmap_blues = mpl.colormaps["Blues"]  # type: ignore  # matplotlib
cmap_reds = mpl.colormaps["Reds"]  # type: ignore  # matplotlib

display(cmap_blues, cmap_reds)

### Functions

In [None]:
@contextmanager
def manual_subplot_spacing():
    """Context manager that allows custom spacing of subplots."""
    with mpl.rc_context({"figure.autolayout": False}):
        try:
            yield
        finally:
            ...


idxs = pd.IndexSlice
"""Use to slice pd.MultiIndex indices"""


def display_named(*args: tuple[Any, str]):
    """Display objects with names above them."""
    for elem, name in args:
        display(Markdown(f"##### {name}"))
        display(elem)
        print()


def get_default_aggs(columns: list[str]) -> dict[str, pd.NamedAgg]:
    """Get default aggregations for columns that are part of the original dataset."""
    return {col: agg for col, agg in proj.axes.aggs.items() if col in columns}


def get_params_mapping_with_uncertainties(
    grp: pd.DataFrame, proj: Project
) -> dict[str, Any]:
    """Get a mapping of parameter names to values with uncertainty."""
    model_params_and_errors = proj.params.params_and_errors
    # Reason: pydantic: use_enum_values
    params: list[str] = proj.params.model_params  # type: ignore
    param_errors: list[str] = proj.params.model_errors
    u_params = [
        ufloat(param, err, tag)
        for param, err, tag in zip(
            grp[params], grp[param_errors], model_params_and_errors
        )
    ]
    return dict(zip(model_params_and_errors, u_params))


def get_params_mapping(grp: pd.DataFrame, params: list[Any]) -> dict[str, Any]:
    """Get a mapping of parameter names to values."""
    # Reason: pydantic: use_enum_values
    return dict(zip(params, grp[params]))


def model_with_error(model, x, u_params):
    """Evaluate the model for x and return y with errors."""
    u_x = [ufloat(v, 0, "x") for v in x]
    u_y = model(u_x, **u_params)
    y = np.array([v.nominal_value for v in u_y])
    y_min = y - [
        v.std_dev for v in u_y
    ]  # pyright: ignore [reportGeneralTypeIssues]  # uncertainties, triggered only locally
    y_max = y + [v.std_dev for v in u_y]
    return y, y_min, y_max

## Plots

### Error distribution by joint type

In [None]:
df = df_in[[A.joint, *errors]]  # .pipe(add_units, proj)
display_named((df.groupby(A.joint).max(), "Max error by joint type"))

#### Show errors by joint type

In [None]:
for agg, path in dict(
    median=proj.dirs.plot_median_error_by_joint,
    max=proj.dirs.plot_max_error_by_joint,
).items():
    fig, ax = plt.subplots()
    df = df_in[[A.joint, *errors]]
    err = df.groupby(A.joint).agg(agg)
    err_scaled = err.div(pd.Series(proj.params.error_scale)[errors])
    sns.heatmap(
        ax=ax,
        data=err_scaled.pipe(tex_wrap),
        annot=err,
        cmap=cmap_blues,
        square=True,
    )
    ax.set_title(f"{agg.title()} Error")
    ax.set_xlabel("Error")
    ax.set_ylabel("Joint Type")
    fig.savefig(path)  # type: ignore  # matplotlib

## Show errors by temperature range

In [None]:
temp_range = "Temperature Range"
for agg, path in dict(
    median=proj.dirs.plot_median_error_by_range,
    max=proj.dirs.plot_max_error_by_range,
).items():
    fig, ax = plt.subplots()
    columns = [A.T_5, *errors]
    temperature_bin_left_edges = dict(
        Under=-np.inf,
        Low=100,
        Med=110,
        High=120,
        Over=130,
    )
    temperature_bin_edges = list(temperature_bin_left_edges.values()) + [np.inf]
    labels = list(temperature_bin_left_edges.keys())
    df = (
        df_in[columns]
        .assign(
            **{
                temp_range: lambda df: pd.cut(
                    df[A.T_5],
                    bins=temperature_bin_edges,
                    labels=labels,
                )
            }
        )
        .groupby(temp_range)
        .agg(agg)
        .drop(axis="columns", labels=A.T_5)
    )
    err = df
    err_scaled = err.div(pd.Series(proj.params.error_scale)[errors])

    sns.heatmap(
        ax=ax,
        data=err_scaled.pipe(tex_wrap),
        annot=err,
        cmap=cmap_reds,
        square=True,
    )
    ax.set_title(f"{agg.title()} Error")
    ax.set_xlabel("Error")
    fig.savefig(path)  # type: ignore  # matplotlib

### Determine the T5 distribution

In [None]:
cols = [A.T_5]
xmin = 100
xmax = 270
ymin = 0
ymax = 1
quantiles = (0.00, 0.20, 0.55, 0.90, 1.00)
labels = ("Under", "Low", "High", "Over")

if DEBUG:
    fig, ax = plt.subplots()

    _ = sns.kdeplot(
        ax=ax,
        data=df_in[cols],
        cumulative=True,
        legend=False,
    )

    ax.set_title("Cumulative KDE of T5")
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    ax.set_autoscale_on(False)
    _ = ax.hlines(quantiles, xmin, xmax, colors=[0.5] * 3)  # type: ignore  # matplotlib
    ax.add_patch(
        mpl.patches.Rectangle(  # type: ignore  # matplotlib
            xy=(xmin, quantiles[0]),
            width=xmax - xmin,
            height=quantiles[1] - quantiles[0],
            color="red",
            alpha=0.5,
        )
    )
    ax.add_patch(
        mpl.patches.Rectangle(  # type: ignore  # matplotlib
            xy=(xmin, quantiles[-1]),
            width=xmax - xmin,
            height=quantiles[-2] - quantiles[-1],
            color="red",
            alpha=0.5,
        )
    )

### Bin the data by useful T5 values

In [None]:
display(pd.qcut(df_in[A.T_5], quantiles).cat.categories)

In [None]:
bins = [75.756, 101.476, 117.001, 161.608, 264.268]
"""Determined from the exploration above and fixed."""

cols = [A.joint, A.T_5, *errors]
temp_range = "Temperature Range"
df = (
    df_in[cols]
    .assign(
        **{
            temp_range: lambda df: (
                pd.cut(
                    x=df_in[A.T_5],
                    bins=bins,
                    labels=labels,
                    right=True,
                )
            ).cat.remove_categories(["Under", "Over"]),
        },
    )
    .dropna(axis="index", subset=[temp_range])
    .drop(axis="columns", labels=A.T_5)
)
display_named(
    (df[temp_range].value_counts().to_frame(), "Counts"),
    (pd.DataFrame(index=df.columns), "Columns"),
)

In [None]:
for error in errors:
    fig, ax = plt.subplots()
    sns.boxplot(
        ax=ax,
        data=df,
        x=error,
        y=A.joint,
        hue=temp_range,
    )
    sns.move_legend(obj=ax, loc="upper left", bbox_to_anchor=(1, 1))

In [None]:
{err: scale for err, scale in free_error_scale.items()}

In [None]:
error

In [None]:
free_error_scale = {
    k: v for k, v in proj.params.error_scale.items() if k in proj.params.free_errors  # type: ignore  # pydantic: use_enum_values
}

df_scaled = df.assign(
    **{err: lambda df: df[err].div(scale) for err, scale in free_error_scale.items()}
)

for error in errors:
    fig, ax = plt.subplots()
    sns.boxplot(
        data=df_scaled,
        x=error,
        y=A.joint,
        hue=temp_range,
    )
    sns.move_legend(obj=ax, loc="upper left", bbox_to_anchor=(1, 1))

# WIP

In [None]:
def write_metrics(df: pd.DataFrame, proj: Project, errors, fits):
    """Compute summary metrics of the model fit and write them to a file."""
    # sourcery skip: merge-dict-assign
    first_fit = fits[0]

    def strip_err(df: pd.DataFrame) -> pd.DataFrame:
        """Strip the "err" suffix from the column names."""
        return df.rename(axis="columns", mapper=lambda col: col.removesuffix("_err"))

    # Reason: pydantic: use_enum_values
    error_ratio = df[errors].pipe(strip_err) / df[fits]
    error_normalized = (df[errors] / df[errors].max()).pipe(strip_err)

    # Compute the rate of failures to fit the model
    metrics: dict[str, float] = {}
    metrics["fit_failure_rate"] = df[first_fit].isna().sum() / len(df)

    # Compute the median and spread of the error two ways
    metric_dfs = {"err_ratio": error_ratio, "err_norm": error_normalized}
    for err_tag, err_df in metric_dfs.items():
        for agg in ["median", "std"]:
            metrics |= {
                f"{k}_{err_tag}_{agg}": v for k, v in err_df.agg(agg).to_dict().items()
            }
    metrics |= {k: 0 for k, v in metrics.items() if np.isnan(v)}
    proj.dirs.file_pipeline_metrics.write_text(json.dumps(metrics, indent=2))

In [None]:
def plot_new_fits(grp: pd.DataFrame, proj: Project, model):
    """Plot model fits for trials marked as new."""

    trial = get_trial(grp, proj)
    if not trial.new:
        return grp

    ser = grp.squeeze()
    tcs, tc_errors = get_tcs(trial)
    x_unique = list(trial.thermocouple_pos.values())
    y_unique = ser[tcs]

    # Plot setup
    fig, ax = plt.subplots(layout="constrained")

    run = ser.name[-1].isoformat()
    run_file = proj.dirs.new_fits / f"{run.replace(':', '-')}.png"

    ax.margins(0, 0)
    ax.set_title(f"{run = }")
    ax.set_xlabel("x (m)")
    ax.set_ylabel("T (C)")

    # Initial plot boundaries
    x_bounds = np.array([0, trial.thermocouple_pos[A.T_1]])

    y_bounds = model(x_bounds, **get_params_mapping(ser, proj.params.model_params))
    ax.plot(
        x_bounds,
        y_bounds,
        "none",
    )

    # Measurements
    measurements_color = [0.2, 0.2, 0.2]
    ax.plot(
        x_unique,
        y_unique,
        ".",
        label="Measurements",
        color=measurements_color,
        markersize=10,
    )
    ax.errorbar(
        x=x_unique,
        y=y_unique,
        yerr=ser[tc_errors],
        fmt="none",
        color=measurements_color,
    )

    # Confidence interval
    (xlim_min, xlim_max) = ax.get_xlim()
    pad = 0.025 * (xlim_max - xlim_min)
    x_padded = np.linspace(xlim_min - pad, xlim_max + pad)

    y_padded, y_padded_min, y_padded_max = model_with_error(
        model, x_padded, get_params_mapping_with_uncertainties(ser, proj)
    )
    ax.plot(
        x_padded,
        y_padded,
        "--",
        label="Model Fit",
    )
    ax.fill_between(
        x=x_padded,
        y1=y_padded_min,
        y2=y_padded_max,  # pyright: ignore [reportGeneralTypeIssues]  # matplotlib
        color=[0.8, 0.8, 0.8],
        edgecolor=[1, 1, 1],
        label="95% CI",
    )

    # Extrapolation
    ax.plot(
        0,
        ser[A.T_s],
        "x",
        label="Extrapolation",
        color=[1, 0, 0],
    )

    # Finishing
    ax.legend()
    fig.savefig(
        run_file,  # pyright: ignore [reportGeneralTypeIssues]  # matplotlib
        dpi=300,
    )

In [None]:
def plot_fits(df: pd.DataFrame, proj: Project, model) -> pd.DataFrame:
    """Get the latest new model fit plot."""
    if proj.params.do_plot:
        per_run(df, plot_new_fits, proj, model)
        if figs_src := sorted(proj.dirs.new_fits.iterdir()):
            figs_src = (
                figs_src[0],
                figs_src[len(figs_src) // 2],
                figs_src[-1],
            )
            figs_dst = (
                proj.dirs.plot_new_fit_0,
                proj.dirs.plot_new_fit_1,
                proj.dirs.plot_new_fit_2,
            )
            for fig_src, fig_dst in zip(figs_src, figs_dst):
                copy(fig_src, fig_dst)
    return df

In [None]:
from boilerdata.stages.modelfun import model_with_uncertainty

_ = df_in.also(write_metrics, proj, errors, fits).also(
    plot_fits, proj, model_with_uncertainty
)