In [None]:
%load_ext autoreload
%autoreload 2

from src.utils import EXPERIMENT_PATH, Run
from src.analysis.simulated import (
    get_exact_posterior, 
    plot_sampled_joint_bivariate,
    plot_sampled_distributions_pairs, 
    )
from src.analysis.sample_distribution import load_samples
from src.analysis.utils import get_variance_estimator
from src.analysis.gradient_variance_estimators import get_variance_estimates
from src.analysis.colors  import get_colors, get_color
from pathlib import Path
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd
from hydra.utils import instantiate
import seaborn as sns


In [None]:
run_dirs = (EXPERIMENT_PATH / "sghmc_gradients"/"2021-11-29"/"09-07-46").glob("[0-9]")
runs = [Run(x) for x in  run_dirs]

def eval_poly(x, coeffs):
    return coeffs[0] + sum(c*x**i for i, c in enumerate(coeffs[1:], start=1))

dataset = instantiate(runs[0].cfg.data.dataset)
X, Y = dataset[:]

In [None]:
gradient_data = pd.concat(
    get_variance_estimates(run._dir / "variance_estimates.pt")
    .assign(estimator = get_variance_estimator(run))
    .set_index("estimator", append=True)
    .reorder_levels(["estimator", "step", "name"])
    .sort_index()
    for run in runs
)


In [None]:
rel_errs = (
    gradient_data.stack("parameter")
    .unstack("name")
    .assign(
        rel_err=lambda x: (x.estimated_variance - x.observed_variance).abs()
        / x.observed_variance
    )
    .rel_err.groupby(["estimator", "parameter"])
    .mean()
)
rel_errs.unstack()

In [None]:
from src.analysis.utils import add_column_level_, embolden_, format_as_percent
(
    rel_errs
    .apply(format_as_percent)
    .unstack("parameter")
    .rename(columns=lambda x: f"$a_{x}$")
    .rename_axis(index={"estimator": "Estimator"}, columns={"parameter": "Parameter"})
    .pipe(embolden_, "ExpWeightedEstimator")
    .pipe(add_column_level_, "Average relative error")
    .to_latex(
        "../thesis/Tables/simultated_variance_estimations.tex",
        escape=False,
        multicolumn_format="c",
        column_format="lcccc"
    )
)

In [None]:
fg: sns.FacetGrid = (
    gradient_data.loc[
        lambda x: x.index.get_level_values("estimator") != "ConstantEstimator"
    ]
    .unstack(level="name")
    .stack("parameter")
    .reset_index()
    .pipe(
        (sns.relplot, "data"),
        x="observed_variance",
        y="estimated_variance",
        hue="parameter",
        col="estimator",
        facet_kws={
            "sharex": False,
            "sharey": False,
        },
        palette="crest",
        rasterized=True,
        marker=".",
        s=12,
        height=2.5,
        edgecolors=None,
        aspect=0.8,
    )
)


for ax in fg.axes.flatten():
    ax.axline((1, 1), slope=1, color="red", linestyle="--")
    new_title = ax.title.get_text().split(" = ")[-1]
    ax.set_title(new_title)
fg.set(
    xlim=[1, 1e5],
    ylim=[1, 1e5],
    yscale="log",
    xscale="log",
    xlabel="Observed variance",
    ylabel="Estimated variance",
)
fg.tight_layout()
plt.savefig(f"../thesis/Figures/simulated_sghmc_gradient_variance_estimations.pdf")


# Plotting resulting distribution

In [None]:
posterior = get_exact_posterior(X, Y)
sample_data = pd.concat(
    load_samples(r)
    .reset_index(level=["sampler", "batch_size"], drop=True)
    .assign(estimator = get_variance_estimator(r))
    .set_index("estimator", append=True)
    .reorder_levels(["estimator", "sample"])
    for r in runs
    if "sampler" in r.cfg["inference"]
)

In [None]:
for run in runs:

    if "Constant" in get_variance_estimator(run):
        continue

    var_est = get_variance_estimator(run)
    subset = sample_data.loc[var_est]
    color = get_color(run)
    kwargs = {"exact_posterior": posterior, "sample_data": subset, "color": color}
    plot_sampled_distributions_pairs(**kwargs)
    plt.savefig(f"../thesis/Figures/simulated_var_est_pairs_{var_est}.pdf")
    plot_sampled_joint_bivariate(i=1, j=3, xlims=(-2.5, 1.0), ylims=(0.1, 0.5), **kwargs)
    plt.savefig(f"../thesis/Figures/simulated_var_est_joint_{var_est}.pdf")
