In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import polars as pls
from src.experiments.common import set_directory


In [None]:
from optuna import Study
from pathlib import Path
from datetime import datetime
import seaborn as sns


studies = {}
with set_directory(".."):
    optuna_storages = list(Path("optuna_storages/").glob("mnist*"))
    for storage in optuna_storages:
        studies[storage.stem] = Study(storage.stem, storage=f"sqlite:///{storage}")


In [None]:
from functools import cache
from src.experiments.common import EXPERIMENT_PATH

In [None]:
# @cache
# def get_tb_statistics(path: Path, tag="err/val"):
#     accumulator = EventAccumulator(str(path))
#     accumulator.Reload()
#     tag = "err/val"
#     index = [x.step for x in accumulator.Scalars(tag)]
#     values = [x.value for x in accumulator.Scalars(tag)]
#     statistics = pd.Series(values, name=tag, index=pd.Index(index, name="step"))
#     return statistics


In [None]:
from dataclasses import dataclass


@dataclass
class Sweep:
    study: Study

    def summary(self) -> pd.DataFrame:

        return pd.DataFrame.from_records(
            (
                {
                    "trial": trial.number,
                    "datetime_start": trial.datetime_start,
                    "err/val": trial.value,
                    **trial.params,
                }
                for trial in self.study.trials
            ),
            index="trial",
        ).sort_values("err/val")

    def loss(self) -> pd.DataFrame:
        return pd.concat(
            pd.DataFrame({"err/val": trial.intermediate_values.values()})
            .rename_axis(index="step")
            .assign(trial=trial.number)
            .set_index("trial", append=True)
            for trial in self.study.trials
        )


In [None]:
combined_loss_data = (
    pd.concat(
        Sweep(study).loss().assign(study=name).set_index("study", append=True)
        for name, study in studies.items()
    )
    .reorder_levels(["study", "trial", "step"])
)
combined_summaries_data = (
    pd.concat(
        Sweep(study).summary().assign(study=name).set_index("study", append=True)
        for name, study in studies.items()
    )
    .reorder_levels(["study", "trial"])
)   

In [None]:
best_runs = combined_summaries_data["err/val"].groupby("study").idxmin()
best_runs.pipe(pd.DataFrame)


In [None]:
t = combined_summaries_data.loc["mnist-sghmc-var-est"].sort_index()

In [None]:
(
    combined_loss_data
    .unstack(level="step")
    .loc[best_runs]
    .stack(level="step")
    .reset_index()
    .pipe((sns.relplot, "data"), x="step", y="err/val", hue="study", kind="line", aspect=1.6)
    .set(ylim=(0.01, 0.03))
    .savefig("../thesis/Figures/mnist-best-runs-val-curves.pdf")
)


In [None]:
import math


def rename_cols(x):

    if x == "err/val":
        return "val. error"
    else:
        return f"\\texttt{{{x.split('.')[-1]}}}".replace("_", "\_")


def format_tex(float_number):
    exponent = math.floor(math.log10(float_number))
    mantissa = float_number / 10 ** exponent
    mantissa_format = str(mantissa)[0:4]
    return "${0}\\times10^{{{1}}}$".format(mantissa_format, str(int(exponent)))


def to_latex(data: pd.DataFrame):

    n_cols = len(data.columns)
    return data.to_latex(
        f"../thesis/Tables/{key}-hparams.tex",
        escape=False,
        formatters={r"\texttt{lr}": format_tex},
        column_format= "l" + n_cols*r"p{2.3cm}" 
    )


for key, study in studies.items():
    (
        Sweep(study)
        .summary()
        .drop(columns="datetime_start")
        .head(10)
        .rename(columns=rename_cols)
        .pipe(to_latex)
    )


In [None]:
for key in studies:
    print(
f"""
\\begin{{table}}[htbp]
    \\centering
    \\input{{Tables/{key}-hparams}}
    \\caption{{Top 10 hyperparameters for INFERENCE according to optuna sweep.}}
    \\label{{tab:{key}-hparams}}
\end{{table}}
"""
)

In [None]:

Sweep(studies["mnist-sghmc"]).summary().drop(columns="datetime_start").head(10)


In [None]:
t = combined_loss_data.loc["mnist-sghmc"].groupby("trial").count()

In [None]:
a : Study = studies["mnist-vi"]
a.best_params

In [None]:
a : Study= studies["mnist-sghmc-var-est"]
a.best_params

In [None]:
Sweep(studies["mnist-sgd-map"]).loss().pipe((sns.relplot, "data"), x="step", hue="trial", y="err/val", kind="line").set(ylim=(0.015,0.04))

TODO:
- [ ] Test med test set
- [ ] Test med test set, v√¶gtet ift val logprob