# Variational autoencoders for structure parameterization

## Setup

In [None]:
import sys

if "google.colab" in sys.modules:
    from google.colab import output

    output.enable_custom_widget_manager()
    if "mlatint" not in sys.modules:
        !sudo apt install libcairo2-dev pkg-config python3-dev
        !pip install git+https://github.com/yaugenst/mlatint

In [None]:
%matplotlib ipympl
import torch
import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from mlatint import VAE, Sampler

torch.set_grad_enabled(False);

## Visualizing geometry samplers

We will start by visualizing two types of geometries that we can generate using our pre-defined samplers.

In [None]:
blobs = Sampler("blobs")
triangles = Sampler("triangles")

In [None]:
plt.close("all")
fig, ax = plt.subplots(1, 2, figsize=(9, 4), sharey=True)

ax[0].set_title("Blobs")
ax[1].set_title("Triangles")
im1 = ax[0].imshow(blobs.sample(), cmap="gray_r")
im2 = ax[1].imshow(triangles.sample(), cmap="gray_r")
plt.show()


def update(b=None):
    im1.set_data(blobs.sample())
    im2.set_data(triangles.sample())


button = widgets.Button(description="Sample")
button.on_click(update)
button

## Geometry reconstruction with VAEs

Now, we will test our VAEs to see how well they can reconstruct input geometries. Play around with different combinations of samplers and VAEs!

In [None]:
vae_blobs = VAE("blobs")
vae_triangles = VAE("triangles")

In [None]:
plt.close("all")
fig, ax = plt.subplots(1, 2, figsize=(9, 4), sharey=True)

ax[0].set_title("Input")
ax[1].set_title("Reconstruction")

_init = blobs.sample()
im1 = ax[0].imshow(_init, cmap="gray_r")
im2 = ax[1].imshow(vae_blobs(_init), cmap="gray_r")
plt.show()

radio_sampler = widgets.RadioButtons(
    options=["Blobs", "Triangles"], description="Geometry", disabled=False
)

radio_model = widgets.RadioButtons(
    options=["Blobs", "Triangles"], description="VAE model", disabled=False
)


def update(_):
    geom_val = radio_sampler.get_interact_value()
    model_val = radio_model.get_interact_value()

    sampler = blobs if geom_val == "Blobs" else triangles
    model = vae_blobs if model_val == "Blobs" else vae_triangles

    sample = sampler.sample()

    im1.set_data(sample)
    im2.set_data(model(sample))


button = widgets.Button(description="Sample")
button.on_click(update)

widgets.HBox([radio_sampler, radio_model, button])

## Latent space representation

As VAEs are all about the latent space, we will now see what that looks like.

In [None]:
x = blobs.sample()
x_encoded = vae_blobs.encode(x)
x_decoded = vae_blobs.decode(x_encoded)

In [None]:
plt.close("all")
fig = plt.figure(figsize=(8, 3), tight_layout=True)
gs = gridspec.GridSpec(2, 8)

ax_in = fig.add_subplot(gs[:, :2])
ax_out = fig.add_subplot(gs[:, -2:])
ax_latent = fig.add_subplot(gs[0, 2:-2])
ax_hist = fig.add_subplot(gs[1, 2:-2])

ax_in.set_title("Input")
ax_out.set_title("Reconstruction")
ax_in.axis("off")
ax_out.axis("off")

ax_latent.set_title("Latent values")
ax_hist.set_title("Latent hist")

im1 = ax_in.imshow(x, cmap="gray_r")
im2 = ax_out.imshow(x_decoded, cmap="gray_r")
ax_latent.plot(x_encoded, "k.")
ax_hist.hist(x_encoded, 32)
plt.show()

radio = widgets.RadioButtons(
    options=["Blobs", "Triangles"], description="Geometry", disabled=False
)


def update(_):
    geom_val = radio.get_interact_value()

    sampler = blobs if geom_val == "Blobs" else triangles
    model = vae_blobs if geom_val == "Blobs" else vae_triangles

    x = sampler.sample()
    x_encoded = model.encode(x)
    x_decoded = model.decode(x_encoded)

    im1.set_data(x)
    im2.set_data(x_decoded)
    ax_latent.cla()
    ax_hist.cla()
    ax_latent.plot(x_encoded, "k.")
    ax_hist.hist(x_encoded, 32)
    ax_latent.set_title("Latent values")
    ax_hist.set_title("Latent hist")


button = widgets.Button(description="Sample")
button.on_click(update)

widgets.HBox([radio, button])

## Direct latent space sampling

The cool thing about these latent spaces is that we don't need the encoder at all!
We can just sample a latent vector directly and have our decoder create a geometry from it.

In [None]:
latent_vector = torch.FloatTensor(256).normal_(0.8, 0.4)

In [None]:
plt.close("all")
fig, ax = plt.subplots(1, 2, figsize=(9, 4))

ax[0].set_title("Latent hist")
ax[1].set_title("Decoded")

ax[0].hist(latent_vector, 32)
im1 = ax[1].imshow(vae_blobs.decode(latent_vector), cmap="gray_r")
plt.show()

radio = widgets.RadioButtons(
    options=["Blobs", "Triangles"], description="VAE model", disabled=False
)


