In [None]:
%load_ext autoreload
%autoreload 2

# from src.visualization.grad_variance_estimates import *
import matplotlib.pyplot as plt
import pandas as pd
from hydra.utils import instantiate
from typing import Tuple
import seaborn as sns
# from src.visualization.simulated import *
from src.utils import Run, EXPERIMENT_PATH
from src.analysis.sample_distribution import load_samples
from src.analysis.simulated import get_exact_posterior
from src.analysis.utils import get_variance_estimator

from src.analysis.colors import ColorPalette

In [None]:
run_dirs = (EXPERIMENT_PATH / "sghmc_stats"/"2021-11-29"/"15-42-37").glob("[0-9]")
runs = [Run(x) for x in  run_dirs]

In [None]:
def eval_poly(x, coeffs):
    return coeffs[0] + sum(c*x**i for i, c in enumerate(coeffs[1:], start=1))

dataset = instantiate(runs[0].cfg.data.dataset)
X, Y = dataset[:]

In [None]:
posterior = get_exact_posterior(X, Y)
sample_data = pd.concat(
    load_samples(r)
    .reset_index(level=["sampler", "batch_size"], drop=True)
    .assign(estimator = get_variance_estimator(r))
    .set_index("estimator", append=True)
    .reorder_levels(["estimator", "sample"])
    for r in runs
    if "sampler" in r.cfg["inference"]
)

In [None]:
# i = 1
# j = 3

# joint_plots = {}

# xlims = (-2.5, 1.0)
# ylims = (0.1, 0.5)

# for run in runs:

#     color = pal.get_color(run)
#     var_est = get_variance_estimator(run)
#     plot_sampled_joint_bivariate(
#         sample_data.loc[var_est],
#         exact_posterior=posterior,
#         xlims=xlims,
#         ylims=ylims,
#         i=i,
#         j=j,
#         color=color,
#     )

#     # fg = sns.pairplot(
#     #     sample_data,
#     #     kind="hist",
#     #     diag_kws={"stat": "density", "bins": 50, "rasterized": True},
#     #     plot_kws={"bins": 50},
#     # )
#     # fg.map_diag(plot_univariate, posterior=posterior)
#     # fg.map_offdiag(plot_bivariate, posterior=posterior)


#     # plt.savefig("simulated_joint_SGHMCWithVarianceEstimator_5.pdf")


# Temperatures

In [None]:
from src.analysis.temperatures import *
from src.analysis.utils import get_variance_estimator

temperature_samples = pd.concat(
    load_temperatures(run)
    .assign(estimator=get_variance_estimator(run))
    .set_index("estimator", append=True)
    .reorder_levels(["estimator", "parameter", "step"])
    for run in runs
).assign(T_k=lambda x: x.temperature_sum / x.n_params)

temperature_samples


In [None]:
pal = ColorPalette()
labeled_runs = {get_variance_estimator(r): r for r in runs}
palette, hue_order = pal.get_colors(labeled_runs)

In [None]:
with sns.color_palette(palette):
    fg = temperature_samples.loc[lambda x: x.T_k < 7].pipe(
        (sns.displot, "data"),
        x="temperature_sum",
        hue="estimator",
        kind="kde",
        hue_order=hue_order,
        common_norm=False,
        aspect=1.6,
        height=3
    )

    lines, texts = plot_temperature_chi2(fg, linestyle="dashed")
    plt.legend(lines, texts , frameon=False, title="Estimator")
    plt.tight_layout()
    plt.savefig("../thesis/Figures/temperature_sum_chi2_comp.pdf")


In [None]:
(
    get_frac_in_ci(temperature_samples, ["estimator"])
    .pipe(format_rate_with_95_ci, "frac_in_ci", "count")
    .rename("$\E[\hat{T}_K \in J_{T_K}(d, {0.99})]$")
    .rename_axis("Estimator")
    .reset_index()
    .to_latex(
        "../thesis/Tables/var-estimators-temperatures.tex",
        escape=False,
        index=False,
        column_format="lc",
    )
)
