# Analysis

In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt
from runner import short_uni_synth, long_uni_synth

%matplotlib inline
%config InlineBackend.figure_formats = ['svg']

In [None]:
results = list(short_uni_synth.results)
print(len(results))
for r in results:
    if r["args"].infer == "mcmc":
        break
print(r.keys())
print(r["times"].keys())
print(r["evaluate"].keys())
print(r["evaluate"]["R0"].keys())
print(r["infer"].keys())
print(r["infer"]["R0"].keys())

In [None]:
def plot_accuracy(variable, metric, experiment):
    view = defaultdict(list)
    for result in experiment.results:
        args = result['args']
        view[args.infer, args.num_bins, args.svi_steps].append(result)
    markers = ["o", "d", "s", "<", "v", "^", ">"]
    assert len(view) <= len(markers)

    plt.figure(figsize=(6, 5)).patch.set_color("white")
    for (key, value), marker in zip(sorted(view.items()), markers):
        algo, num_bins, svi_steps = key
        if algo == "svi":
            label = f"SVI steps={svi_steps}"
        elif algo == "mcmc":
            if num_bins == 1:
                label = "MCMC relaxed"
            else:
                label = f"MCMC num_bins={num_bins}"
        X = [v["times"]["infer"] for v in value]
        Y = [v["evaluate"][variable][metric] for v in value]
        plt.scatter(X, Y, marker=marker, label=label, alpha=0.8)
    plt.ylim(0, None)
    plt.xscale("log")
    plt.xlabel("inference time (sec)")
    plt.ylabel(metric.upper())
    plt.title(f"{variable} accuracy ({experiment.__name__})")
    plt.legend(loc="best", prop={'size': 8})
    plt.tight_layout()

In [None]:
plot_accuracy("R0", "crps", short_uni_synth)
plot_accuracy("rho", "crps", short_uni_synth)
plot_accuracy("obs", "crps", short_uni_synth)
plot_accuracy("I", "crps", short_uni_synth)

In [None]:
plot_accuracy("R0", "crps", long_uni_synth)
plot_accuracy("rho", "crps", long_uni_synth)
plot_accuracy("obs", "crps", long_uni_synth)
plot_accuracy("I", "crps", long_uni_synth)

In [None]:
def plot_convergence(variable, experiment, metrics=["n_eff", "r_hat"]):
    view = defaultdict(list)
    for result in results:
        args = result['args']
        if args.infer == "mcmc":
            view[args.num_bins].append(result)
    markers = ["o", "d", "s"]
    assert len(view) <= len(markers)

    fig, axes = plt.subplots(len(metrics), 1, figsize=(6, 5), sharex=True)
    fig.patch.set_color("white")
    for (num_bins, value), marker in zip(sorted(view.items()), markers):
        if num_bins == 1:
            label = "MCMC relaxed"
        else:
            label = f"MCMC num_bins={num_bins}"
        X = [v["times"]["infer"] for v in value]
        for metric, ax in zip(metrics, axes):
            Y = [v["infer"][variable][metric] for v in value]
            ax.scatter(X, Y, marker=marker, label=label, alpha=0.8)
            ax.set_xscale("log")
            ax.set_yscale("log")
            ax.set_ylabel(metric)
    axes[0].set_title(f"{variable} convergence ({experiment.__name__})")
    axes[1].set_ylim(1, None)
    axes[-1].legend(loc="best", prop={'size': 8})
    axes[-1].set_xlabel("inference time (sec)")
    plt.subplots_adjust(hspace=0)

In [None]:
plot_convergence("R0", short_uni_synth)
plot_convergence("rho", short_uni_synth)
plot_convergence("auxiliary_haar_split_0", short_uni_synth)

In [None]:
plot_convergence("R0", long_uni_synth)
plot_convergence("rho", long_uni_synth)
plot_convergence("auxiliary_haar_split_0", long_uni_synth)