def update(_):
    model_val = radio.get_interact_value()
    model = vae_blobs if model_val == "Blobs" else vae_triangles

    latent_vector = torch.FloatTensor(256).normal_(0.8, 0.4)

    ax[0].cla()
    ax[0].set_title("Latent hist")
    ax[0].hist(latent_vector, 32)
    im1.set_data(model.decode(latent_vector))


button = widgets.Button(description="Sample")
button.on_click(update)

widgets.HBox([radio, button])

## Latent space interpolation

The latent space vectors enable us to smoothly interpolate between geometries without using an explicit parameterization.
Here, we will sample two latent vectors and interpolate between them.

In [None]:
x1 = blobs.sample()
x2 = blobs.sample()
x1_encoded = vae_blobs.encode(x1)
x2_encoded = vae_blobs.encode(x2)
x1_decoded = vae_blobs.decode(x1_encoded)
x2_decoded = vae_blobs.decode(x2_encoded)

In [None]:
plt.close("all")
fig = plt.figure(figsize=(8, 3), tight_layout=True)
gs = gridspec.GridSpec(2, 8)

ax_in = fig.add_subplot(gs[0, :2])
ax_out = fig.add_subplot(gs[1, :2])
ax_interp = fig.add_subplot(gs[:, -2:])
ax_latent = fig.add_subplot(gs[:, 2:-2])

ax_in.set_title("Input")
ax_out.set_title("Target")
ax_interp.set_title("Interpolation")
ax_in.axis("off")
ax_out.axis("off")
ax_interp.axis("off")

ax_latent.set_title("Latent values")
ax_hist.set_title("Latent hist")

im1 = ax_in.imshow(x1, cmap="gray_r")
im2 = ax_out.imshow(x2, cmap="gray_r")
im3 = ax_interp.imshow(x1, cmap="gray_r")
ax_latent.plot(
    x1_encoded, c="tab:red", ls=" ", marker=".", ms=3, alpha=0.6, label="Input"
)
ax_latent.plot(
    x2_encoded, c="tab:blue", ls=" ", marker=".", ms=4, alpha=0.6, label="Target"
)
ax_latent.plot(
    x1_encoded, c="k", ls=" ", marker="x", ms=4, alpha=0.8, label="Interpolation"
)
ax_latent.legend(ncols=3)
plt.show()

radio = widgets.RadioButtons(
    options=["Blobs", "Triangles"], description="Geometry", disabled=False
)

slider = widgets.FloatSlider(
    value=0,
    min=0,
    max=1,
    step=0.01,
    description="Ratio",
    disabled=False,
    continuous_update=True,
    orientation="horizontal",
    readout=True,
    readout_format=".2f",
)

_state = {
    "x1_encoded": x1_encoded,
    "x2_encoded": x2_encoded,
    "model": vae_blobs,
}


def interpolate(_):
    ratio = float(slider.get_interact_value())

    interp = (1 - ratio) * _state["x1_encoded"] + ratio * _state["x2_encoded"]
    interp_decoded = _state["model"].decode(interp)
    im3.set_data(interp_decoded)

    ax_latent.cla()
    ax_latent.plot(
        _state["x1_encoded"],
        c="tab:red",
        ls=" ",
        marker=".",
        ms=3,
        alpha=0.6,
        label="Input",
    )
    ax_latent.plot(
        _state["x2_encoded"],
        c="tab:blue",
        ls=" ",
        marker=".",
        ms=4,
        alpha=0.6,
        label="Target",
    )
    ax_latent.plot(
        interp, c="k", ls=" ", marker="x", ms=4, alpha=0.8, label="Interpolation"
    )
    ax_latent.legend(ncols=3)
    ax_latent.set_title("Latent values")


def update(_):
    geom_val = radio.get_interact_value()
    ratio = slider.get_interact_value()

    sampler = blobs if geom_val == "Blobs" else triangles
    model = vae_blobs if geom_val == "Blobs" else vae_triangles

    x1 = sampler.sample()
    x2 = sampler.sample()
    x1_encoded = model.encode(x1)
    x2_encoded = model.encode(x2)
    x1_decoded = model.decode(x1_encoded)
    x2_decoded = model.decode(x2_encoded)
    interp = (1 - ratio) * x1_encoded + ratio * x2_encoded
    interp_decoded = model.decode(interp)

    _state["x1_encoded"] = x1_encoded
    _state["x2_encoded"] = x2_encoded
    _state["model"] = model

    im1.set_data(x1)
    im2.set_data(x2)
    im3.set_data(interp_decoded)

    ax_latent.cla()
    ax_latent.plot(
        x1_encoded, c="tab:red", ls=" ", marker=".", ms=3, alpha=0.6, label="Input"
    )
    ax_latent.plot(
        x2_encoded, c="tab:blue", ls=" ", marker=".", ms=4, alpha=0.6, label="Target"
    )
    ax_latent.plot(
        interp, c="k", ls=" ", marker="x", ms=4, alpha=0.8, label="Interpolation"
    )
    ax_latent.legend(ncols=3)
    ax_latent.set_title("Latent values")


slider.observe(interpolate)
button = widgets.Button(description="Sample")
button.on_click(update)

widgets.HBox([radio, button, slider])