In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import pandas as pd
import seaborn as sns
from optuna import Study
from src.utils import Run, Sweep, set_directory


# Sweeps

In [None]:
studies = {}
with set_directory(".."):
    optuna_storages = list(Path("optuna_storages/").glob("mnist*"))
    for storage in optuna_storages:
        if "corr" in storage.stem:
            continue
        studies[storage.stem] = Study(storage.stem, storage=f"sqlite:///{storage}")


In [None]:
combined_loss_data = (
    pd.concat(
        Sweep(study).loss().assign(study=name).set_index("study", append=True)
        for name, study in studies.items()
    )
    .reorder_levels(["study", "trial", "step"])
)
combined_summaries_data = (
    pd.concat(
        Sweep(study).summary().assign(study=name).set_index("study", append=True)
        for name, study in studies.items()
    )
    .reorder_levels(["study", "trial"])
)   

In [None]:
best_runs = combined_summaries_data["err/val"].groupby("study").idxmin()
best_runs.pipe(pd.DataFrame)

In [None]:
(
    combined_loss_data
    .unstack(level="step")
    .loc[best_runs]
    .stack(level="step")
    .reset_index()
    .pipe((sns.relplot, "data"), x="step", y="err/val", hue="study", kind="line", aspect=1.6)
    .set(ylim=(0.01, 0.03))
    .savefig("../thesis/Figures/mnist-best-runs-val-curves.pdf")
)


In [None]:
import math

def rename_cols(x):

    if x == "err/val":
        return "val. error"
    else:
        
        return f"\\texttt{{{x.split('.')[-1]}}}".replace("_", "\_")

def format_sctf(float_number):
    exponent = math.floor(math.log10(float_number))
    mantissa = float_number / 10 ** exponent
    mantissa_format = str(mantissa)[0:4]
    return "${0}\\times10^{{{1}}}$".format(mantissa_format, str(int(exponent)))


def to_latex(data: pd.DataFrame):

    n_cols = len(data.columns)
    return data.to_latex(
        f"../thesis/Tables/{key}-hparams.tex",
        escape=False,
        formatters={r"\texttt{lr}": format_sctf},
        column_format= "l" + n_cols*r"p{2.3cm}" 
    )


for key, study in studies.items():
    (
        Sweep(study)
        .summary()
        .drop(columns="datetime_start")
        .head(10)
        .rename(columns=rename_cols)
        .pipe(to_latex)
    )


In [None]:
for key in studies:
    print(
f"""
\\begin{{table}}[htbp]
    \\centering
    \\resizebox{{
        \\ifdim\\width>\\columnwidth
        \\columnwidth
      \\else
        \\width
      \\fi
    }}{{!}}{{\\small
    \\input{{Tables/{key}-hparams}}
    }}
    \\caption{{Top 10 hyperparameters for INFERENCE according to optuna sweep.}}
    \\label{{tab:{key}-hparams}}
\end{{table}}
"""
)

# Test errors


In [None]:
# mcmc_dir = Path("../experiment_results/mnist/2021-12-16/13-07-51/")
mcmc_dir = Path("../experiment_results/mnist/2021-12-17/15-58-26/")
mcmc_runs = list(map(Run, mcmc_dir.glob("[01]/")))

other_dir  = Path("../experiment_results/mnist/2021-12-17/11-01-32/")
other_runs = list(map(Run, other_dir.glob("[012]/")))

all_runs = other_runs + mcmc_runs 

In [None]:
import hydra
dm = hydra.utils.instantiate(all_runs[0].cfg.data)
dm.setup()
n_test = len(dm.test_data)

In [None]:
import numpy as np
from math import sqrt


def get_err_incl_ci(error: float) -> str:
    pm = sqrt(error * (1 - error) / n_test) * 1.96
    return f"${error:.3} \\pm {pm:.2}$"


(
    pd.DataFrame.from_dict(
        {
            run.inference_label: {"err/test": run.get_scalar("err/test").iloc[0]}
            for run in all_runs
        },
        orient="index",
    )
    .apply({"err/test" :get_err_incl_ci})
    .rename(columns={"err/test": "Test error incl. 95\\% CI"})
    .to_latex("../thesis/Tables/mnist_test_err.tex", escape=False)
)


In [None]:
from textwrap import wrap

## Reweighting MCMC samples



In [None]:
mcmc_dir_ = Path("../experiment_results/mnist/2021-12-19/20-34-46/")
mcmc_runs_ = list(map(Run, mcmc_dir_.glob("[01]/")))


In [None]:
(
    pd.concat(
        pd.read_json(r.dir / "sample_resampling_curve.json")
        .rename_axis(index=["n_sampled"])
        .assign(sampler=r.inference_label)
        .set_index("sampler", append=True)
        .reorder_levels(["sampler", "n_sampled"])
        .sort_index()
        for r in mcmc_runs_
    )
    .reset_index()
    .pipe((sns.relplot, "data"), x="n_sampled", y="error_rate", hue="sampler")
)


