In [None]:
import os
from dataclasses import asdict

import sbi
import torch
from sbi.inference import SNLE, SNPE
from sbi.inference.posteriors import (
    DirectPosteriorParameters,
    MCMCPosteriorParameters,
    VIPosteriorParameters,
)
from utils import (
    benchmark_sample_from_inference,
    eval_samples,
    query,
    save_results,
    train_inference,
)

torch.manual_seed(0)

import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use("../../.matplotlibrc")

In [None]:
print(sbi.__version__)

## NPE

Here we can evaluate all three sampling methods i.e. `direct`, `mcmc`, and `vi`.

In [None]:
dimensions = [2, 5, 10, 20]
method = SNPE
posterior_parameters = DirectPosteriorParameters()


for d in dimensions:
    x_o = torch.tensor([0.0] * d)
    inf, true_posterior = train_inference(method, d, num_simulations=5000)

    # Direct sampling
    samples, times = benchmark_sample_from_inference(
        inf,
        1000,
        x_o,
        seeds=[1, 2, 3, 4, 5],
        sample_with="direct",
        posterior_parameters=posterior_parameters,
    )
    c2sts = eval_samples(samples, true_posterior(x_o))
    save_results("SNPE", d, "direct", asdict(posterior_parameters), times, c2sts)

## SNLE with MCMC and VI

In [None]:
dimensions = [2, 5, 10, 20]
method = SNLE
mcmc_parameters = MCMCPosteriorParameters(
    method="slice_np_vectorized", num_chains=100, thin=1
)
vi_parameters = VIPosteriorParameters()

for d in dimensions:
    x_o = torch.tensor([0.0] * d)
    inf, true_posterior = train_inference(method, d, num_simulations=5000)
    # VI sampling
    samples_vi, times_vi = benchmark_sample_from_inference(
        inf,
        1000,
        x_o,
        seeds=[1, 2, 3, 4, 5],
        sample_with="vi",
        posterior_parameters=vi_parameters,
    )
    c2sts_vi = eval_samples(samples_vi, true_posterior(x_o))
    save_results("SNLE", d, "vi", asdict(vi_parameters), times_vi, c2sts_vi)

    # MCMC sampling
    samples_mcmc, times_mcmc = benchmark_sample_from_inference(
        inf,
        1000,
        x_o,
        seeds=[1, 2, 3, 4, 5],
        sample_with="mcmc",
        posterior_parameters=mcmc_parameters,
    )
    c2sts_mcmc = eval_samples(samples_mcmc, true_posterior(x_o))
    save_results(
        "SNLE",
        d,
        "mcmc",
        asdict(mcmc_parameters),
        times_mcmc,
        c2sts_mcmc,
    )

In [None]:
direct_color = "#6495ED"
MCMC_color = "#FFA07A"
VI_color = "#7b2e8b"

color_palette = {
    "direct": direct_color,
    "mcmc": MCMC_color,
    "vi": VI_color,
}

In [None]:
# Benchmark sampling times

df = query()

In [None]:
fig = plt.figure(figsize=(1.8, 1.0))

# linewidth = 2
outward = 10
# ticklength = 4
# tickwidth = 1


ax = sns.pointplot(
    data=df,
    x="dimension",
    y="times",
    hue="sampling_method",
    log_scale=False,
    legend=True,
    markersize=3,
    palette=color_palette,
    clip_on=False,
)
ax.set_ylabel("time (s)", labelpad=-5)
ax.set_xlabel(r"dim ($\theta$)")

ax.set_ylim(0, 240)
ax.set_yticks([0, 240])

ax.set_xlim(0, 3)

# ax.tick_params(right="off",top="off", direction = "out")
# ax.ti

for line in ["left", "bottom"]:
    # ax.spines[line].set_linewidth(linewidth)
    ax.spines[line].set_position(("outward", outward))


ax.legend(
    title="sampling method",
    loc="upper center",
    bbox_to_anchor=(0.5, 1.5),
    ncol=3,
    handlelength=1,
)

if not os.path.exists("./fig"):
    os.makedirs("./fig")
fig.savefig("./fig/default_sampling_times.svg", bbox_inches="tight")
fig.savefig("./fig/default_sampling_times.png", bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(3, 2))
ax = sns.boxplot(
    data=df,
    x="dimension",
    y="times",
    hue="sampling_method",
    log_scale=True,
    legend=True,
    linewidth=0.5,
)
ax.set_ylabel("time (s)")
ax.set_xlabel(r"dim ($\theta$)")
ax.legend(
    title="sampling method", loc="upper center", bbox_to_anchor=(0.5, 1.3), ncol=3
)
fig.savefig("./fig/default_sampling_times_log_scale.svg", bbox_inches="tight")
fig.savefig("./fig/default_sampling_times_log_scale.png", bbox_inches="tight")