In [None]:
from matplotlib import pyplot as plt
import pandas as pd
from sklearn import model_selection
from sklearn import preprocessing
from sklearn.metrics.pairwise import euclidean_distances
import torch

from calibration import data
from calibration import dist
from calibration import method
from calibration import pit
from calibration import plot

In [None]:
def param2pdf(weight, mean, variance):
    return lambda x: dist.pdf_gaussian_mixture(x, weight, mean, variance)

def y2pdf(w, s, v1, v2):
    weight = torch.tensor([w, 1 - w])
    mean = torch.tensor([-s / 2, s / 2])
    variance = torch.tensor([v1, v2])
    return param2pdf(weight, mean, variance)

## Synthetic data set

In [None]:
TESTS = 1000
torch.manual_seed(78)
testset = pit.PITDataset(TESTS)

In [None]:
refset = pit.PITReference()
len(refset)

In [None]:
checkpoint = torch.load("models/skilled-monkey-956.pt")
hyperparams = checkpoint["hyperparams"]
interpreter = method.MDN(inputs=hyperparams["bins"],
                neurons=hyperparams["neurons"],
                components=hyperparams["m"])
interpreter.load_state_dict(checkpoint["model_state_dict"])
interpreter

In [None]:
pred_mdn = method.predict(interpreter, testset.X)

In [None]:
distances = euclidean_distances(testset.X, refset.X)
js = distances.argmin(axis=1)
for y, weight, mean, variance, neighbour in zip(testset.annotation,
                                          *pred_mdn,
                                          refset.annotation[js]):
    _, ax = plt.subplots()
    plot.density(ax, y2pdf(*y), label="data generating")
    plot.density(ax, param2pdf(weight, mean, variance), label="MDN")
    plot.density(ax, y2pdf(*neighbour), label="nearest neighbor")
    ax.legend()
    plt.show()
    plt.close()

In [None]:
nll_mdn = dist.nll_gaussian_mixture(testset.y, *pred_mdn).mean()
nll_mdn.item()

In [None]:
steps = range(5, 18)
nll_neighbour = []
for s in steps:
    refset = pit.PITReference(steps=s)
    distances = euclidean_distances(testset.X, refset.X)
    js = distances.argmin(axis=1)
    neighbour = refset.annotation[js]
    weight = torch.stack([neighbour[:, 0], 1 - neighbour[:, 0]], dim=1)
    mean = torch.stack((-neighbour[:, 1] / 2, neighbour[:, 1] / 2), dim=1)
    variance = neighbour[:, 2:]
    nll = dist.nll_gaussian_mixture(testset.y, weight, mean, variance)
    nll_neighbour.append(nll.mean().item())
    print(f"{s:2d} {nll_neighbour[-1]:f} {s ** 4}")

_, ax = plt.subplots()
size = [s ** 4 for s in steps]
ax.scatter(size, nll_neighbour, label="$k$-NN")
ax.axhline(nll_mdn, label="our MDN")
ax.set_xlabel("size of reference set")
ax.set_ylabel("NLL")
ax.legend()
nll_neighbour

In [None]:
df = pd.DataFrame.from_dict({"steps": steps,
                             "size": size,
                             "nll": nll_neighbour})
df = df.set_index("steps")
df.to_csv("data/neighbour.csv")
df

## Uniform PIT histogram

In [None]:
pit_hist_uniform = torch.full((pit.BINS, ), 1 / pit.BINS)
pred_uniform = method.predict(interpreter, pit_hist_uniform)
_, ax = plt.subplots()
plot.density(ax, dist.pdf_gaussian, label="data-generating")
plot.density(ax, param2pdf(*pred_uniform), label="predicted data-generating")
ax.legend()

## Physicochemical Properties of *Protein* Tertiary Structure

In [None]:
X_train, y_train = data.protein()
X_train, X_test, y_train, y_test = model_selection.train_test_split(X_train, y_train, test_size=0.1, random_state=33)
X_train, X_valid, y_train, y_valid = model_selection.train_test_split(X_train, y_train, test_size=0.1, random_state=79)
X_scaler, y_scaler = preprocessing.StandardScaler(), preprocessing.StandardScaler()
X_train = X_scaler.fit_transform(X_train)
X_valid = X_scaler.transform(X_valid)
y_scaler = y_scaler.fit(y_train)
X_train, y_train = data.array2tensor(X_train), data.array2tensor(X_valid)
X_valid, y_valid = data.array2tensor(X_valid), data.array2tensor(y_valid)
X_train.shape, X_valid.shape

