In [None]:
import functools

from matplotlib import pyplot as plt
from matplotlib.patches import Ellipse
import torch

from calibration import data
from calibration import dists
from calibration import plot
from calibration import vae
from calibration import utils
from calibration import uci

In [None]:
%matplotlib widget

## Variational Auto-Encoder

In [None]:
BINS = 10
SAMPLES = 1000
DEVICE = torch.device("cpu")

In [None]:
utils.seed()
testset = data.PITHistDataset(SAMPLES, BINS, DEVICE)
len(testset)

In [None]:
model = vae.VAE(input_dim=BINS,
                n_hiddens=4,
                n_neurons=32,
                embed_dim=2,
                epsilon=None)
checkpoint = torch.load("models/fluent-mountain-382.pt")
model.load_state_dict(checkpoint)
model

In [None]:
mu_test, sigma_test = model.encode(testset.X)

In [None]:
uniform = torch.full((BINS, ), 1 / BINS)
_, ax = plt.subplots()
plot.pit_hist(ax, uniform, BINS, label="org")
mu_uniform, sigma_uniform = model.encode(uniform.unsqueeze(0))
uniform_pred = model.decode(mu_uniform).squeeze()
plot.pit_hist(ax, uniform_pred, BINS, label="rec")
ax.legend()

In [None]:
fig, ax, ax_pick, ax_press = plot.get_grid()

for i in range(len(testset)):
    size = 3 * sigma_test[i]
    e = Ellipse(xy=mu_test[i], width=3 * sigma_test[i, 0], height=3 * sigma_test[i, 1])
    ax.add_artist(e)
    e.set_clip_box(ax.bbox)
    e.set_alpha(0.5)
    e.set_facecolor("k")

ax.scatter(mu_test[:, 0], mu_test[:, 1], picker=True)
plot_fn = functools.partial(plot.pit_hist, n_bins=BINS)

plot_fn = functools.partial(plot.pit_hist, n_bins=BINS)
fig.canvas.mpl_connect("pick_event",
                       functools.partial(plot.on_pick,
                                         ax=ax_pick,
                                         dataset=testset,
                                         model=model,
                                         plot_fn=plot_fn))
fig.canvas.mpl_connect("button_press_event",
                       functools.partial(plot.on_button_press,
                                         ax=ax_press,
                                         model=model,
                                         plot_fn=plot_fn))

## Protein Data Set

In [None]:
(_, proteinset), _ = uci.get_dataset("protein", seed=50, validation=False, preparation=True)
x_protein, y_protein = proteinset.tensors
nn = uci.NeuralNetwork(x_protein.shape[-1],
                       {"loss": "nll", "neurons": 64, "hiddens": 1})
nn.load("nll-1-1-5432")
y_pred = nn.predict(proteinset)
pit_values = uci.normal_pit(*y_pred, y_protein.cpu())[:, 0]
pit_hist = data.pit_hist(pit_values, BINS)

mu_protein, sigma_protein = model.encode(pit_hist)

_, ax = plt.subplots()
plot.pit_hist(ax, pit_hist.squeeze(), BINS, label="org")
plot.pit_hist(ax, model.decode(mu_protein).squeeze(), BINS, label="rec")
ax.legend()