In [None]:
from pathlib import Path
import pandas as pd
import seaborn as sns
from optuna import Study
from src.utils import Run, Sweep, set_directory, INFERENCE_LABELS
import matplotlib.pyplot as plt
import torch
from src.models.base import ErrorRate
from src.analysis.inference import *

# Sweeps

In [None]:
sweeps = {}
with set_directory(".."):
    optuna_storages = list(Path("optuna_storages/").glob("cifar-small*"))
    for storage in optuna_storages:
        sweeps[storage.stem] = Sweep(
            Study(storage.stem, storage=f"sqlite:///{storage}")
        )


In [None]:
combined_loss_data = pd.concat(
    sweep.loss().assign(study=name).set_index("study", append=True)
    for name, sweep in sweeps.items()
).reorder_levels(["study", "trial", "step"])

combined_summaries_data = pd.concat(
    sweep.summary().assign(study=name, run=sweep.runs()).set_index("study", append=True)
    for name, sweep in sweeps.items()
).reorder_levels(["study", "trial"])

In [None]:
best_runs = (
    combined_summaries_data[["err/val", "run"]]
    .loc[lambda x: x["err/val"].groupby("study").idxmin()]
)
best_runs

In [None]:
def get_label(x: Run):
    return x.inference_label

inference_labels = (
    best_runs["run"]
    .apply(get_label)
    .reset_index("trial", drop=True)
    .rename("inference")
)

In [None]:
(
    combined_loss_data.unstack(level="step")
    .loc[best_runs.index.values]
    .stack(level="step")
    .reset_index()
    .join(inference_labels, on="study")
    .pipe(
        (sns.relplot, "data"),
        x="step",
        y="err/val",
        hue="inference",
        kind="line",
        aspect=1.6,
        hue_order=INFERENCE_LABELS.values()
    )
    .set(xlim=(0,200))
    .savefig("../thesis/Figures/cifar-small-best-runs-val-curves.pdf")
)


In [None]:
import math

def rename_cols(x):

    if x == "err/val":
        return "val. error"
    else:

        return f"\\texttt{{{x.split('.')[-1]}}}".replace("_", "\_")


def format_sctf(float_number):
    exponent = math.floor(math.log10(float_number))
    mantissa = float_number / 10 ** exponent
    mantissa_format = str(mantissa)[0:4]
    return "${0}\\times10^{{{1}}}$".format(mantissa_format, str(int(exponent)))


def to_latex(data: pd.DataFrame):

    n_cols = len(data.columns)
    return data.to_latex(
        f"../thesis/Tables/{key}-hparams.tex",
        escape=False,
        formatters={r"\texttt{lr}": format_sctf},
        column_format="l" + n_cols * r"p{2.3cm}",
    )


for key, sweep in sweeps.items():
    (
        sweep
        .summary()
        .drop(columns="datetime_start")
        .head(10)
        .rename(columns=rename_cols)
        .pipe(to_latex)
    )


In [None]:
for key, value in INFERENCE_LABELS.items():
    print(
        f"""
\\begin{{table}}[H]
    \\centering
    \\resizebox{{
        \\ifdim\\width>\\columnwidth
        \\columnwidth
      \\else
        \\width
      \\fi
    }}{{!}}{{\\small
    \\input{{Tables/cifar-small-{key}-hparams}}
    }}
    \\caption{{Top 10 hyperparameters for {value} on CIFAR10 dataset according to optuna sweep.}}
    \\label{{tab:cifar-small-{key}-hparams}}
\end{{table}}"""
    )


# Test errors


In [None]:
mcmc_dir = Path("../experiment_results/cifar10_small/2021-12-26/19-13-33")
mcmc_runs = list(map(Run, mcmc_dir.glob("[01]/")))

other_dir = Path("../experiment_results/cifar10_small/2021-12-26/19-13-36")
other_runs = list(map(Run, other_dir.glob("[012]/")))

all_runs = other_runs + mcmc_runs

In [None]:
(
    plot_val_err(all_runs)
    .set(ylim=(None, 0.32))
    .savefig("../thesis/Figures/cifar-small-final-runs-val.pdf")
)
(
    get_test_err_table(all_runs).to_latex(
        "../thesis/Tables/cifar-small-test-err.tex",
        escape=False,
        index=False,
        column_format="lc",
    )
)

## Downsampling MCMC samples



In [None]:
plot_mcmc_downsampling(all_runs)
plt.savefig("../thesis/Figures/cifar-small-downsampling.pdf")

In [None]:
test_err_10_ensemble = {}

In [None]:
from tqdm import tqdm

vi_run = next(r for r in other_runs if "VI" in r.inference_label)
best_step = vi_run.get_scalar("err/val").idxmin()
best_ckpt = next(vi_run.dir.glob(f"**/*{best_step}.ckpt"))
inference = hydra.utils.instantiate(vi_run.cfg.inference, n_particles=10)
inference.load_state_dict(torch.load(best_ckpt, map_location="cpu")["state_dict"])

test_err_10_ensemble["VI"] = ErrorRate()
for x, y in tqdm(dm.test_dataloader()):
    output = torch.stack(inference.forward_particles(x))
    preds = inference.model.predict_gvn_output(output).mean(0)
    test_err_10_ensemble["VI"].update(preds, y)

In [None]:
sgd_runs = (r for r in other_runs if "SGD" in r.inference_label)
for sgd_run in sgd_runs:
    inference = hydra.utils.instantiate(sgd_run.cfg.inference)
    inference.eval()
    test_err_10_ensemble[sgd_run.inference_label] = ErrorRate()
    for x, y in tqdm(dm.test_dataloader()):
        outputs = []
        for ckpt in sgd_run.dir.glob(f"**/epoch=*.ckpt"):
            inference.load_state_dict(torch.load(ckpt, map_location="cpu")["state_dict"])
            outputs.append(inference.model(x).softmax(-1))
        output = torch.stack(outputs).mean(0)
        test_err_10_ensemble[sgd_run.inference_label].update(output, y)



In [None]:
test_err_10_ensemble['SGD (dropout)'].compute()

# Calibration


In [None]:
plot_calibration(all_runs)
plt.savefig("../thesis/Figures/cifar-small-calibration.pdf")

get_ece_table(all_runs).to_latex(
    "../thesis/Tables/cifar-small-ece.tex", index=False, escape=False, column_format="lc"
)

## Checking SGHMC assumptions

In [None]:
plot_temperatures(mcmc_runs)
plt.subplots_adjust(bottom=0.13)
plt.savefig("../thesis/Figures/cifar-small-temperatures.pdf")
get_temp_ci_table(mcmc_runs).to_latex(
    "../thesis/Tables/cifar-small-temperatures.tex",
    escape=False,
    index=False,
    column_format="lc",
)