In [None]:
def inverse_transform(y, scaler):
    y = list(y)
    y[-2] = (y[-2] * data.array2tensor(scaler.scale_)
             + data.array2tensor(scaler.mean_))
    y[-1] = y[-1] * data.array2tensor(scaler.var_)
    return y

### Density network

In [None]:
checkpoint = torch.load("models/driven-serenity-33.pt")
hyperparams = checkpoint["hyperparams"]
dn = method.MDN(inputs=9, neurons=hyperparams["neurons"], components=1)
dn.load_state_dict(checkpoint["model_state_dict"])
dn

In [None]:
y_dn = inverse_transform(method.predict(dn, X_valid)[1:], y_scaler)
pit_hist_dn = pit.pit_hist(pit.pit_gaussian(y_valid, *y_dn))
pred_dn = method.predict(interpreter, pit_hist_dn)

_, (ax1, ax2) = plt.subplots(nrows=2)
plot.pit_hist(ax1, pit_hist_dn, label="true")
plot.pit_hist(ax1, pit.pit_hist(pit.pit_gaussian(dist.sample_gaussian_mixture(*pred_dn))),
              label="predicted")
ax1.legend()
plot.density(ax2, dist.pdf_gaussian, label="predictive")
plot.density(ax2, param2pdf(*pred_dn), label="predicted data-generating")
ax2.legend()

### Deep Ensemble

In [None]:
checkpoint = torch.load("models/crimson-firebrand-43.pt")
hyperparams = checkpoint["hyperparams"]
de = method.DE(inputs=hyperparams["inputs"],
                 neurons=hyperparams["neurons"],
                 members=hyperparams["m"])
de.load_state_dict(checkpoint["model_state_dict"])
de

In [None]:
y_de = inverse_transform(method.predict(de, X_valid), y_scaler)
pit_hist_de = pit.pit_hist(pit.pit_gaussian(y_valid, *y_de))
pred_de = method.predict(interpreter, pit_hist_de)

_, (ax1, ax2) = plt.subplots(nrows=2)
plot.pit_hist(ax1, pit_hist_de, label="true")
plot.pit_hist(ax1, pit.pit_hist(pit.pit_gaussian(dist.sample_gaussian_mixture(*pred_de))),
              label="predicted")
ax1.legend()
plot.density(ax2, dist.pdf_gaussian, label="predictive")
plot.density(ax2, param2pdf(*pred_de), label="predicted data-generating")
ax2.legend()

### Mixture density network

In [None]:
checkpoint = torch.load("models/dainty-feather-49.pt")
hyperparams = checkpoint["hyperparams"]
mdn = method.MDN(inputs=9, neurons=hyperparams["neurons"], components=5)
mdn.load_state_dict(checkpoint["model_state_dict"])
mdn

In [None]:
y_mdn = inverse_transform(method.predict(mdn, X_valid), y_scaler)
pit_hist_mdn = pit.pit_hist(pit.pit_gaussian_mixture(y_valid, *y_mdn))
_, ax = plt.subplots()
plot.pit_hist(ax, pit_hist_dn, label="DN")
plot.pit_hist(ax, pit_hist_de, label="DE")
plot.pit_hist(ax, pit_hist_mdn, label="MDN (m = 5)")
ax.legend()

## Metrics

TODO: use test data.

In [None]:
methods = ["density network", "deep ensemble", "mixture density network"]
nll = [dist.nll_gaussian(y_valid, *y_dn).mean().item(),
       dist.nll_gaussian(y_valid, *y_de).mean().item(),
       dist.nll_gaussian_mixture(y_valid, *y_mdn).mean().item()]
crps = [dist.crps_gaussian(y_valid, *y_dn).mean().item(),
        dist.crps_gaussian(y_valid, *y_de).mean().item(),
        dist.crps_gaussian_mixture(y_valid, *y_mdn).mean().item()]

df = pd.DataFrame({"method": methods, "nll": nll, "crps": crps})
df = df.set_index("method")
df