In [None]:
from pathlib import Path
from dataclasses import dataclass
from typing import Union
from omegaconf import OmegaConf
import pandas as pd
dir_ = Path("../experiment_results/mnist/2021-12-16/13-07-51/")


In [None]:
@dataclass
class Run:
    dir: Union[str, Path]

    @property
    def cfg(self):
        return OmegaConf.load(self.dir / ".hydra" / "config.yaml")

runs = list(map(Run, dir_.glob("[01]/")))

In [None]:
curves = pd.concat(
    pd.read_json(run.dir / "sample_resampling_curve.json")
    .rename_axis(index=["n_sampled"])
    .assign(sampler=run.cfg.inference.sampler._target_.split(".")[-1])
    .set_index("sampler", append=True)
    .reorder_levels(["sampler", "n_sampled"])
    .sort_index()
    for run in runs
)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_style("white")
sns.lineplot(data=curves.reset_index(), x="n_sampled", y="error_rate", hue="sampler")
sns.despine()
h, l = plt.gca().get_legend_handles_labels()
plt.gca().legend(h, l, frameon=False)
plt.ylim(None, 0.02)

In [None]:
import torch
a = torch.load(runs[0].dir / "temperature_samples.pt")

In [None]:
import torch

temperatures = pd.concat(
    pd.DataFrame.from_dict(
        torch.load(run.dir / "temperature_samples.pt"),
        orient="index",
    )
    .rename_axis(index=["step", "parameter"])
    .loc[lambda x: x.index.get_level_values("step") % 50 == 0]
    .assign(sampler=run.cfg.inference.sampler._target_.split(".")[-1])
    .set_index("sampler", append=True)
    .reorder_levels(["sampler", "parameter", "step"])
    for run in runs
)

In [None]:
import numpy as np
from scipy.stats import chi2
def plot_chi2(df, **kwargs):
    xlim = plt.gca().axes.get_xlim()
    xx = np.linspace(*xlim, 300)
    yy = chi2(df.iloc[0]).pdf(xx)
    plt.plot(xx, yy, color="black")

fg = sns.displot(
    data=temperatures.reset_index(),
    x="temperature_sum",
    hue="sampler",
    kind="kde",
    col="parameter",
    col_wrap=3,
    common_norm=False,
    facet_kws={"sharex":False, "sharey":False},
)
fg.map(plot_chi2, "n_params")