In [1]:
import torch
import pandas as pd
%matplotlib widget
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm, trange
import ipywidgets as widgets
from ipywidgets import HBox, VBox, interactive
import numpy as np
from IPython.display import display
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
import gc
import seaborn as sns

from feature_vae import FeatureVAE

ModuleNotFoundError: No module named 'ipympl'

In [None]:
model = torch.load('models/1-featurevae-20dim')
model.eval()
z_dim = model.z_dim
num_features = model.num_features
shape_z = model.random_z

In [None]:
valid = torch.load('valid_shapes.pt')
valid_meta = pd.read_csv('valid_meta.csv')
valid_meta

In [None]:
valid_z = []
colors = valid_meta[['r', 'g', 'b']].values
with torch.no_grad():
    for i in range(0, len(valid), 100):
        batch = valid[i:i+100].contiguous()
        labels = torch.FloatTensor(colors[i:i+100])
        _, _, _, z = model(batch, labels)
        valid_z.append(z)
valid_z = torch.vstack(valid_z)
valid_z.shape

In [None]:
tsne = TSNE(n_components=2, n_iter=300)
tsne_results = tsne.fit_transform(valid_z)
plt.figure(figsize=(10,10))
color_dict = {
    color: 'xkcd:'+color
    for color in valid_meta.color
}
sns.scatterplot(
    x=tsne_results[:,0], y=tsne_results[:,1],
#     hue=valid_meta['shape'],
    hue=valid_meta.color,
    palette=color_dict,
#     legend="full"
).set(title='A T-SNE plot of Latent Dims');

In [None]:
z_slides = [
    widgets.FloatSlider(
    value=shape_z[i],
    min=-5,
    max=5,
    step=0.1,
    description=f'dim {i+num_features}',
    readout_format='.1f',
    orientation='vertical',
        continuous_update=False
    ) for i in range(valid_z.size(1))
]

dim_range_slider = widgets.FloatRangeSlider(
    value=[-4., 4.],
    min=-10., max=10., step=0.1,
    description='Range',
    readout_format='.1f',
        continuous_update=False
)
dim_select_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=z_dim-1,
    description='Latent Dim',
        continuous_update=False
)
num_im_slider = widgets.IntSlider(
    value=5,
    min=2,
    max=11,
    description='Num Images',
        continuous_update=False
)

color_picker = widgets.ColorPicker(
    concise=False,
    description='Pick a color',
    value='black',
    disabled=False
)

tab1 = HBox(children=z_slides)
tab2 = VBox(children=[dim_select_slider, dim_range_slider, num_im_slider])
tab3 = HBox(children = [color_picker])
tab = widgets.Tab(children=[tab3, tab1, tab2])
tab.set_title(0, 'Color Picker')
tab.set_title(1, 'Initial Point')
tab.set_title(2, 'Vary Settings')
# VBox(children=[tab, widget])

def simple_plot(dim_range, dim_select, num_ims, color, **z_slides):
#     plt.close('all')

    with torch.no_grad():
        
        init_z = list(matplotlib.colors.to_rgb(color)) + [z_slides[f'{i}'] for i in range(len(z_slides))]
        init_z = [init_z] * num_ims
        init_z = torch.Tensor(init_z)
        middle = init_z[num_ims // 2, dim_select].item()
        init_z[:, dim_select] = torch.linspace(*dim_range, num_ims)
        init_z[num_ims // 2, dim_select] = middle
        
        colors, init_za = init_z[:, :num_features], init_z[:, num_features:]
        
        ims = model.generate(
            init_za, 
            colors
        ).permute(0, 2, 3, 1 ).cpu()
    
    fig, axes = plt.subplots(1, num_ims)
#     fig.suptitle(f'Varying  dim {dim_select} with {num_ims} images', y=0)
    fig.set_size_inches(10, 2)
    for i in range(num_ims):
        axes[i].imshow(ims[i])
        axes[i].set_title(f'dim {dim_select}={init_z[i, dim_select].item():.2f}')
        axes[i].title.set_size(10)
    plt.show()

d = {'dim_range': dim_range_slider, 'dim_select':dim_select_slider, 'num_ims': num_im_slider,
     'color':color_picker,
     **{f'{i}':slide for i, slide in enumerate(z_slides)}}
out = widgets.interactive_output(simple_plot, d)

In [None]:
widgets.VBox([out, tab])