In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import pandas as pd
import seaborn as sns
from optuna import Study
from datetime import datetime
from src.utils import Run, Sweep, set_directory, INFERENCE_LABELS
from src.analysis.inference import *

# Test errors


In [None]:
def add_dropout_amount(x: pd.Series):

    if "SGD" not in x.run.cfg.inference._target_ or x.run.cfg.inference.get("use_map", False):
        return x
    else:
        y = x
        y.label = f"SGD (dropout={x.run.cfg.model.dropout})"
        return y

runs = (
    pd.DataFrame(
        [
            {
                "run": Run(p),
                "multirun_start": datetime.strptime(
                    f"{p.parents[1].stem}/{p.parents[0].stem}",
                    r"%Y-%m-%d/%H-%M-%S",
                ),
                "index": int(p.stem),
            }
            for p in (Path("../experiment_results/cifar10_densenet/").glob("*/*/[012]"))
        ]
    )
)

In [None]:
latest_runs = (
    runs.assign(
        label=lambda x: x.run.map(lambda y: y.inference_label),
        lr=lambda x: x.run.map(
            lambda y: y.cfg.inference.sampler.lr
            if "sampler" in y.cfg.inference
            else y.cfg.inference.lr
        ),
    )
    .apply(add_dropout_amount, axis=1)
    .set_index(["label", "multirun_start", "index"])
    .sort_index()
    .reset_index(level=["label", "index"])
    .drop_duplicates(["label", "index"], keep="last")
    .set_index(["index"], append=True)
    .assign(val_err=lambda x: x.run.map(lambda y: y.get_scalar("err/val").iloc[-1]))
)
latest_runs

In [None]:
best_runs = latest_runs.loc[lambda x: x.groupby("label").val_err.idxmin()]
best_runs.assign(
    test_err=lambda x: x.run.map(lambda r: r.get_scalar("err/test").item())
)

In [None]:
all_runs = best_runs.run.tolist()
all_labels =  best_runs.label.tolist()

mcmc_runs = best_runs.loc[lambda x: x.label.str.contains("SGHMC")].run.tolist()
mcmc_labels =  best_runs.loc[lambda x: x.label.str.contains("SGHMC")].label.tolist()

In [None]:

(
    plot_val_err(all_runs, all_labels)
    .set(ylim=(None, 0.32))
    .savefig("../thesis/Figures/cifar-densenet-final-runs-val.pdf")
)
(
    get_test_err_table(all_runs, all_labels).to_latex(
        "../thesis/Tables/cifar-densenet-test-err.tex",
        escape=False,
        index=False,
        column_format="lc",
    )
)


# Calibration


In [None]:
plot_calibration(all_runs, all_labels, legend_cols=4)
plt.savefig("../thesis/Figures/cifar10-densenet-calibration.pdf")
get_ece_table(all_runs, all_labels).to_latex(
    "../thesis/Tables/cifar-densenet-ece.tex", index=False, escape=False, column_format="lc"
)

## Checking SGHMC assumptions

In [None]:
plot_temperatures(mcmc_runs, mcmc_labels)
plt.subplots_adjust(bottom=0.13)
plt.savefig("../thesis/Figures/cifar-densenet-temperatures.pdf")
get_temp_ci_table(mcmc_runs, mcmc_labels).to_latex(
    "../thesis/Tables/cifar-densenet-temperatures.tex",
    escape=False,
    index=False,
    column_format="lc",
)
