In [None]:
import os
import time
import numpy as np
import tensorflow as tf

import config
import tfutil
import misc
import random
import string
from PIL import Image
import glob

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

import matplotlib.pyplot as plt
import matplotlib
%matplotlib notebook

In [None]:
tfutil.init_tf(config.tf_config)

resume_network_pkl = "./results/path-to-training/network-snapshot-??????.pkl"
resume_network_pkl = "./results/023-pgan-mnist-cond-preset-v2-1gpu-fp32-nogrowing-VERBOSE/network-snapshot-000468.pkl"

with tf.device('/gpu:0'):
    G, D, Gs = misc.load_pkl(resume_network_pkl)
    
imsize = Gs.output_shape[-1]

# Sample Random Shapes

In [None]:
mask_list = glob.glob("./dataset/masks/*.png")

def get_random_mask(batch_size):
    
    ix = np.random.randint(len(mask_list), size=(batch_size,))
    
    random_masks = []
    for i in ix:
        temp = Image.open(mask_list[i])
        temp = temp.resize((imsize, imsize))
        temp = (np.float32(temp) - 127.5)/127.5
        temp = temp.reshape((1, 1, imsize, imsize))
        random_masks.append(temp)
    random_masks = np.vstack(random_masks)
    
    return random_masks

def get_random_color(batch_size):
    return np.random.rand(batch_size, 3) * 2 - 1

def convert_to_image(x):
    return x.transpose((0,2,3,1)).clip(-1, 1) * 0.5 + 0.5

# Interactive Color/Texture Adjustment

In [None]:
%matplotlib notebook

selected_shape = get_random_mask(1)

z1 = misc.random_latents(1, Gs)
z2 = misc.random_latents(1, Gs)

N = 2
fig = plt.figure(figsize=(N*6,6))
ax = []
art = []
for i in range(N):
    ax+=[fig.add_subplot(1,N,i+1)]
    ax[-1].axis('off')
    art += [ax[-1].imshow(np.zeros((1,1)))]

def f(r, z):
    
    r = r[1:]
    
    selected_color = np.array([[int(r[i:i+2], 16) for i in [0, 2, 4]]])
    selected_color = (np.float32(selected_color) - 127.5) / 127.5
    
    selected_texture = z1 + z * (z2 - z1)
    
    st = time.time()
    GI = Gs.run(selected_texture, selected_color, selected_shape)
    et = time.time()
    
    GI = convert_to_image(GI)

    art[0].set_array(GI[0])
    ax[0].set_title(f"Inference Time: {(et-st)*1000:.0f} ms")
    art[0].autoscale()
    
    art[1].set_array(selected_shape[0,0])
    art[1].autoscale()
    
interactive_plot = interactive(f,
                               r=widgets.ColorPicker(concise=False, description='Pick a color', value='#aa00cc', disabled=False),
                               z=widgets.FloatSlider(min=0.0,max=1.0,step=0.1,value=0.0),
                              )

output = interactive_plot.children[-1]
output.layout.height = '500px'

interactive_plot

# Change Color

In [None]:
selected_textures = misc.random_latents(1, Gs).repeat(3, 0)

selected_shapes = get_random_mask(1).repeat(3, 0)

selected_colors = get_random_color(3)

fake_images = Gs.run(selected_textures, selected_colors, selected_shapes)

fake_images = convert_to_image(fake_images)

plt.figure(figsize=(12,4))
for i in range(3):
    plt.subplot(1,4,i+1)
    plt.imshow(fake_images[i])
    plt.axis('off')

plt.subplot(1,4,4)
plt.imshow(selected_shapes[0, 0], cmap='gray', vmin=0.0, vmax=1.0)
plt.axis('off')

# Change Texture

In [None]:
selected_textures = misc.random_latents(3, Gs)

selected_colors = get_random_color(1).repeat(3, 0)

selected_shapes = get_random_mask(1).repeat(3, 0)

fake_images = Gs.run(selected_textures, selected_colors, selected_shapes)

fake_images = convert_to_image(fake_images)

plt.figure(figsize=(12,4))
for i in range(3):
    plt.subplot(1,4,i+1)
    plt.imshow(fake_images[i])
    plt.axis('off')
plt.subplot(1,4,4)
plt.imshow(selected_shapes[i,0])
plt.axis('off')

# Change Shape

In [None]:
selected_textures = misc.random_latents(1, Gs).repeat(3, 0)

selected_colors = get_random_color(1).repeat(3, 0)

selected_shapes = get_random_mask(3)

fake_images = Gs.run(selected_textures, selected_colors, selected_shapes)

fake_images = convert_to_image(fake_images)

plt.figure(figsize=(12,4))
for i in range(3):
    plt.subplot(2,3,i+1)
    plt.imshow(fake_images[i])
    plt.axis('off')
    plt.subplot(2,3,i+4)
    plt.imshow(selected_shapes[i, 0], cmap='gray', vmin=0.0, vmax=1.0)
    plt.axis('off')