In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import pandas as pd
import seaborn as sns
from optuna import Study
from src.utils import Run, Sweep, set_directory


# Sweeps

In [None]:
studies = {}
with set_directory(".."):
    optuna_storages = list(Path("optuna_storages/").glob("mnist*"))
    for storage in optuna_storages:
        if "corr" in storage.stem:
            continue
        studies[storage.stem] = Study(storage.stem, storage=f"sqlite:///{storage}")


In [None]:
combined_loss_data = (
    pd.concat(
        Sweep(study).loss().assign(study=name).set_index("study", append=True)
        for name, study in studies.items()
    )
    .reorder_levels(["study", "trial", "step"])
)
combined_summaries_data = (
    pd.concat(
        Sweep(study).summary().assign(study=name).set_index("study", append=True)
        for name, study in studies.items()
    )
    .reorder_levels(["study", "trial"])
)   

In [None]:
best_runs = combined_summaries_data["err/val"].groupby("study").idxmin()
best_runs.pipe(pd.DataFrame)

In [None]:
(
    combined_loss_data
    .unstack(level="step")
    .loc[best_runs]
    .stack(level="step")
    .reset_index()
    .pipe((sns.relplot, "data"), x="step", y="err/val", hue="study", kind="line", aspect=1.6)
    .set(ylim=(0.01, 0.03))
    .savefig("../thesis/Figures/mnist-best-runs-val-curves.pdf")
)


In [None]:
import math

def rename_cols(x):

    if x == "err/val":
        return "val. error"
    else:
        
        return f"\\texttt{{{x.split('.')[-1]}}}".replace("_", "\_")

def format_sctf(float_number):
    exponent = math.floor(math.log10(float_number))
    mantissa = float_number / 10 ** exponent
    mantissa_format = str(mantissa)[0:4]
    return "${0}\\times10^{{{1}}}$".format(mantissa_format, str(int(exponent)))


def to_latex(data: pd.DataFrame):

    n_cols = len(data.columns)
    return data.to_latex(
        f"../thesis/Tables/{key}-hparams.tex",
        escape=False,
        formatters={r"\texttt{lr}": format_sctf},
        column_format= "l" + n_cols*r"p{2.3cm}" 
    )


for key, study in studies.items():
    (
        Sweep(study)
        .summary()
        .drop(columns="datetime_start")
        .head(10)
        .rename(columns=rename_cols)
        .pipe(to_latex)
    )


In [None]:
for key in studies:
    print(
f"""
\\begin{{table}}[htbp]
    \\centering
    \\resizebox{{
        \\ifdim\\width>\\columnwidth
        \\columnwidth
      \\else
        \\width
      \\fi
    }}{{!}}{{\\small
    \\input{{Tables/{key}-hparams}}
    }}
    \\caption{{Top 10 hyperparameters for INFERENCE according to optuna sweep.}}
    \\label{{tab:{key}-hparams}}
\end{{table}}
"""
)

# Test errors


In [None]:
mcmc_dir = Path("../experiment_results/mnist/2021-12-16/13-07-51/")
mcmc_runs = list(map(Run, mcmc_dir.glob("[01]/")))

other_dir  = Path("../experiment_results/mnist/2021-12-17/11-01-32/")
other_runs = list(map(Run, other_dir.glob("[012]/")))

all_runs = other_runs + mcmc_runs 

In [None]:
import hydra
dm = hydra.utils.instantiate(all_runs[0].cfg.data)
dm.setup()
n_test = len(dm.test_data)

In [None]:
import numpy as np
from math import sqrt


def get_err_incl_ci(error: float) -> str:
    pm = sqrt(error * (1 - error) / n_test) * 1.96
    return f"${error:.3} \\pm {pm:.2}$"


(
    pd.DataFrame.from_dict(
        {
            run.inference_label: {"err/test": run.get_scalar("err/test").iloc[0]}
            for run in all_runs
        },
        orient="index",
    )
    .apply({"err/test" :get_err_incl_ci})
    .rename(columns={"err/test": "Test error incl. 95\\% CI"})
    .to_latex("../thesis/Tables/mnist_test_err.tex", escape=False)
)
