# Calibration

In [None]:
from matplotlib import pyplot as plt
import pandas as pd
from sklearn.metrics.pairwise import euclidean_distances
import torch
import wandb

from calibration import data
from calibration import dist
from calibration import method
from calibration import pit
from calibration import plot

In [None]:
def param2pdf(weight, mean, variance):
    return lambda x: dist.pdf_gaussian_mixture(x, weight, mean, variance)

def y2pdf(w, s, v1, v2):
    weight = torch.tensor([w, 1 - w])
    mean = torch.tensor([-s / 2, s / 2])
    variance = torch.tensor([v1, v2])
    return param2pdf(weight, mean, variance)

## Interpreter

In [None]:
checkpoint = torch.load("models/light-aardvark-162.pt")
hyperparams = checkpoint["hyperparams"]
interpreter = method.MDN(inputs=hyperparams["bins"],
                neurons=hyperparams["neurons"],
                components=hyperparams["components"])
interpreter.load_state_dict(checkpoint["model_state_dict"])
interpreter

## Synthetic data set

In [None]:
TESTS = 1000
torch.manual_seed(78)
testset = pit.PITDataset(TESTS)

In [None]:
refset = pit.PITReference()
len(refset)

In [None]:
pred_mdn = method.predict(interpreter, testset.X)

In [None]:
distances = euclidean_distances(testset.X, refset.X)
js = distances.argmin(axis=1)
for y, weight, mean, variance, neighbour in zip(testset.annotation,
                                          *pred_mdn,
                                          refset.annotation[js]):
    _, ax = plt.subplots()
    plot.density(ax, y2pdf(*y), label="data generating")
    plot.density(ax, param2pdf(weight, mean, variance), label="MDN")
    plot.density(ax, y2pdf(*neighbour), label="nearest neighbor")
    ax.legend()
    plt.show()
    plt.close()
    break

In [None]:
nll_mdn = dist.nll_gaussian_mixture(testset.y, *pred_mdn).mean()
nll_mdn.item()

In [None]:
steps = range(5, 18)
nll_neighbour = []
for s in steps:
    refset = pit.PITReference(steps=s)
    distances = euclidean_distances(testset.X, refset.X)
    js = distances.argmin(axis=1)
    neighbour = refset.annotation[js]
    weight = torch.stack([neighbour[:, 0], 1 - neighbour[:, 0]], dim=1)
    mean = torch.stack((-neighbour[:, 1] / 2, neighbour[:, 1] / 2), dim=1)
    variance = neighbour[:, 2:]
    nll = dist.nll_gaussian_mixture(testset.y, weight, mean, variance)
    nll_neighbour.append(nll.mean().item())
    print(f"{s:2d} {nll_neighbour[-1]:f} {s ** 4}")

_, ax = plt.subplots()
size = [s ** 4 for s in steps]
ax.scatter(size, nll_neighbour, label="$k$-NN")
ax.axhline(nll_mdn, label="our MDN")
ax.set_xlabel("size of reference set")
ax.set_ylabel("NLL")
ax.legend()
nll_neighbour

In [None]:
df = pd.DataFrame.from_dict({"steps": steps,
                             "size": size,
                             "nll": nll_neighbour})
df = df.set_index("steps")
df.to_csv("data/neighbour.csv")
df

## Uniform PIT histogram

In [None]:
pit_hist_uniform = torch.full((pit.BINS, ), 1.0 / pit.BINS)
pred_uniform = method.predict(interpreter, pit_hist_uniform)
_, ax = plt.subplots()
plot.density(ax, dist.pdf_gaussian, label="data-generating")
plot.density(ax, param2pdf(*pred_uniform), label="predicted data-generating")
ax.legend()

In [None]:
_, ax = plt.subplots()
plot.pit_hist(ax, pit_hist_uniform)

## UCI ML repository data sets

TODO list:
1. what hyperparameters to use?
1. other methods: MC dropout?, concrete dropout?, Gaussian processes?

