In [99]:
import torch
import sys
import numpy
import matplotlib.pyplot as plt
import ipywidgets as widgets

from ipywidgets.widgets.interaction import show_inline_matplotlib_plots

from scipy.stats import multivariate_normal

sys.path.append('../')

from src.cars.model import CarsConvVAE

In [21]:
# Set some hyperparameters
LATENT_DIM = 5
MODEL_NAME = "../vae_model_cars_gpu.h5"
NB_SIM = 9

MU = numpy.repeat(0, LATENT_DIM)
SIGMA = numpy.diag(numpy.repeat(1, LATENT_DIM))

Let's import the model

In [18]:
model = CarsConvVAE(LATENT_DIM)
model.load_state_dict(torch.load(MODEL_NAME))

<All keys matched successfully>

In [79]:
simulations = numpy.random.multivariate_normal(MU, SIGMA, NB_SIM)
simulations = torch.FloatTensor(simulations)

torch.Size([9, 3, 224, 224])

In [81]:
def plot_examples(dt_decoded):

    fig, axs = plt.subplots(3, 3)
    fig.tight_layout()
    axs = axs.ravel()

    for i in range(9):
        image = dt_decoded[i].transpose_(0,2).detach().numpy()
        axs[i].imshow(image)
        axs[i].axis('off')

    fig

In [100]:
dropdown_choices = numpy.arange(0, LATENT_DIM, 1)

latent_dropdown = widgets.Dropdown(options = dropdown_choices)
latent_slider = widgets.FloatSlider(value=0, min=-5, max=5, step=0.1)

output = widgets.Output()

def change_noise(dim, noise):
    output.clear_output()
    
    new_simulations = simulations.new(*simulations.size())
    new_simulations[:,dim] = simulations[:,dim] + torch.FloatTensor(numpy.repeat(noise, simulations.shape[0]))
    decoded_simulations = model.decode(new_simulations)
    fig = plot_examples(decoded_simulations)
    
    with output:
        show_inline_matplotlib_plots()
        
def latent_dropdown_eventhandler(change):
    change_noise(change.new, latent_slider.value)

def latent_slider_eventhandler(change):
    change_noise(latent_dropdown.value, change.new)

latent_dropdown.observe(latent_dropdown_eventhandler, names='value')
latent_slider.observe(latent_slider_eventhandler, names='value')

display(latent_dropdown)
display(latent_slider)
display(output)

Dropdown(options=(0, 1, 2, 3, 4), value=0)

FloatSlider(value=0.0, max=5.0, min=-5.0)

Output()