# Variational Auto-Encoder

In [None]:
from functools import partial

from matplotlib import gridspec
from matplotlib import pyplot as plt
from matplotlib.patches import Ellipse

from vae import *

In [None]:
%matplotlib widget

In [None]:
def pit_hist(ax, x, n_bins, **kwargs):
    ax.hist(x, range=(0, 1), bins=n_bins, **kwargs)

def pit_stairs(ax, x, n_bins, **kwargs):
    ax.stairs(x, np.linspace(0, 1, n_bins + 1), **kwargs)

def get_grid():
    fig = plt.figure(tight_layout=True)
    gs = gridspec.GridSpec(2, 2)
    ax = fig.add_subplot(gs[0, :])
    ax_true = fig.add_subplot(gs[1, 0])
    ax_pred = fig.add_subplot(gs[1, 1])
    return fig, ax, ax_true, ax_pred

def plot_pred_press(event, ax, model, plot_function):
    x = float(event.xdata)
    y = float(event.ydata)
    if x is not None and y is not None:
        ax.clear()
        x_pred = model.decoder.decode(torch.tensor([[x, y]]))[0]
        plot_function(ax, x_pred, label=f"({x:.4f}, {y:.4f})")
        ax.legend()
        fig.canvas.draw()

def plot_true_pick(event, ax, dataset, model, plot_function):
    idx = event.ind[0]
    ax.clear()
    # true
    x, y = dataset.X[idx], dataset.y[idx]
    plot_function(ax, x, label=repr(y))
    # reconstruction
    mu, sigma = model.encode(x.unsqueeze(0))
    x_pred = model.decoder.decode(mu)[0]
    plot_function(ax, x_pred)
    ax.legend()
    fig.canvas.draw()

In [None]:
REPEATS = 1
SAMPLES = 1000

seed()
data_train = generate_data(REPEATS, SAMPLES)
data_test = generate_data(1, SAMPLES)
colors = ["red" if type(a[1]) is Normal else "green" for a in data_train[1]]
trainset, testset = PITHistDataset(*data_train, BINS), PITHistDataset(*data_test, BINS)
len(trainset), len(testset)

In [None]:
vae = VAE(input_dim=BINS, n_hiddens=1, n_neurons=16, epsilon=None)
vae.load_state_dict(torch.load("models/winter-grass-2.pt"))
mu_train, sigma_train = vae.encode(trainset.X)
vae

In [None]:
fig, ax, ax_true, ax_pred = get_grid()

for i in range(len(trainset)):
    size = 3 * sigma_train[i]
    e = Ellipse(xy=mu_train[i], width=size, height=size)
    ax.add_artist(e)
    e.set_clip_box(ax.bbox)
    e.set_alpha(0.1)
    e.set_facecolor("k")

ax.scatter(mu_train[:, 0], mu_train[:, 1], c=colors, marker="x", picker=True)
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
plot_function = partial(pit_stairs, n_bins=BINS)
fig.canvas.mpl_connect(
    "pick_event",
    partial(plot_true_pick, ax=ax_true, dataset=trainset, model=vae, plot_function=plot_function))
fig.canvas.mpl_connect(
    "button_press_event",
    partial(plot_pred_press, ax=ax_pred, model=vae, plot_function=plot_function))

In [None]:
pit_hist_uni = torch.full((1, 10), 0.1)
_, ax = plt.subplots()
pit_stairs(ax, pit_hist_uni[0], BINS, label="true")
pit_stairs(ax, vae.decoder.decode(vae.encode(pit_hist_uni)[0])[0], BINS, label="reconstruction")
ax.legend()