# Automatic Calibration Diagnosis: Interpreting Probability Integral Transform (PIT) Histograms

In [None]:
import functools
import math
import random

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import ticker
from matplotlib import gridspec
import pandas as pd
from scipy import stats
from sklearn.metrics.pairwise import euclidean_distances
import torch
import wandb

from calibration import data
from calibration import dist
from calibration import method
from calibration import pit
from calibration import plot

In [None]:
matplotlib.rcParams["font.family"] = "Times New Roman"
matplotlib.rcParams["font.size"] = 8
matplotlib.rcParams["axes.titlesize"] = 10
matplotlib.rcParams["figure.dpi"] = 300
PHI = (1 + math.sqrt(5)) / 2
WIDTH = 5.5
matplotlib.rcParams['figure.figsize'] = (WIDTH, (PHI - 1) * WIDTH)

In [None]:
def param2pdf(weight, mean, variance):
    return lambda x: dist.pdf_gaussian_mixture(x, weight, mean, variance)

def y2pdf(w, s, v1, v2):
    weight = torch.tensor([w, 1 - w])
    mean = torch.tensor([-s / 2, s / 2])
    variance = torch.tensor([v1, v2])
    return param2pdf(weight, mean, variance)

## PIT histogram

In [None]:
bias = (torch.tensor(1.0), torch.tensor(1.0), "biased")
under = (torch.tensor(0.0), torch.tensor(2.0), "under-dispersed")
over = (torch.tensor(0.0), torch.tensor(0.5), "over-dispersed")

_, axes = plt.subplots(nrows=3, ncols=2,
                       sharex="col", constrained_layout=True)
for ax, (mean, variance, label) in zip(axes, [bias, under, over]):
    ax[0].set_title(f"{label} predictive distribution")
    y = mean + torch.sqrt(variance) * torch.randn(pit.SAMPLES)
    handle_pred = plot.density(ax[0], dist.pdf_gaussian, color="C2")
    dist_obs = functools.partial(dist.pdf_gaussian, mean=mean, variance=variance)
    handle_obs = plot.density(ax[0], dist_obs, color="C3", linestyle="--")
    handle_pit = plot.pit_hist(ax[1], pit.pit_hist(pit.pit_gaussian(y)))
    ax[0].set_ylabel("density")
    ax[1].set_ylabel("density")
axes[2, 0].set_xlabel("y", style="italic")
axes[2, 1].set_xlabel("PIT")
axes[0, 1].legend([handle_pit, handle_pred, handle_obs],
                  ["PIT histogram",
                   "predictive distribution",
                   "observation-generating\ndistribution"])
plt.savefig("figures/types.pdf")

## Automatically interpreting PIT histograms

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=2,
                         constrained_layout=True,
                         figsize=(WIDTH, (2 / 3) * WIDTH))
[[ax_obs, ax_pre], [ax_nll, ax_pit], [ax_pro, ax_int]] = axes

ax_obs.set_title("step 1 and 2")
mean = torch.tensor(1.0)
variance = torch.tensor(1.0)
pdf_obs = functools.partial(dist.pdf_gaussian, mean=mean, variance=variance)
plot.density(ax_obs, pdf_obs,
             color="C3", linestyle="--",
             label="observation-\ngenerating\ndistribution")
y = mean + torch.sqrt(variance) * torch.randn(100)
ax_obs.scatter(y, torch.zeros_like(y),
               color="C4", marker="|",
               label="sample")
ax_obs.legend(loc="upper left")
ax_obs.set_xlabel("y", style="italic")
ax_obs.set_ylabel("density")

ax_pre.set_title("step 3")
cdf_pre = dist.cdf_gaussian
plot.cumulative_density(ax_pre, cdf_pre,
                        color="C2",
                        label="predictive\ndistribution")
ax_pre.scatter(y, torch.zeros_like(y),
               color="C4", marker="|",
               label="sample")
ax_pre.legend(loc="upper left")
ax_pre.set_xlabel("y", style="italic")
ax_pre.set_ylabel("cumulative density")

ax_pit.set_title("step 4 and 5")
plot.pit_hist(ax_pit, pit.pit_hist(pit.pit_gaussian(y)),
              label="PIT histogram")
ax_pit.legend(loc="upper left")
ax_pit.set_xlabel("PIT")
ax_pit.set_ylabel("density")

ax_int.set_axis_off()
ax_int.annotate("interpreter",
                xy=(0.5, 1), xycoords="data",
                xytext=(0.5, 0.5), textcoords="data",
                arrowprops=dict(arrowstyle="<-"),
                ha="center", va="center", fontsize=10)
ax_int.annotate("interpreter",
                xy=(0, 0.5), xycoords="data",
                xytext=(0.5, 0.5), textcoords="data",
                arrowprops=dict(arrowstyle="->"),
                ha="center", va="center", fontsize=10)

