In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from src.experiments.common import get_run_from_path, EXPERIMENT_PATH, set_directory
from src.experiments.wrangle import get_multirun_statistics

In [None]:
from optuna import Study
from pathlib import Path
from datetime import datetime
import seaborn as sns


studies = {}
with set_directory(".."):
    optuna_storages = list(Path("optuna_storages/").glob("*"))
    for storage in optuna_storages:
        studies[storage.stem] = Study(storage.stem, storage=f"sqlite:///{storage}")


In [None]:
from functools import cache
from src.experiments.common import EXPERIMENT_PATH


In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

# EventAccumulator(str(run.path / "metrics"))}


In [None]:
@cache
def get_tb_statistics(path: Path, tag="err/val"):
    accumulator = EventAccumulator(str(path))
    accumulator.Reload()
    tag = "err/val"
    index = [x.step for x in accumulator.Scalars(tag)]
    values = [x.value for x in accumulator.Scalars(tag)]
    statistics = pd.Series(values, name=tag, index=pd.Index(index, name="step"))
    return statistics


In [None]:
from dataclasses import dataclass


@dataclass
class Sweep:
    study: Study

    def summary(self) -> pd.DataFrame:

        return pd.DataFrame.from_records(
            (
                {"trial": trial.number, "err/val": trial.value, **trial.params}
                for trial in self.study.trials
            ),
            index="trial",
        ).sort_values("err/val")

    def loss(self) -> pd.DataFrame:
        return pd.concat(
            pd.DataFrame({"err/val": trial.intermediate_values.values()})
            .rename_axis(index="step")
            .assign(trial=trial.number)
            .set_index("trial", append=True)
            for trial in self.study.trials
        )


In [None]:
combined_loss_data = pd.concat(
    Sweep(study).loss().assign(study=name).set_index("study", append=True)
    for name, study in studies.items()
).reorder_levels(["study", "trial", "step"])


In [None]:
best_runs = (
    combined_loss_data["err/val"]
    .groupby(level=["study", "trial"])
    .last()
    .groupby("study")
    .idxmin()
)
best_runs.pipe(pd.DataFrame)

In [None]:
(
    combined_loss_data
    .unstack(level="step")
    .loc[best_runs]
    .stack(level="step")
    .reset_index()
    .pipe((sns.relplot, "data"), x="step", y="err/val", hue="study", kind="line")
    .set(ylim=(0.01, 0.05))
)


In [None]:
t = combined_loss_data.loc["mnist-sghmc"].groupby("trial").count()

In [None]:
a : Study= studies["mnist-sghmc"]
a.trials

In [None]:
a : Study= studies["mnist-sghmc-var-est"]
a.best_params