In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import pandas as pd
import seaborn as sns
from optuna import Study
from src.utils import Run, Sweep, set_directory
from src.analysis.colors import get_color, get_colors
import matplotlib.pyplot as plt
from src.analysis.inference import *

# Sweeps

In [None]:
sweeps = {}
with set_directory(".."):
    optuna_storages = list(Path("optuna_storages/").glob("mnist*"))
    for storage in optuna_storages:
        if "corr" in storage.stem:
            continue
        sweeps[storage.stem] = Sweep(Study(storage.stem, storage=f"sqlite:///{storage}"))

In [None]:
sweeps

In [None]:
combined_loss_data = pd.concat(
    sweep.loss().assign(study=name).set_index("study", append=True)
    for name, sweep in sweeps.items()
).reorder_levels(["study", "trial", "step"])
combined_summaries_data = pd.concat(
    sweep.summary().assign(study=name).set_index("study", append=True)
    for name, sweep in sweeps.items()
).reorder_levels(["study", "trial"])


In [None]:
best_runs = combined_summaries_data["err/val"].groupby("study").idxmin()
best_runs.pipe(pd.DataFrame)


In [None]:
(
    combined_loss_data.unstack(level="step")
    .loc[best_runs]
    .stack(level="step")
    .reset_index()
    .pipe(
        (sns.relplot, "data"),
        x="step",
        y="err/val",
        hue="study",
        kind="line",
        aspect=1.6,
    )
    .set(ylim=(0.01, 0.03))
    .savefig("../thesis/Figures/mnist-best-runs-val-curves.pdf")
)


In [None]:
import math


def rename_cols(x):

    if x == "err/val":
        return "Val. error"
    else:

        return f"\\texttt{{{x.split('.')[-1]}}}".replace("_", "\_")


def format_sctf(float_number):
    exponent = math.floor(math.log10(float_number))
    mantissa = float_number / 10 ** exponent
    mantissa_format = str(mantissa)[0:4]
    return "${0}\\times10^{{{1}}}$".format(mantissa_format, str(int(exponent)))


def to_latex(data: pd.DataFrame):

    n_cols = len(data.columns)
    return data.to_latex(
        f"../thesis/Tables/{key}-hparams.tex",
        escape=False,
        formatters={r"\texttt{lr}": format_sctf},
        column_format="l" + n_cols * r"p{2.3cm}",
    )


for key, sweep in sweeps.items():
    (
        sweep
        .summary()
        .drop(columns="datetime_start")
        .head(10)
        .rename(columns=rename_cols)
        .pipe(to_latex)
    )


In [None]:
for key in sweeps:
    print(
        f"""
\\begin{{table}}[htbp]
    \\centering
    \\resizebox{{
        \\ifdim\\width>\\columnwidth
        \\columnwidth
      \\else
        \\width
      \\fi
    }}{{!}}{{\\small
    \\input{{Tables/{key}-hparams}}
    }}
    \\caption{{Top 10 hyperparameters for INFERENCE according to optuna sweep.}}
    \\label{{tab:{key}-hparams}}
\end{{table}}
"""
    )


# Test errors


In [None]:
mcmc_dir = Path("../experiment_results/mnist/2021-12-20/15-52-53")
mcmc_runs = list(map(Run, mcmc_dir.glob("[01]/")))

other_dir = Path("../experiment_results/mnist/2021-12-21/07-58-11")
other_runs = list(map(Run, other_dir.glob("[012]/")))

all_runs = other_runs + mcmc_runs


In [None]:
(
    plot_val_err(all_runs)
    .set(ylim=(None, 0.03))
    .savefig("../thesis/Figures/mnist-final-runs-val.pdf")
)
(
    get_test_err_table(all_runs).to_latex(
        "../thesis/Tables/mnist-test-err.tex",
        escape=False,
        index=False,
        column_format="lc",
    )
)



## Downsampling MCMC samples



In [None]:
(
    plot_mcmc_downsampling(mcmc_runs)
    .set(ylim=(None, 0.022))
    .savefig("../thesis/Figures/mnist-downsampling.pdf")
)


# Calibration


In [None]:
plot_calibration(all_runs)
plt.savefig("../thesis/Figures/mnist-calibration.pdf")

get_ece_table(all_runs).to_latex(
    "../thesis/Tables/mnist-ece.tex", index=False, escape=False, column_format="lc"
)


## Checking SGHMC assumptions

In [None]:
plot_temperatures(mcmc_runs)
plt.savefig("../thesis/Figures/mnist-temperatures.pdf")
get_temp_ci_table(mcmc_runs).to_latex(
    "../thesis/Tables/mnist-temperatures.tex",
    escape=False,
    index=False,
    column_format="lc",
)
