In [None]:
from functools import partial

from matplotlib import gridspec
from matplotlib import pyplot as plt

from pit import *

In [None]:
%matplotlib widget

In [None]:
def get_grid():
    fig = plt.figure(tight_layout=True)
    gs = gridspec.GridSpec(2, 2)
    ax = fig.add_subplot(gs[0, :])
    ax_true = fig.add_subplot(gs[1, 0])
    ax_pred = fig.add_subplot(gs[1, 1])
    return fig, ax, ax_true, ax_pred

def plot_pred(event, ax, model, plot_function):
    x = float(event.xdata)
    y = float(event.ydata)
    if x is not None and y is not None:
        ax.clear()
        reconstruction = model.decode(torch.tensor([[x, y]]))[0]
        plot_function(ax, reconstruction, label=f"({x:.4f}, {y:.4f})")
        ax.legend()
        fig.canvas.draw()

def plot_true(event, ax, data_set, plot_function):
    idx = event.ind[0]
    ax.clear()
    x, y = data_set.X[idx], data_set.y[idx]
    plot_function(ax, x, label=repr(y))
    ax.legend()
    fig.canvas.draw()

In [None]:
def pit_hist(ax, x, n_bins, **kwargs):
    ax.hist(x, range=(0, 1), bins=n_bins, **kwargs)

def pit_stairs(ax, x, n_bins, **kwargs):
    ax.stairs(x, np.linspace(0, 1, n_bins + 1), **kwargs)

In [None]:
seed()
n_repeats, n_samples = 10, 1000
train_data = generate_data(n_repeats, n_samples)
test_data = generate_data(n_repeats=1, n_samples=1000)
colors = ["red" if type(a[1]) is Normal else "green" for a in train_data[1]]

## Embedder

In [None]:
train_set, test_set = PITValuesDataset(*train_data), PITValuesDataset(*test_data)

embedder = EmbedderDecoder(len(train_set), embed_dim=2, n_hiddens=10, output_dim=100)
embedder.load_state_dict(torch.load("models/lemon-vortex-535.pt"))
embedder

In [None]:
train_embeds = embedder.embed(torch.arange(len(train_set)))
hyperparams = {"bs": 32, "lr": 0.1, "patience": 100}
test_embedder = embedder.new_data_set(test_set, hyperparams)
test_embeds = test_embedder.embed(torch.arange(len(test_set)))

In [None]:
fig, ax, ax_true, ax_pred = get_grid()
ax.scatter(train_embeds[:, 0], train_embeds[:, 1], c=colors, alpha=0.5, picker=True)
ax.scatter(test_embeds[:, 0], test_embeds[:, 1], marker="*", s=100, c="black")

plot_function = partial(pit_hist, n_bins=20)
fig.canvas.mpl_connect(
    "button_press_event",
    partial(plot_pred, ax=ax_pred, model=embedder, plot_function=plot_function))
fig.canvas.mpl_connect(
    "pick_event",
    partial(plot_true, ax=ax_true, data_set=train_set, plot_function=plot_function))

## Encoder

In [None]:
BINS = 20

In [None]:
encoder = EncoderDecoder(input_dim=BINS, n_hiddens=10, embed_dim=2)
encoder.load_state_dict(torch.load("models/stilted-snowflake-536.pt"))
encoder

In [None]:
train_set, test_set = PITHistDataset(*train_data, BINS), PITHistDataset(*test_data, BINS)

In [None]:
train_embeds = encoder.embed(train_set.X)
test_embeds = encoder.embed(test_set.X)
train_embeds.shape, test_embeds.shape

In [None]:
fig, ax, ax_true, ax_pred = get_grid()
ax.scatter(train_embeds[:, 0], train_embeds[:, 1], c=colors, alpha=0.5, picker=True)
ax.scatter(test_embeds[:, 0], test_embeds[:, 1], marker="*", s=100, c="black")

plot_function = partial(pit_stairs, n_bins=20)
fig.canvas.mpl_connect(
    "button_press_event",
    partial(plot_pred, ax=ax_pred, model=encoder, plot_function=plot_function))
fig.canvas.mpl_connect(
    "pick_event",
    partial(plot_true, ax=ax_true, data_set=train_set, plot_function=plot_function))