In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import pandas as pd
from hydra.utils import instantiate
from typing import Tuple
import seaborn as sns
from src.utils import Run, EXPERIMENT_PATH
from src.analysis.sample_distribution import load_samples
from src.analysis.simulated import (
    get_exact_posterior, 
    plot_sampled_joint_bivariate,
    plot_sampled_distributions_pairs, 
    )
import torch
from src.analysis.colors import get_color, get_colors


In [None]:
def eval_poly(x, coeffs):
    return coeffs[0] + sum(c*x**i for i, c in enumerate(coeffs[1:], start=1))
    
run_dirs = (EXPERIMENT_PATH /"simulated"/"2021-12-16"/"13-27-43").glob("[0-9]")
runs = [Run(x) for x in run_dirs]
dataset = instantiate(runs[0].cfg.data.dataset)
X, Y = dataset[:]
xx = torch.linspace(-3, 3)
plt.plot(xx, eval_poly(xx, dataset.coeffs), "--k")
plt.scatter(X[:, 1], Y[:, 0]) 
plt.xlabel("$x$")
plt.ylabel("$P(x)$")
sns.despine()
plt.savefig("../thesis/Figures/pol_model.pdf")

In [None]:
posterior = get_exact_posterior(X, Y)
pal = ColorPalette()

In [None]:
for run in runs:
    if "sampler" not in run.cfg["inference"]:
        continue

    sample_data = load_samples(run)
    sampler,  batch_size = sample_data.index[0][:2]
    color = pal.get_color(run)
    kwargs = {"exact_posterior": posterior, "sample_data": sample_data, "color": color}
    plot_sampled_distributions_pairs(**kwargs)
    plt.savefig(f"../thesis/Figures/simulated_pairs_{sampler}_{batch_size}.pdf")
    plot_sampled_joint_bivariate(i=1, j=3, xlims=(-2.5, 1.0), ylims=(0.1, 0.5), **kwargs)
    plt.savefig(f"../thesis/Figures/simulated_joint_{sampler}_{batch_size}.pdf")


In [None]:
# i = 1
# j = 3
# joint_plots = {}

# xlims = (-2.5, 1.0)
# ylims = (0.1, 0.5)

# # for key, data in sample_data[[i, j]].groupby(level=[0, 1]):
#     with (sns.color_palette(PLOT_COLORS[key]["color_palette"])):
#         joint_plots[key] = plot_sampled_joint_bivariate(
#             data, exact_posterior=posterior, xlims=xlims, ylims=ylims
#         )
#         plt.savefig(f"../thesis/Figures/simulated_joint_{'_'.join(map(str, key))}.pdf")


In [None]:
from torch.distributions import MultivariateNormal

def load_var_params(file_name):
    state_dict = torch.load(file_name)["state_dict"]
    mu = state_dict["model.linear.variational_parameters.weight.mu"].flatten().numpy()
    rho = state_dict["model.linear.variational_parameters.weight.rho"]
    sigma = rho.exp().log1p().flatten().numpy()
    return pd.DataFrame({"mu": mu, "sigma": sigma}).rename_axis("parameter")

vi_run = next(r for r in runs if "Variational" in r.cfg.inference._target_)
vi_ckpt = next(vi_run._dir.glob("**/*.ckpt"))
var_params = load_var_params(vi_ckpt)
var_distribution = MultivariateNormal(
    torch.tensor(var_params["mu"]),
    torch.diag(torch.tensor(var_params["sigma"])).square(),
)
color = pal.get_color(vi_run)

In [None]:
from torch.distributions import Normal, MultivariateNormal
from itertools import product
from src.analysis.simulated import (
    draw_bi_gaussian,
    draw_uni_gaussian,
    get_marginal,
    get_lims,
)

fig, axes = plt.subplots(4, 4, figsize=(10, 8))
levels = [1e-2, 1e-1, 1e0, 1e1, 1e2]