plot.density(ax_pro, pdf_obs,
             color="C3", linestyle="--",
             label="predicted\nobservation-\ngenerating\ndistribution")
ax_pro.legend(loc="upper left")
ax_pro.set_xlabel("y", style="italic")
ax_pro.set_ylabel("density")

ax_nll.set_axis_off()
ax_nll.annotate("mean negative log-likelihood",
                xy=(0.5, 1), xycoords="data",
                xytext=(0.5, 0.5), textcoords="data",
                arrowprops=dict(arrowstyle="->"),
                va="center", ha="center", fontsize=10)
ax_nll.annotate("mean negative log-likelihood",
                xy=(0.5, 0), xycoords="data",
                xytext=(0.5, 0.5), textcoords="data",
                arrowprops=dict(arrowstyle="->"),
                va="center", ha="center", fontsize=10)

plt.savefig("figures/concept.pdf")

## Interpreter

In [None]:
checkpoint = torch.load("models/eternal-smoke-51.pt")
hyperparams = checkpoint["hyperparams"]
interpreter = method.MDN(inputs=hyperparams["bins"],
                neurons=hyperparams["neurons"],
                components=hyperparams["components"])
interpreter.load_state_dict(checkpoint["model_state_dict"])
interpreter

## Synthetic data set

In [None]:
TESTS = 1000
torch.manual_seed(78)
testset = pit.PITDataset(TESTS)

In [None]:
refset = pit.PITReference()
len(refset)

In [None]:
pred_mdn = method.predict(interpreter, testset.X)

In [None]:
random.seed(54)
random.sample(range(len(testset)), k=3)

In [None]:
distances = euclidean_distances(testset.X, refset.X)
js = distances.argmin(axis=1)
random.seed(54)
for i in random.sample(range(len(testset)), k=3):
    y = testset.annotation[i]
    weight = pred_mdn[0][i]
    mean = pred_mdn[1][i]
    variance = pred_mdn[2][i]
    neighbour = refset.annotation[js][i]
    _, ax = plt.subplots()
    plot.density(ax, y2pdf(*y), label="data generating")
    plot.density(ax, param2pdf(weight, mean, variance), label="MDN")
    plot.density(ax, y2pdf(*neighbour), label="nearest neighbor")
    ax.legend()
    plt.show()
    plt.close()

In [None]:
nll_mdn = dist.nll_gaussian_mixture(testset.y, *pred_mdn).mean()
nll_mdn.item()

In [None]:
steps = range(5, 18)
nll_neighbour = []
for s in steps:
    refset = pit.PITReference(steps=s)
    distances = euclidean_distances(testset.X, refset.X)
    js = distances.argmin(axis=1)
    neighbour = refset.annotation[js]
    weight = torch.stack([neighbour[:, 0], 1 - neighbour[:, 0]], dim=1)
    mean = torch.stack((-neighbour[:, 1] / 2, neighbour[:, 1] / 2), dim=1)
    variance = neighbour[:, 2:]
    nll = dist.nll_gaussian_mixture(testset.y, weight, mean, variance)
    nll_neighbour.append(nll.mean().item())
    print(f"{s:2d} {nll_neighbour[-1]:f} {s ** 4}")

In [None]:
_, ax = plt.subplots(figsize=(5.5, (2 / 3) * (PHI - 1) * 5.5), constrained_layout=True)
size = [s ** 4 for s in steps]
ax.scatter(size, nll_neighbour, marker="+", label="nearest neighbour algorithm")
ax.axhline(round(nll_mdn.item(), 3), ls="--", label="our interpreter")
ax.set_xlabel("training set size of nearest neighbour algorithm")
ax.set_ylabel("negative log-likelihood")
ax.legend()
ax.yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.3f}"))
plt.savefig("figures/neighbour.pdf")

In [None]:
df = pd.DataFrame.from_dict({"steps": steps,
                             "size": size,
                             "nll": nll_neighbour})
df = df.set_index("steps")
df

## Uniform PIT histogram

In [None]:
pit_hist_uniform = torch.full((pit.BINS, ), 1.0)
pred_uniform = method.predict(interpreter, pit_hist_uniform)
_, ax = plt.subplots()
plot.density(ax, dist.pdf_gaussian, label="data-generating")
plot.density(ax, param2pdf(*pred_uniform), label="predicted data-generating")
ax.legend()

## UCI ML repository data sets

In [None]:
def load_model(modelfile, Model, keys):
    checkpoint = torch.load(modelfile)
    hyperparams = checkpoint["hyperparams"]
    model = Model(*[hyperparams[k] for k in ("inputs", "neurons") + keys])
    model.load_state_dict(checkpoint["model_state_dict"])
    return model

def load_mdn(modelfile):
    return load_model(modelfile, method.MDN, ("components", ))

def load_dn(modelfile):
    return load_mdn(modelfile)

def load_de(modelfile):
    return load_model(modelfile, method.DE, ("members", ))

