In [None]:
import os
import pickle
import numpy as np
import PIL.Image
import PIL.ImageSequence
import sys
sys.path.append("stylegan2")
import dnnlib
import dnnlib.tflib as tflib
import tensorflow as tf
import copy
from IPython.display import display, clear_output
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

In [2]:
stylegan2model = "./models/70knetwork-snapshot-002800.pkl"

tflib.init_tf()

# Load pre-trained network.
with open(stylegan2model, 'rb') as f:
    _G, _D, Gs = pickle.load(f)

Setting up TensorFlow plugin "fused_bias_act.cu": Preprocessing... Loading... Done.
Setting up TensorFlow plugin "upfirdn_2d.cu": Preprocessing... Loading... Done.


In [11]:
def truncate(dlatents, truncation_psi, maxlayer = 10):
    dlatent_avg = tf.get_default_session().run(Gs.own_vars["dlatent_avg"])
    layer_idx = np.arange(10)[np.newaxis, :, np.newaxis]
    ones = np.ones(layer_idx.shape, dtype=np.float32)
    coefs = tf.where(layer_idx < maxlayer, truncation_psi * ones, ones)
    return tf.get_default_session().run(tflib.lerp(dlatent_avg, dlatents, coefs))

# Generate image with disentangled latents as input
def generate_images_from_dlatents(dlatents, truncation_psi = 1.0, randomize_noise = True):
    if not truncation_psi is None:
        dlatents_trunc = truncate(dlatents, truncation_psi)
    else:
        dlatents_trunc = dlatents
        
    # Run the network
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    result_image = Gs.components.synthesis.run(
        dlatents_trunc.reshape((-1, 10, 512)),
        randomize_noise = randomize_noise,
        minibatch_size = 1,
        output_transform=fmt
    )[0]
    return result_image

# Init randomizer, load latents
rnd = np.random
latents_a = rnd.randn(1, Gs.input_shape[1])
latents_b = rnd.randn(1, Gs.input_shape[1])
latents_c = rnd.randn(1, Gs.input_shape[1])

tag_directions = {}
with open("tagged_dlatents/tag_dirs_cont1.pkl", 'rb') as f:
    tag_directions = pickle.load(f)

In [12]:
##
# UI
##
hair_eyes_only = False
    
tag_len = {}
for tag in tag_directions:
    tag_len[tag] = np.linalg.norm(tag_directions[tag].flatten())

mod_latents = copy.deepcopy(latents_a)
dlatents_gen = Gs.components.mapping.run(mod_latents, None)[0]  
def modify_and_sample(psi_pre, psi_post, truncate_pre, truncate_post, **kwargs):
    if truncate_pre == True:
        dlatents_mod = truncate(copy.deepcopy(dlatents_gen), psi_pre)
    else:
        dlatents_mod = copy.deepcopy(dlatents_gen)
        
    for tag in kwargs:
        dlatents_mod += tag_directions[tag] * kwargs[tag]  / tag_len[tag] * 25.0
    value_widgets["psi_pre"].value = str(round(psi_pre, 2))
    value_widgets["psi_post"].value = str(round(psi_post, 2))
    
    for tag in kwargs:
        tag_value = round((np.dot(dlatents_mod.flatten(), tag_directions[tag].flatten()) / tag_len[tag]), 2)
        value_widgets[tag].value = str(kwargs[tag]) + " | " + str(tag_value)
    
    display_psi = None
    if truncate_post == True:
        display_psi = psi_post
    display(PIL.Image.fromarray(generate_images_from_dlatents(dlatents_mod, truncation_psi = display_psi), 'RGB').resize(size=(128, 128)))

psi_slider = widgets.FloatSlider(min = 0.0, max = 1.0, step = 0.01, value = 0.60, continuous_update = False, readout = False)
psi_slider_post = widgets.FloatSlider(min = 0.0, max = 1.0, step = 0.01, value = 0.60, continuous_update = False, readout = False)
if hair_eyes_only:
    modify_tags = [tag for tag in tag_directions if "_hair" in tag or "_eyes" in tag or "_mouth" in tag]
