In [None]:
%load_ext autoreload
%autoreload 2
from src.utils import Run, EXPERIMENT_PATH
from src.inference.mcmc.example_distribution import Example
import torch
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

In [None]:
dir_ = EXPERIMENT_PATH/"synthetic"/"2021-11-25"/"10-56-20"
runs = list(map(Run, dir_.glob("[0-9]/")))


In [None]:
def plot_distribution(samples, bins, *args, ax=None, **kwargs):

    ax = kwargs.pop("ax", None)
    if ax is None:
        ax = plt.gca()

    xx = bins[:-1] + (bins[1] - bins[0]) / 2
    yy, _ = np.histogram(samples, bins, density=True)
    ax.plot(xx, yy, *args, **kwargs)

In [None]:
runs[0].cfg

In [None]:
from dataclasses import asdict
import pandas as pd


sample_data = pd.concat(
    pd.DataFrame(pd.Series(torch.load(r._dir / "samples.pt"), ))
    .rename_axis(index="sample")
    .assign(
        Sampler=r.cfg["legend"]
    )
    .set_index(["Sampler"], append=True)
    .reorder_levels(["Sampler", "sample"])
    for r in runs if "sampler" in r.cfg["inference"]
)



# plot_distribution(samples, bins)
# plt.plot(xx, Example.density(xx))

In [None]:
n_bins = 60
bins = np.linspace(-3, 3, n_bins + 1)
xx = np.linspace(bins[0], bins[-1], 200)


def get_x_y(data):
    y, _ = np.histogram(data, bins=bins, density=True)
    diff = bins[1] - bins[0]
    x = bins[1:] - diff / 2
    return pd.DataFrame({"$x$": x, "Density": y})


true_dist = pd.DataFrame(
    {
        "$x$": np.linspace(-3, 3),
    }
).assign(
    Density= lambda x: Example.density(x["$x$"]),
    Sampler="True distribution",
)

(
    sample_data.groupby(level="Sampler")
    .apply(get_x_y)
    .reset_index("Sampler")
    .pipe(lambda x: pd.concat([true_dist, x]))
    .reset_index(drop=True)
    .pipe(
        (sns.relplot, "data"),
        x="$x$",
        y="Density",
        style="Sampler",
        hue="Sampler",
        kind="line",
        aspect=1.5,
        height=3,
        # palette="colorblind"
    )
)
# plt.plot(xx, Example.density(xx))
plt.ylim(None, 0.6)
plt.xlim(-2.5, 2.5)
sns.despine()
plt.savefig("../thesis/Figures/synthetic.pdf")


In [None]:
import seaborn as sns
sns.histplot(data=sample_data.reset_index("legend").reset_index(drop=True), x="value", hue="legend", style="legend")

In [None]:
_plot_dist(samples.numpy(), bins=bins)

In [None]:
torch.manual_seed(10)
samplable = Example(grad_noise=0.)
sampler = Hamiltonian(n_steps=50, step_size=0.02).setup(samplable)
trace_data = get_traces(sampler, 50)

In [None]:
is_accepted = trace_data.groupby(level=0).last()["accepted"]
accepted_traces = is_accepted.index[is_accepted]
accepted_states = trace_data.loc[accepted_traces].groupby(level=0).last()

initial_states  = trace_data.groupby(level=0).first()

In [None]:
by_trace = trace_data.groupby(level=0)
momentum_updates = by_trace.last().copy().rename(columns={"momentum" : "momentum_from"})
del momentum_updates["accepted"]
momentum_updates.loc[momentum_updates.index[:-1], "momentum_to"] = by_trace.first().loc[1:]["momentum"].values
momentum_updates = momentum_updates.reset_index().melt(
    id_vars=["trace", "value"],
    value_name="momentum",
).sort_values("trace").reset_index(drop=True)
del momentum_updates["variable"]

In [None]:
plt.figure(figsize=(10, 8))
sns.lineplot(
    x="value", y="momentum", data=trace_data, sort=False, units="trace", estimator=None
)
sns.scatterplot(x="value", y="momentum", data=accepted_states, marker="X", sizes=20)
sns.scatterplot(x="value", y="momentum", data=initial_states, marker="o")
sns.lineplot(
    x="value",
    y="momentum",
    color="grey",
    units="trace",
    estimator=None,
    data=momentum_updates,
    linestyle="dashed",
)
final_states = trace_data.groupby(level=0).last()
plt.xlim(-2, 2)
plt.ylim(-3, 3)
plt.show()