def plot_interpretation(ax, interpretation):
    density_predictive = plot.density(ax, dist.pdf_gaussian, color="C2")
    density_interpret = plot.density(ax, param2pdf(*interpretation), color="C3", ls="--")
    return density_predictive, density_interpret

def pit_hist(model, dataset):
    alpha, mu, sigma = method.predict(model, dataset.X)
    pit_values = pit.pit_gaussian_mixture(dataset.y, alpha, mu, sigma)
    return pit.pit_hist(pit_values)

def diagnose(pit_hist, ax_pit, ax_dist, interpreter=interpreter):
    interpretation = method.predict(interpreter, pit_hist)
    pit_hist_interpreter = pit.pit_hist(pit.pit_gaussian(dist.sample_gaussian_mixture(*interpretation)))
    hist_true = plot.pit_hist(ax_pit, pit_hist, fill=True)
    hist_pred = plot.pit_hist(ax_pit, pit_hist_interpreter)
    density_predictive, density_interpret = plot_interpretation(ax_dist, interpretation)
    return hist_true, hist_pred, density_predictive, density_interpret

def visualise(pit_hist_dn, pit_hist_de, pit_hist_mdn):
    _, axes = plt.subplots(3, 2,
                           constrained_layout=True,
                           figsize=(WIDTH, (PHI - 1) * WIDTH))
    axes[0, 0].set_title("density network")
    axes[1, 0].set_title("deep ensemble")
    axes[2, 0].set_title("mixture density network")
    _, hist_pred, density_pred, density_interpret = diagnose(pit_hist_dn, axes[0, 0], axes[0, 1])
    print(density_interpret)
    diagnose(pit_hist_de, axes[1, 0], axes[1, 1])
    hist_mdn = plot.pit_hist(axes[2, 0], pit_hist_mdn, fill=True)
    axes[0, 0].set_xticklabels([])
    axes[0, 1].set_xticklabels([])
    axes[2, 1].set_axis_off()
    axes[2, 1].legend([hist_mdn, hist_pred, density_pred, density_interpret],
                      ["true PIT histogram",
                       "predicted PIT histogram",
                       "predictive distribution",
                       "predicted observation-generating distribution"],
                      loc="center")
    axes[1, 0].set_xlabel("PIT")
    axes[2, 0].set_xlabel("PIT")
    axes[1, 1].set_xlabel("y", style="italic")
    for i in range(3):
        axes[i, 0].set_ylabel("density")
        axes[i, 1].set_ylabel("density")

### Metrics

In [None]:
api = wandb.Api()
runs = api.runs("podondra/calibration")

keys = ["dataname", "method", "seed", "neurons"]
dicts, names = [], []
for run in runs:
    try:
        dictionary = {k: run.config[k] for k in keys}
        dictionary["nll"] = run.summary["test.nll"]
        dictionary["crps"] = run.summary["test.crps"]
    except KeyError:
        continue
    dicts.append(dictionary)
    names.append(run.name)

df = pd.DataFrame(data=dicts, index=names)
gdf = df.groupby(["dataname", "method"])
df

In [None]:
gdf["nll"].agg(["mean", "sem"])

In [None]:
gdf["crps"].agg(["mean", "sem"])

### Year

In [None]:
_, _, yearset = data.split(*data.year(), seed=4)

In [None]:
pit_hist_dn_year = pit_hist(load_dn("models/rich-dragon-8.pt"), yearset)
pit_hist_de_year = pit_hist(load_de("models/generous-valley-7.pt"), yearset)
pit_hist_mdn_year = pit_hist(load_mdn("models/chocolate-sound-9.pt"), yearset)

In [None]:
visualise(pit_hist_dn_year, pit_hist_de_year, pit_hist_mdn_year)
plt.savefig("figures/year.pdf")

### Protein

In [None]:
_, _, proteinset = data.split(*data.protein(), seed=4)

In [None]:
pit_hist_dn_protein = pit_hist(load_dn("models/super-durian-4.pt"), proteinset)
pit_hist_de_protein = pit_hist(load_de("models/golden-snow-6.pt"), proteinset)
pit_hist_mdn_protein = pit_hist(load_mdn("models/lucky-moon-5.pt"), proteinset)

In [None]:
visualise(pit_hist_dn_protein, pit_hist_de_protein, pit_hist_mdn_protein)
plt.savefig("figures/protein.pdf")

### Power

In [None]:
_, _, powerset = data.split(*data.power(), seed=4)

In [None]:
pit_hist_dn_power = pit_hist(load_dn("models/polished-star-1.pt"), powerset)
pit_hist_de_power = pit_hist(load_de("models/elated-surf-3.pt"), powerset)
pit_hist_mdn_power = pit_hist(load_mdn("models/effortless-firefly-2.pt"), powerset)

In [None]:
visualise(pit_hist_dn_power, pit_hist_de_power, pit_hist_mdn_power)
plt.savefig("figures/power.pdf")