## Interactive Digit Generator

Use the sliders to modify the latent vector $z$ and visualize the change of the reconstructed digit.   
Requires the usage of the [ipywidgets](https://ipywidgets.readthedocs.io/en/latest/user_install.html) notebook extension.

In [16]:
import torch
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from models import ConvVAE

In [17]:
# Load saved and trained VAE model
model_weights = torch.load('models/mnist_dc_vae.pt')
model = ConvVAE(latent_size=3)
model.load_state_dict(model_weights)
print(model)

ConvVAE(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2))
  (conv3): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))
  (encoder_fc_1): Linear(in_features=1600, out_features=512, bias=True)
  (encoder_mu): Linear(in_features=512, out_features=3, bias=True)
  (encoder_logvar): Linear(in_features=512, out_features=3, bias=True)
  (decoder_fc_1): Linear(in_features=3, out_features=512, bias=True)
  (decoder_fc_2): Linear(in_features=512, out_features=1600, bias=True)
  (deconv1): ConvTranspose2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (deconv2): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), output_padding=(1, 1))
  (deconv3): ConvTranspose2d(16, 1, kernel_size=(3, 3), stride=(1, 1))
)


In [45]:
%matplotlib inline
print('Change latent variable z by controlling the sliders:')

# The latent vector z of the model has three dimensions
z_slider = [widgets.FloatSlider(value=0, min=-4, max=4, step=0.02, continuous_update=True) for i in range(3)]
ui = widgets.HBox(z_slider)

def reconstruct(z_0, z_1, z_2):
    z = np.array([[z_0, z_1, z_2]], dtype=np.float32)
    z = torch.from_numpy(z)
    x_hat = model.decode(z)
    x_hat = x_hat.view(28, 28).detach().numpy()   
    return x_hat

def slider_change(z_0, z_1, z_2):
    plt.imshow(reconstruct(z_0, z_1, z_2), cmap='gray')                       

out = widgets.interactive_output(slider_change, 
    {'z_0': z_slider[0], 'z_1': z_slider[1], 'z_2': z_slider[2]})
display(ui, out)

Change latent variable z by controlling the sliders:


HBox(children=(FloatSlider(value=0.0, max=4.0, min=-4.0, step=0.02), FloatSlider(value=0.0, max=4.0, min=-4.0,…

Output()