In [None]:
import matplotlib.pyplot as plt
import torch

approaches = [
    ("hands on", "VexprHandsOnLossModel"),
    ("one size fits all", "VexprFullyJointLossModel"),
]
n_cvs = range(10, 121, 10)
x = n_cvs

def get_result(approach, n_cv):
    filename = f"results/cv/cv-{approach}-{n_cv}.pt"
    return torch.load(filename, map_location=torch.device("cpu"))

fig, ax = plt.subplots(1, 3, figsize=(10, 4))

for label, approach in approaches:
    norms = []
    mean_p_y = []
    mean_log_p_y = []
    for n_cv in n_cvs:
        result = get_result(approach, n_cv)
        posterior = result["posterior"]
        log_density = posterior.mvn.log_prob(result["observed_Y"].squeeze(-1))
        density = log_density.exp()

        norms.append((posterior.mean - result["observed_Y"]).view(-1).norm().item() / posterior.mean.numel())
        mean_p_y.append(density.mean().item())
        mean_log_p_y.append(log_density.mean().item())

    ax[0].plot(x, mean_p_y, label=label)
    ax[1].plot(x, mean_log_p_y, label=label)
    ax[2].plot(x, norms, label=label)

ax[0].set_xlabel("# of experiments")
ax[0].set_xticks([20, 40, 60, 80, 100, 120])
ax[1].set_xlabel("# of experiments")
ax[1].set_xticks([20, 40, 60, 80, 100, 120])
ax[2].set_xlabel("# of experiments")
ax[2].set_xticks([20, 40, 60, 80, 100, 120])
ax[2].set_ylabel("Δ loss")
plt.legend()
ax[0].set_title("mean p(held out result)")
ax[1].set_title("mean log p(held out result)")
ax[2].set_title("standard error of MLE prediction")
plt.tight_layout()
# plt.savefig("aggregate.svg")
plt.show()

In [None]:
import numpy as np


for n_cv in range(10, 121, 10):
    print(n_cv)
    plt.figure(figsize=(20,3))
    for i, (label, approach) in enumerate([
        ("hands on", "VexprHandsOnLossModel"),
        ("one size fits all", "VexprFullyJointLossModel"),
    ]):
        result = get_result(approach, n_cv)
        posterior = result["posterior"]
        x = np.array(range(result["observed_Y"].numel()))
        # n_samples = 10000
        # plt.violinplot(posterior.sample(torch.Size([n_samples])).view(n_samples, -1)[..., x].numpy(),
        #                positions=x, showextrema=False)

        with torch.no_grad():
            plt.errorbar(x + i * -0.2 + 0.1,
                        posterior.mean.view(-1).numpy(),
                        yerr=posterior.variance.sqrt().view(-1).numpy(),
                        fmt=".",
                        label=label)

    plt.plot(x, result["observed_Y"].flatten()[x].numpy(), "o", color="black", label="actual")

    plt.xticks([])
    plt.xlabel("experiment #")
    plt.yticks([])
    plt.ylabel("log loss")
    plt.legend(loc="upper left")
    plt.tight_layout()
    # plt.savefig("raw-data.svg")
    plt.show()