for i, j in product(range(4), repeat=2):
    plt.sca(axes[i, j])
    if i == j:
        true_marg = get_marginal(posterior, i)
        var_marg = get_marginal(var_distribution, i)
        xlim = get_lims(true_marg.mean, true_marg.stddev)
        lines = []
        lines += draw_uni_gaussian(true_marg, xlim=xlim, color="black")
        lines += draw_uni_gaussian(var_marg, xlim=xlim, color=color)
        plt.xlim(xlim)
        pass

    else:
        true_marg = get_marginal(posterior, i, j)
        var_marg = get_marginal(var_distribution, i, j)
        xlim = get_lims(true_marg.mean[0], true_marg.stddev[0])
        ylim = get_lims(true_marg.mean[1], true_marg.stddev[1])
        draw_bi_gaussian(true_marg, xlim, ylim, colors="black", levels=levels)
        draw_bi_gaussian(var_marg, xlim, ylim, colors=[color], levels=levels)
        plt.xlim(xlim)
        plt.ylim(ylim)

    if i == 3:
        plt.xlabel(j)

    if j == 0:
        plt.ylabel(i)
plt.figlegend(
    lines,
    ["True posterior", "Variational posterior"],
    loc="lower center",
    ncol=2,
    labelspacing=0,
    frameon=False,
)
plt.tight_layout()
plt.subplots_adjust(bottom=0.09)
sns.despine()

plt.savefig("../thesis/Figures/vi-simulated.pdf")


In [None]:
predidictions = pd.DataFrame(
    (X @ torch.tensor(sample_data.values).view(-1, 4, 1)).squeeze().numpy()
).set_index(sample_data.index)


In [None]:

def ci_low(x):
    return x.quantile(0.05)
def ci_high(x):
    return x.quantile(0.95)
def median(x):
    return x.quantile(0.5)

predidictions = (
    pd.DataFrame(
        (X @ torch.tensor(sample_data.values).view(-1, 4, 1)).squeeze().numpy()
    )
    .set_index(sample_data.index)
    # .melt(ignore_index=False)
    .groupby(level=["sampler", "batch_size"])
    .agg(["mean", "std", ci_low, "median", ci_high])
)


In [None]:
posterior

In [None]:
predidictions

In [None]:
# from torch.distributions import Normal


# def draw_uni_variational(x, **kwargs):

#     i = x.name
#     marg = get_marginal(posterior, i)
#     xlims = marg.mean + torch.tensor([-5, 5]) * marg.stddev
#     xx = torch.linspace(*xlims)
#     yy = marg.log_prob(xx).exp()

#     batch_sizes = (
#         x.index.get_level_values(level="batch_size")
#         .unique()
#         .sort_values(ascending=False)
#     )
#     def get_densities(data):
#         print(data)

#     x.groupby(level="batch_sizes")

#     sns.lineplot(x=xx, y=yy, color="black")

#     mu = torch.tensor(x["mu"].values).view(2, 1)
#     sigma = torch.tensor(x["sigma"].values).view(2, 1)
#     var_dist = Normal(mu, sigma)

#     pd.DataFrame().assign(
#         batch_size=x.get,
#     )

#     d = x.copy()


# def draw_bi_variational(x, **kwargs):

#     i = x.name
#     marg = get_marginal(posterior, i)
#     xlims = marg.mean + torch.tensor([-5, 5]) * marg.stddev
#     xx = torch.linspace(*xlims)
#     yy = marg.log_prob(xx).exp()
#     sns.lineplot(x=xx, y=yy)


# pg: sns.PairGrid = (
#     variational_models.rename_axis("var_param", axis=1)
#     .unstack("batch_size")
#     .transpose()
#     .pipe((sns.PairGrid, "data"), vars=[0, 1, 2, 3])
# )
# pg.map_diag(draw_uni_variational, size="batch_size")


In [None]:

import pandas as pd
import seaborn as sns

for r in run.runs:
    model = instantiate(r.config["model"])
    


# n_samples = 10_000
# param_samples = posterior.sample((n_samples,))
# XX = torch.stack([torch.ones_like(xx), xx, xx**2, xx**3]).T
# predictions = XX @ param_samples.unsqueeze(-1)
# (
#     pd.DataFrame(predictions.squeeze().numpy(), columns=XX[:, 1].numpy())
#     .melt(
#         var_name="x",
#         value_name="y",
#     )
#     .groupby("x")
#     ç
#     .unstack()
#     .droplevel(0, axis="columns")
#     .reset_index()
#     .melt(
#         id_vars="x",
#         var_name="quantile",
#     )
#     .pipe((sns.relplot, "data"), x="x", y="value", hue="quantile", kind="line")
# )
# plt.scatter(X[:,1], y)
# plt.plot(xx, eval_poly(xx, coeffs))
# plt.plot(xx, eval_poly(xx, sgd_inf.model.linear.weight.detach()))