In [None]:
%load_ext autoreload
%autoreload 2

from src.experiments.common import Experiment
from src.visualization.grad_variance_estimates import *
from pathlib import Path
import matplotlib.pyplot as plt
import plotly.express as px
import pandas as pd
import seaborn as sns
experiment = Experiment("sghmc_gradients")

In [None]:
multirun = experiment.latest_run()
run = multirun.runs[2]

In [None]:
variance_inter_batch = load_estimates(run, "variance_inter_batch")
variance_estimated = load_estimates(run, "variance_estimated")

In [None]:
# fmt: off
data = (
    pd.concat([variance_inter_batch, variance_estimated], axis=1)
    .unstack("parameter_index")
)
is_zero = (
    np.isclose(data["variance_inter_batch"], 0).all(0) 
    |  np.isclose(data["variance_inter_batch"], 0).all(0)
)
# fmt: on
no_zero_columns = (
    data.columns.get_level_values("parameter_index").to_series().unique()[~is_zero]
)
sampled_cols = pd.Series(no_zero_columns).sample(9, random_state=123)

In [None]:
sampled_data = (
    data.reorder_levels((1, 0), axis=1)
    .loc[:, sampled_cols]
    .stack(level="parameter_index")
    .reset_index()
    .assign(step_mod_110=lambda x: x.step % 110)
    .sample(frac=1.0)
)


In [None]:
run.config.variance_estimator

In [None]:
fg = sns.relplot(
    data=sampled_data,
    x="variance_inter_batch",
    y="variance_estimated",
    col="parameter_index",
    col_wrap=3,
    hue="step_mod_110",
    facet_kws={"sharey": False, "sharex": False},
)
fg.set(xscale="log")
fg.set(yscale="log")

for ax in fg.axes.flatten():
    ax.axline((0, 0), (1, 1), color="red")