else:
    with open("tagged_dlatents/tags_use1.pkl", "rb") as f:
        modify_tags = pickle.load(f)

tag_widgets = {}
for tag in modify_tags:
    tag_widgets[tag] = widgets.FloatSlider(min = -2.0, max = 6.0, step = 0.005, continuous_update = False, readout = False)
all_widgets = []

sorted_widgets = sorted(tag_widgets.items(), key = lambda x: x[0])
sorted_widgets = [("psi_pre", psi_slider), ("psi_post", psi_slider_post)] + sorted_widgets
value_widgets = {}
for widget in sorted_widgets:
    label_widget = widgets.Label(widget[0])
    label_widget.layout.width = "170px"
    
    value_widget = widgets.Label("0.0+100.0")
    value_widget.layout.width = "150px"
    value_widgets[widget[0]] = value_widget
    
    tag_hbox = widgets.HBox([label_widget, widget[1], value_widget])
    tag_hbox.layout.width = "320px"
    
    all_widgets.append(tag_hbox)

refresh = widgets.Button(description="New Sample")
modify = widgets.Button(description="Mutate")
reset_sliders = widgets.Button(description="Reset")
dlatent_encode = widgets.Button(description="Project input.png")

def new_sample(b):
    global mod_latents
    global dlatents_gen
    mod_latents = np.random.randn(1, Gs.input_shape[1])
    dlatents_gen = Gs.components.mapping.run(mod_latents, None)[0]  
    if psi_slider.value != 1.0:
        psi_slider.value += 0.00000000001
    else:
        psi_slider.value -= 0.00000000001
    
def mutate(b):
    global dlatents_gen
    mod_dlatents_add = np.random.randn(10, 512)
    dlatents_gen += mod_dlatents_add * 0.1
    if psi_slider.value != 1.0:
        psi_slider.value += 0.00000000001
    else:
        psi_slider.value -= 0.00000000001

def reset(b):
    for widget in tag_widgets:
        if not widget in ["truncate_pre", "truncate_post", "psi_pre", "psi_post"]:
            tag_widgets[widget].value = 0.0

def encode(b):
    global dlatents_gen
    image = PIL.Image.open("input.png")
    dlatents_gen = encode_image(image, encode_iters.value)
    if psi_slider.value != 1.0:
        psi_slider.value += 0.00000000001
    else:
        psi_slider.value -= 0.00000000001
            
truncate_pre = widgets.ToggleButton(value=True, description='Truncate Pre')
truncate_post = widgets.ToggleButton(value=True, description='Truncate Post')
refresh.on_click(new_sample)
modify.on_click(mutate)
reset_sliders.on_click(reset)
dlatent_encode.on_click(encode)

encode_iters = widgets.IntSlider(min = 10, max = 2000, step = 1, value=200, continuous_update = False, readout = True)

ui = widgets.Box(all_widgets + [refresh, modify, reset_sliders, truncate_pre, truncate_post, dlatent_encode, widgets.Label("Iterations"), encode_iters])
tag_widgets["psi_pre"] = psi_slider
tag_widgets["psi_post"] = psi_slider_post

ui.layout.flex_flow = 'row wrap'
ui.layout.display = 'inline-flex'
tag_widgets["truncate_pre"] = truncate_pre
tag_widgets["truncate_post"] = truncate_post

out = widgets.interactive_output(modify_and_sample, tag_widgets)

In [13]:
display(ui, out)

Box(children=(HBox(children=(Label(value='psi_pre', layout=Layout(width='170px')), FloatSlider(value=0.6, cont…

Output(outputs=({'output_type': 'display_data', 'data': {'text/plain': '<PIL.Image.Image image mode=RGB size=1…