In [None]:
%matplotlib widget

In [None]:
import functools
import importlib
import math

from matplotlib import pyplot as plt
from matplotlib.patches import Ellipse
import torch
from torchvision import datasets

from calibration import data
from calibration import dists
from calibration import plot
from calibration import vae
from calibration import utils
from calibration import uci

In [None]:
DEVICE = torch.device("cpu")

## Data Set

In [None]:
BINS = 20
SAMPLES = 10000

In [None]:
utils.seed()
testset = data.PITHistDataset(SAMPLES, BINS, DEVICE)
len(testset)

## Variational Auto-Encoder

In [None]:
model = vae.VAE(inputs=BINS, hiddens=1, neurons=16, embeds=3, epsilon=None)
checkpoint = torch.load("models/whole-snowflake-373.pt")
model.load_state_dict(checkpoint)
model

In [None]:
mu_test, sigma_test = model.encode(testset.X)
X_pred = model.decode(mu_test)
rec = model.loss_rec(X_pred, testset.X)

### Uniform PIT histogram

In [None]:
uniform = torch.full((BINS, ), 1 / BINS)
_, ax = plt.subplots()
plot.pit_hist(ax, uniform, BINS, label="org")
mu_uniform, sigma_uniform = model.encode(uniform.unsqueeze(0))
uniform_pred = model.decode(mu_uniform)
uniform_pred = uniform_pred.squeeze()
plot.pit_hist(ax, uniform_pred, BINS, label="rec")
ax.legend()

### Protein Data Set

In [None]:
(_, proteinset), _ = uci.get_dataset("protein", seed=50, validation=False, preparation=True)
X_protein, y_protein = proteinset.tensors
nn = uci.NeuralNetwork(X_protein.shape[-1], {"loss": "nll", "neurons": 64, "hiddens": 1})
nn.load("nll-1-1-5432")
y_pred_protein = nn.predict(proteinset)
pit_values_protein = uci.normal_pit(*y_pred_protein, y_protein.cpu()).flatten()
pit_hist_protein = data.pit_hist(pit_values_protein, BINS)
mu_protein, sigma_protein = model.encode(pit_hist_protein)

_, ax = plt.subplots()
plot.pit_hist(ax, pit_hist_protein, BINS, label="org")
plot.pit_hist(ax, model.decode(mu_protein.unsqueeze(0)).squeeze(), BINS, label="rec")
ax.legend()

### 2-D Projection

In [None]:
colors = (testset.y[:, 2] == 1.5)

In [None]:
fig, ax, ax_pick, ax_press = plot.get_grid()

for i in range(len(testset)):
    size = 3 * sigma_test[i]
    e = Ellipse(xy=mu_test[i], width=3 * sigma_test[i, 0], height=3 * sigma_test[i, 1])
    ax.add_artist(e)
    e.set_clip_box(ax.bbox)
    e.set_alpha(0.1)
    e.set_facecolor("k")

cb = ax.scatter(mu_test[:, 0], mu_test[:, 1], c=testset.y[:, 2], picker=True)
plt.colorbar(cb)

plot_fn = functools.partial(plot.pit_hist, n_bins=BINS)
fig.canvas.mpl_connect("pick_event", functools.partial(plot.on_pick, ax=ax_pick, dataset=testset, model=model, plot_fn=plot_fn))
fig.canvas.mpl_connect("button_press_event", functools.partial(plot.on_button_press, ax=ax_press, model=model, plot_fn=plot_fn))

In [None]:
ax.scatter(mu_uniform[:, 0], mu_uniform[:, 1])
ax.scatter(mu_protein[0], mu_protein[1])

### 3-D Projection

In [None]:
fig, ax, ax_pick, ax_press = plot.get_grid(projection="3d")
cb = ax.scatter(mu_test[:, 0], mu_test[:, 1], zs=mu_test[:, 2], c=testset.y[:, 2], picker=True)
plt.colorbar(cb)

plot_fn = functools.partial(plot.pit_hist, n_bins=BINS)
fig.canvas.mpl_connect("pick_event", functools.partial(plot.on_pick, ax=ax_pick, dataset=testset, model=model, plot_fn=plot_fn))

In [None]:
ax.scatter(mu_uniform[:, 0], mu_uniform[:, 1], zs=mu_uniform[:, 2])
ax.scatter(mu_protein[0], mu_protein[1], zs=mu_protein[2])

### Reconstructions

In [None]:
domain = torch.linspace(-4, 4, 128)
for i, (x, y) in enumerate(testset):
    _, (ax1, ax2) = plt.subplots(nrows=2)
    plot.pit_hist(ax1, x, BINS, label="orig")
    plot.pit_hist(ax1, X_pred[i], BINS, label="rec")
    ax1.legend()
    dist_pred, dist_true = data.label2dists(*y)
    ax2.plot(domain, dist_pred.pdf(domain), label="predictive")
    ax2.plot(domain, dist_true.pdf(domain))
    ax2.legend()

## Nearest Neigbours

### Original Space

In [None]:
from sklearn.metrics.pairwise import euclidean_distances

In [None]:
dist = euclidean_distances(testset.X, pit_hist_protein.unsqueeze(0))
i = dist.argmin()
i, dist.min()

In [None]:
testset.y[i]

In [None]:
dist_pred, dist_true = data.label2dists(*testset.y[i])
dist_pred, dist_true

In [None]:
_, (ax, ax2) = plt.subplots(nrows=2)
plot.pit_hist(ax, pit_hist_protein, BINS, label="protein")
plot.pit_hist(ax, testset.X[i], BINS, label="nearest")
ax.legend()
domain = torch.linspace(-4, 4, 128)
ax2.plot(domain, dist_pred.pdf(domain), label="predictive")
ax2.plot(domain, dist_true.pdf(domain))
ax2.legend()