In [None]:
# from typing import Dict
# from torch import Tensor


# def get_records(data: Dict[int, Dict[int, Tensor]]):
#     return pd.DataFrame.from_dict(
#         {
#             i: pd.Series(
#                 {batch_idx: avg.item() for batch_idx, avg in x.items()},
#                 name="avg_likelihood",
#             ).rename_axis(index=["batch"])
#             for i, x in data.items()
#         }
#     ).rename_axis(columns="parameter_sample")


# avg_likelihoods = pd.concat(
#     get_records(torch.load(r.dir / "val_avg_likelihood.pt"))
#     .assign(sampler=r.inference_label)
#     .set_index("sampler", append=True)
#     .reorder_levels(["sampler", "batch"])
#     for r in mcmc_runs_
# )

# val_joint_logliks = (
#     pd.concat(
#         get_records(torch.load(r.dir / "val_joint_logliks.pt"))
#         .assign(sampler=r.inference_label)
#         .set_index("sampler", append=True)
#         .reorder_levels(["sampler", "batch"])
#         for r in mcmc_runs_
#     )
#     .stack("parameter_sample")
#     .groupby(level=["sampler", "parameter_sample"])
#     .sum()
#     .sort_index()
# )


In [None]:
# train_joint_logliks = (
#     pd.DataFrame(
#         {
#             r.inference_label: pd.Series(
#                 {
#                     i: x.item()
#                     for i, x in torch.load(r.dir / "train_joint_logliks.pt").items()
#                 }
#             )
#             for r in mcmc_runs_
#         }
#     )
#     .rename_axis(index=["parameter_sample"], columns=["sampler"])
#     .stack("sampler")
#     .reorder_levels(["sampler", "parameter_sample"])
#     .sort_index()
# )
# train_joint_logliks


In [None]:
# log_ratio = torch.tensor((val_joint_logliks/10_000 - train_joint_logliks/50_000).values)
# un_normalized_weights = log_ratio.exp()
# weights = un_normalized_weights / un_normalized_weights.sum()

Conclusion is that validation and training are too different

# Calibration


In [None]:
other_dir_ = Path("../experiment_results/mnist/2021-12-19/20-34-23/")
other_runs_ = list(map(Run, other_dir_.glob("[012]/")))

In [None]:
import os

in_file = Path("tmp.csv")
out_file = Path("../thesis/Figures/mnist-calibration.pdf")
pd.concat(
    pd.read_csv(r.dir / "calibration_data.csv", index_col=0).assign(
        inference=r.inference_label
    )
    for r in mcmc_runs_ + other_runs_
).to_csv(in_file)

cmd = (
    f"IN_FILE={in_file.resolve().absolute()} "
    f"OUT_FILE={out_file.resolve().absolute()} "
    "Rscript ../src/visualization/calibration_curves.r"
)
os.system(cmd)
in_file.unlink()

## Checking SGHMC assumptions

In [None]:
import torch

temperatures = pd.concat(
    pd.DataFrame.from_dict(
        torch.load(run.dir / "temperature_samples.pt"),
        orient="index",
    )
    .rename_axis(index=["step", "parameter"])
    .loc[lambda x: x.index.get_level_values("step") % 50 == 0]
    .assign(Sampler="\n".join(wrap(run.inference_label, width=12)))
    .set_index("Sampler", append=True)
    .reorder_levels(["Sampler", "parameter", "step"])
    for run in mcmc_runs
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import chi2

global black_line


def plot_chi2(df, **kwargs):
    global black_line
    xlim = plt.gca().axes.get_xlim()
    xx = np.linspace(*xlim, 300)
    yy = chi2(df.iloc[0]).pdf(xx)
    black_line = plt.plot(xx, yy, color="black", label="true")


fg = sns.displot(
    data=temperatures.reset_index(),
    x="temperature_sum",
    hue="Sampler",
    kind="kde",
    col="parameter",
    height=1.8,
    col_wrap=2,
    aspect=2,
    common_norm=False,
    facet_kws={"sharex": False, "sharey": False},
)
fg.map(plot_chi2, "n_params")

handles = fg.legend.legendHandles
texts = [t.get_text() for t in fg.legend.texts]
fg.legend.remove()  # Remove seaborn legens

handles += black_line
texts += ["True\ndistribution"]

plt.subplots_adjust(bottom=0.2, right=1)
fg.set_xlabels("Temperature")
plt.figlegend(
    dict(zip(texts, handles)),
    title="Sampler",
    ncol=3,
    frameon=False,
    loc="lower center",
)
plt.savefig("../thesis/Figures/mnist-temperatures.pdf")
