In [None]:
from functools import partial

from matplotlib import gridspec
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KDTree
import torch
import wandb

from pit import *

In [None]:
%matplotlib widget

In [None]:
torch.inf * 0, torch.inf * 1e-6

In [None]:
BINS = 8

In [None]:
def get_grid():
    fig = plt.figure(tight_layout=True)
    gs = gridspec.GridSpec(2, 2)
    ax = fig.add_subplot(gs[0, :])
    ax_true = fig.add_subplot(gs[1, 0])
    ax_pred = fig.add_subplot(gs[1, 1])
    return fig, ax, ax_true, ax_pred

def plot_pred(event, ax, model, plot_function):
    x = float(event.xdata)
    y = float(event.ydata)
    if x is not None and y is not None:
        ax.clear()
        reconstruction = model.decode(torch.tensor([[x, y]], device=DEVICE))[0]
        plot_function(ax, reconstruction, label=f"({x:.4f}, {y:.4f})")
        ax.legend()
        fig.canvas.draw()

def plot_true(event, ax, data, plot_function):
    idx = event.ind[0]
    ax.clear()
    x, y = data[idx]
    label = "\n".join(repr(y)[11:-2].split("), "))
    plot_function(ax, x, label=label)
    ax.legend()
    fig.canvas.draw()

In [None]:
def pit_hist(ax, x, **kwargs):
    ax.hist(x, range=(0, 1), bins=BINS, **kwargs)

def pit_stairs(ax, x, **kwargs):
    ax.stairs(x, np.linspace(0, 1, BINS + 1), **kwargs)

In [None]:
data = generate_data(10, 1000)
train_data, test_data = train_test_split(data, test_size=0.2, random_state=86)

In [None]:
colors = ["red" if math.isclose(a.pis[0], 1.0) else "green" for _, a in train_data]

In [None]:
dist_true = Normal(0, 1)
dist_pred = Normal(0, 1)
test_data = [(pit(dist_pred, dist_true.sample(5000)), torch.nan)]

## Embedder

In [None]:
embedder = EmbedderDecoder(len(train_data), embed_dim=2, hiddens=8, output_dim=512)
embedder.load_state_dict(torch.load("models/dainty-sweep-8.pt"))
embedder

In [None]:
hyperparams = {"bs": 32, "lr": 0.1, "patience": 1000}
new_embedder = embedder.new_data_set(PITValuesDataset(test_data), hyperparams)
embedding = new_embedder.embed(torch.tensor([0], device=DEVICE))
reconstruction = new_embedder.decode(embedding.to(DEVICE))
_, (ax1, ax2) = plt.subplots(nrows=2)
pit_hist(ax1, test_data[0][0])
pit_hist(ax2, reconstruction)

In [None]:
embeddings = embedder.embed(torch.arange(len(train_data), device=DEVICE))
tree = KDTree(embeddings, leaf_size=2, metric="euclidean")
js = tree.query(embedding, k=3, return_distance=False)[0]
[train_data[j][1] for j in js]

In [None]:
fig, ax, ax_true, ax_pred = get_grid()
ax.scatter(embeddings[:, 0], embeddings[:, 1], c=colors, alpha=0.5, picker=True)
ax.scatter(embedding[:, 0], embedding[:, 1], marker="*", s=100, c="black")
ax.scatter(embeddings[js, 0], embeddings[js, 1], marker="*", s=100, c="yellow")

#fig.canvas.mpl_connect("motion_notify_event", plot_pred)
fig.canvas.mpl_connect("button_press_event", partial(plot_pred, ax=ax_pred, model=embedder, plot_function=pit_hist))
fig.canvas.mpl_connect("pick_event", partial(plot_true, ax=ax_true, data=train_data, plot_function=pit_hist))

## Encoder

In [None]:
encoder = EncoderDecoder(input_dim=BINS, hiddens=64, embed_dim=2)
encoder.load_state_dict(torch.load("models/driven-sweep-439.pt"))
encoder

In [None]:
test_hist = bin_data(test_data, n_bins=BINS)[0][0]
embedding = encoder.encode(test_hist.unsqueeze(0).to(DEVICE))
reconstruction = encoder.decode(embedding.to(DEVICE))
_, (ax1, ax2) = plt.subplots(nrows=2)
pit_stairs(ax1, test_hist)
pit_stairs(ax2, reconstruction[0])

In [None]:
train_hists = bin_data(train_data, BINS)

In [None]:
embeddings = encoder.encode(torch.stack([h for h, _ in train_hists]).to(DEVICE))
tree = KDTree(embeddings, leaf_size=2, metric="euclidean")
js = tree.query(embedding, k=3, return_distance=False)[0]
[train_data[j][1] for j in js]

In [None]:
fig, ax, ax_true, ax_pred = get_grid()
ax.scatter(embeddings[:, 0], embeddings[:, 1], c=colors, alpha=0.5, picker=True)
ax.scatter(embedding[:, 0], embedding[:, 1], marker="*", s=100, c="black")
ax.scatter(embeddings[js, 0], embeddings[js, 1], marker="*", s=100, c="yellow")

#fig.canvas.mpl_connect("motion_notify_event", plot_pred)
fig.canvas.mpl_connect("button_press_event", partial(plot_pred, ax=ax_pred, model=encoder, plot_function=pit_stairs))
fig.canvas.mpl_connect("pick_event", partial(plot_true, ax=ax_true, data=train_hists, plot_function=pit_stairs))