In [None]:
def load_model(modelfile, Model, keys):
    checkpoint = torch.load(modelfile)
    hyperparams = checkpoint["hyperparams"]
    model = Model(*[hyperparams[k] for k in ("inputs", "neurons") + keys])
    model.load_state_dict(checkpoint["model_state_dict"])
    return model

def load_mdn(modelfile):
    return load_model(modelfile, method.MDN, ("components", ))

def load_dn(modelfile):
    return load_mdn(modelfile)

def load_de(modelfile):
    return load_model(modelfile, method.DE, ("members", ))

def plot_interpretation(ax, interpretation):
    plot.density(ax, dist.pdf_gaussian, label="predictive")
    plot.density(ax, param2pdf(*interpretation), label="predicted data-generating")
    ax.legend()

def pit_hist(model, dataset):
    alpha, mu, sigma = method.predict(model, dataset.X)
    pit_values = pit.pit_gaussian_mixture(dataset.y, alpha, mu, sigma)
    return pit.pit_hist(pit_values)

def diagnose(pit_hist, interpreter=interpreter):
    interpretation = method.predict(interpreter, pit_hist)
    pit_hist_interpreter = pit.pit_hist(pit.pit_gaussian(dist.sample_gaussian_mixture(*interpretation)))
    _, (ax1, ax2) = plt.subplots(nrows=2)
    plot_interpretation(ax1, interpretation)
    plot.pit_hist(ax2, pit_hist, label="true")
    plot.pit_hist(ax2, pit_hist_interpreter, label="predicted")
    ax2.legend()

### Metrics

In [None]:
api = wandb.Api()
runs = api.runs("podondra/calibration")

keys = ["dataname", "seed", "method", "neurons", "lr", "patience"]
dicts, names = [], []
for run in runs:
    try:
        dictionary = {k: run.config[k] for k in keys}
        dictionary["nll"] = run.summary["test.nll"]
        dictionary["crps"] = run.summary["test.crps"]
    except KeyError:
        continue
    dicts.append(dictionary)
    names.append(run.name)

df = pd.DataFrame(data=dicts, index=names)
df

In [None]:
gdf = df.groupby(["dataname", "method"])

In [None]:
gdf["nll"].agg(["mean", "sem"])

In [None]:
gdf["crps"].agg(["mean", "sem"])

### Power

In [None]:
_, _, powerset = data.split(*data.power(), seed=4)

In [None]:
pit_hist_dn_power = pit_hist(load_dn("models/wise-glitter-174.pt"), powerset)
diagnose(pit_hist_dn_power)

In [None]:
pit_hist_de_power = pit_hist(load_de("models/major-sunset-175.pt"), powerset)
diagnose(pit_hist_de_power)

In [None]:
pit_hist_mdn_power = pit_hist(load_mdn("models/smooth-universe-176.pt"), powerset)
plot.pit_hist(plt.subplots()[1], pit_hist_mdn_power)

### Protein

In [None]:
_, _, proteinset = data.split(*data.protein(), seed=4)

In [None]:
pit_hist_dn_protein = pit_hist(load_dn("models/denim-dust-171.pt"), proteinset)
diagnose(pit_hist_dn_protein)

In [None]:
pit_hist_de_protein = pit_hist(load_de("models/upbeat-sound-173.pt"), proteinset)
diagnose(pit_hist_de_protein)

In [None]:
pit_hist_mdn_protein = pit_hist(load_mdn("models/fanciful-water-172.pt"), proteinset)
plot.pit_hist(plt.subplots()[1], pit_hist_mdn_protein)

### Year

In [None]:
_, _, yearset = data.split(*data.year(), seed=4)

In [None]:
pit_hist_dn_year = pit_hist(load_dn("models/lunar-plant-177.pt"), yearset)
diagnose(pit_hist_dn_year)

In [None]:
pit_hist_de_year = pit_hist(load_de("models/lemon-eon-178.pt"), yearset)
diagnose(pit_hist_de_year)

In [None]:
pit_hist_mdn_year = pit_hist(load_mdn("models/swept-moon-179.pt"), yearset)
plot.pit_hist(plt.subplots()[1], pit_hist_mdn_year)