In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from ipywidgets import *
import warnings

import torch
from torchvision.utils import make_grid

from models.semisup_vae import REVAE

warnings.filterwarnings("ignore")
%matplotlib inline

In [None]:
revae = REVAE()
data = torch.load('./data/celeba.pt')

In [None]:
# choose random 100 samples
batch = data[np.random.choice(data.size(0), 100)]
grid = make_grid(batch, nrow=10, padding=20, pad_value=1.)

# plot images
matplotlib.rc('axes', edgecolor="#ffffff")
plt.clf()
fig = plt.figure(figsize=(14, 14))
ax = fig.gca()
ax.set_xticks(np.arange(0.48, 10, 1))
ax.set_yticks(np.arange(0.48, 10, 1))
ax.set_xticklabels(np.arange(0, 10, 1))
ax.set_yticklabels(np.arange(9, -1, -1))
ax.tick_params(axis='both', which='both', left=False, right=False, bottom=False, top=False, labelsize=14)
plt.imshow(grid.permute(1, 2, 0), extent=[0, 10, 0, 10], origin='upper')

In [None]:
# get image coordinate from user
coords = input("Enter a comma seperated grid coordinate to obtain an image to intervene on, e.g 2,5: ")
img = batch[10 * int(coords[0]) + int(coords[2])]

# show selected image
%matplotlib inline
plt.clf()
plt.figure(figsize=(2,2))
plt.axis('off')
print('Chosen Image:')
plt.imshow(img.permute(1, 2, 0))

In [None]:
# get the reconstruction for the selected image
with torch.no_grad():
    recon = revae.reconstruct_img(img.unsqueeze(0))[0]

# setup interactive plot
%matplotlib notebook
fig = plt.figure(figsize = (10,4))
ax1 = fig.add_subplot(1, 2, 1)
ax2 = fig.add_subplot(1, 2, 2)
ax1.axis('off')
ax2.axis('off')
ax1.imshow(img.permute(1, 2, 0))
im = ax2.imshow(recon.permute(1, 2, 0))

# get initial latent value
z = revae._z_prior_fn(*revae.encoder_z(img.unsqueeze(0))).sample()

# define the sliders
spaced_vertical = Layout(display='flex', align_items='stretch', min_height='7ex', width='90%', margin='1ex')
widgets = []
for i in range(18):
    slider = FloatSlider(min=revae.lims[i][0], max=revae.lims[i][1], step=1.0, value=z[0, i].item(),
                        layout=spaced_vertical, style={'description_width': '20%'})
    widgets.append(slider)


def update(Arched_Eyebrows, Bags_Under_Eyes, Bangs,
         Black_Hair, Blond_Hair, Brown_Hair,
         Bushy_Eyebrows, Chubby, Eyeglasses,
         Heavy_Makeup, Male, No_Beard,
         Pale_Skin, Receding_Hairline, Smiling,
         Wavy_Hair, Wearing_Necktie, Young):
    """
    Update the latent values and then reconstruct.
    """
    z[0, :18] = torch.tensor([[Arched_Eyebrows, Bags_Under_Eyes, Bangs,
                             Black_Hair, Blond_Hair, Brown_Hair,
                             Bushy_Eyebrows, Chubby, Eyeglasses,
                             Heavy_Makeup, Male, No_Beard,
                             Pale_Skin, Receding_Hairline, Smiling,
                             Wavy_Hair, Wearing_Necktie, Young]])
    with torch.no_grad():
        img = revae.decoder(z).squeeze()
    im.set_data(img.permute(1, 2, 0))
    fig.canvas.draw()
    fig.canvas.flush_events()


interact(update, 
         Arched_Eyebrows=widgets[0], Bags_Under_Eyes=widgets[1], Bangs=widgets[2],
         Black_Hair=widgets[3], Blond_Hair=widgets[4], Brown_Hair=widgets[5],
         Bushy_Eyebrows=widgets[6], Chubby=widgets[7], Eyeglasses=widgets[8],
         Heavy_Makeup=widgets[9], Male=widgets[10], No_Beard=widgets[11],
         Pale_Skin=widgets[12], Receding_Hairline=widgets[13], Smiling=widgets[14],
         Wavy_Hair=widgets[15], Wearing_Necktie=widgets[16], Young=widgets[17]);