In [1]:
import os
import pickle
import numpy as np
import numpy.linalg as la
import PIL.Image
import PIL.ImageSequence
import dnnlib
import dnnlib.tflib as tflib
from IPython.display import display, clear_output
import moviepy
import moviepy.editor
import math
import glob
import csv
from functools import partial
import time
import collections

import tensorflow as tf

from tensorflow import keras
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input

from sklearn.linear_model import LinearRegression, Lasso

import colorsys
import requests
import re
import copy

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets






In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [3]:
##
# Load network snapshot
##

#input_sg_name = "2019-02-09-stylegan-danbooru2017-faces-network-snapshot-007841.pkl"

# From https://mega.nz/#!vOgj1QoD!GD3E37BroNnZaIR_nic2zVxBtKfAqlvbEC8uBK8-4co
#input_sg_name = "cache/2019-02-10-stylegan-asuka.pkl"
#input_sg_name = "cache/2019-02-18-stylegan-faces-network-02041-011095.pkl"
#input_sg_name = "cache/2019-05-03-stylegan-malefaces.pkl"
input_sg_name = "cache/2019-03-08-stylegan-animefaces-network.pkl"

tflib.init_tf()

# Load pre-trained network.
with open(input_sg_name, 'rb') as f:
    # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
    # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
    # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.    
    _G, _D, Gs = pickle.load(f)
        
# Print network details.
Gs.print_layers()
_D.print_layers()



Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Gs                            Params    OutputShape         WeightShape     
---                           ---       ---                 ---             
latents_in                    -         (?, 512)            -               
labels_in                     -         (?, 0)              -               
lod                           -         ()                  -               
dlatent_avg                   -         (512,)              -               
G_mapping/latents_in          -         (?, 512)            -               
G_mapping/labels_in           -         (?, 0)              -               
G_mapping/PixelNorm           -         (?, 512)            -               
G_mapping/Dense0              262656    (?, 512)            (512, 512)      
G_mapping/Dense1              262656    (?, 512)            (512, 512)      
G_mapping/Dense2              262656    (?, 512)        

Grow_lod2            -         (?, 256, 64, 64)    -               
64x64/Conv0          590080    (?, 256, 64, 64)    (3, 3, 256, 256)
64x64/Conv1_down     1180160   (?, 512, 32, 32)    (3, 3, 256, 512)
Downscale2D_3        -         (?, 3, 32, 32)      -               
FromRGB_lod4         2048      (?, 512, 32, 32)    (1, 1, 3, 512)  
Grow_lod3            -         (?, 512, 32, 32)    -               
32x32/Conv0          2359808   (?, 512, 32, 32)    (3, 3, 512, 512)
32x32/Conv1_down     2359808   (?, 512, 16, 16)    (3, 3, 512, 512)
Downscale2D_4        -         (?, 3, 16, 16)      -               
FromRGB_lod5         2048      (?, 512, 16, 16)    (1, 1, 3, 512)  
Grow_lod4            -         (?, 512, 16, 16)    -               
16x16/Conv0          2359808   (?, 512, 16, 16)    (3, 3, 512, 512)
16x16/Conv1_down     2359808   (?, 512, 8, 8)      (3, 3, 512, 512)
Downscale2D_5        -         (?, 3, 8, 8)        -               
FromRGB_lod6         2048      (?, 512, 8, 8)   

In [4]:
##
# Build things on top for encoding
# Unfortunately this works only Kind Of Okay
# Based on https://github.com/Puzer/stylegan
##
def create_stub(name, batch_size):
    return tf.constant(0, dtype='float32', shape=(batch_size, 0))

dlatent_avg = tf.get_default_session().run(Gs.own_vars["dlatent_avg"])
def create_variable_for_generator(name, batch_size):
    print("create_variable_for_generator called")  # added, so I know it was called. Todo: figure out why it's crashing
    truncation_psi_encode = 0.7
    layer_idx = np.arange(16)[np.newaxis, :, np.newaxis]
    ones = np.ones(layer_idx.shape, dtype=np.float32)
    coefs = tf.where(layer_idx < 8, truncation_psi_encode * ones, ones)
    dlatent_variable = tf.get_variable(
        'learnable_dlatents', 
        shape=(1, 16, 512), 
        dtype='float32', 
        initializer=tf.initializers.zeros()
    )
    dlatent_variable_trunc = tflib.lerp(dlatent_avg, dlatent_variable, coefs)
    return dlatent_variable_trunc

# Generation-from-disentangled-latents part
initial_dlatents = np.zeros((1, 16, 512))
Gs.components.synthesis.run(
    initial_dlatents,
    randomize_noise = True, # Turns out this should not be off ever for trying to lean dlatents, who knew
    minibatch_size = 1,
    custom_inputs = [
        partial(create_variable_for_generator, batch_size=1),
        partial(create_stub, batch_size = 1)],
    structure = 'fixed', num_gpus=1
)

dlatent_variable = next(v for v in tf.global_variables() if 'learnable_dlatents' in v.name)
generator_output = tf.get_default_graph().get_tensor_by_name('G_synthesis_1/_Run/G_synthesis/images_out:0')
generated_image = tflib.convert_images_to_uint8(generator_output, nchw_to_nhwc=True, uint8_cast=False)
generated_image_uint8 = tf.saturate_cast(generated_image, tf.uint8)

# Loss part
vgg16 = VGG16(include_top=False, input_shape=(512, 512, 3))
perceptual_model = keras.Model(vgg16.input, vgg16.layers[9].output)
generated_img_features = perceptual_model(preprocess_input(generated_image, mode="tf"))
ref_img = tf.get_variable(
    'ref_img', 
    shape = generated_image.shape,
    dtype = 'float32', 
    initializer = tf.zeros_initializer()
)
ref_img_features = tf.get_variable(
    'ref_img_features', 
    shape = generated_img_features.shape,
    dtype = 'float32', 
    initializer = tf.zeros_initializer()
)
tf.get_default_session().run([ref_img.initializer, ref_img_features.initializer])
basic_loss = tf.losses.mean_squared_error(ref_img, generated_image)
perceptual_loss = tf.losses.mean_squared_error(ref_img_features, generated_img_features)

_D.run(np.zeros((1, 3, 512, 512)), None, custom_inputs = [
    lambda x: generator_output,
    partial(create_stub, batch_size = 1),
])
discriminator_output = tf.get_default_graph().get_tensor_by_name('D/_Run/D/scores_out:0')

# Attempt at making encoding better: Bias towards mean ("truncation loss", essentially)
dlatent_avg_full = dlatent_avg.reshape(-1, 512).repeat(16, axis = 0).reshape(-1, 16, 512)
input_loss = tf.losses.mean_squared_error(dlatent_variable, dlatent_avg_full)
combined_loss = input_loss + perceptual_loss

# We have a discriminator network, why not use it?
discriminator_loss = tf.nn.softplus(-discriminator_output)

create_variable_for_generator called
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


In [5]:
# Gradient descend in latent space to something that is similar to the input image
def encode_image(image, iterations = 1024, learning_rate = 0.1, reset_dlatents = True, custom_initial_dlatents = None):
    # Get session
    sess = tf.get_default_session()
    
    # Gradient descent initial state
    #optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
    optimizer = tf.train.AdadeltaOptimizer(learning_rate = learning_rate)
    min_op = optimizer.minimize(perceptual_loss, var_list=[[dlatent_variable]])
    if reset_dlatents == True:
        if not custom_initial_dlatents is None:
            sess.run(tf.assign(dlatent_variable, custom_initial_dlatents.reshape(-1, 16, 512)))
        else:
            sess.run(tf.assign(dlatent_variable, initial_dlatents))
    
    # Generate and set reference image features
    ref_image_data = np.array(list(map(lambda x: (x.astype("float32")), [image])))
    image_features = perceptual_model.predict_on_batch(preprocess_input(ref_image_data, mode="tf"))  
    sess.run(tf.assign(ref_img_features, image_features))
    
    # Run
    for i in range(iterations):
        _, loss = sess.run([min_op, perceptual_loss])
        if i % 100 == 0:
            print("i: {}, l: {}".format(i, loss))
    
    # Generate image that actually goes with these dlatents for quick testing
    dlatents = sess.run(dlatent_variable)[0]
    generated_image = generate_images_from_dlatents(dlatents)
    
    return dlatents, generated_image

# Same as above but start with given dlatents and use plain MSE loss instead of vgg16
def finetune_image(dlatents, image, iterations = 32, learning_rate = 0.0001):
    # Get session and assign initial dlatents
    sess = tf.get_default_session()
    sess.run(tf.assign(dlatent_variable, np.array([dlatents])))
    
    # Set reference image
    ref_image_data = np.array(list(map(lambda x: (x.astype("float64")), [image])))
    sess.run(tf.assign(ref_img, ref_image_data))    
    
    # Gradient descent
    optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
    min_op = optimizer.minimize(basic_loss, var_list=[[dlatent_variable]])
    
    for i in range(iterations):
        _, loss = sess.run([min_op, basic_loss])
        if i % 100 == 0:
            print("i: {}, l: {}".format(i, loss))

    # Generate image that actually goes with these latents for quick testing
    dlatents = sess.run(dlatent_variable)[0]
    generated_image = generate_images_from_dlatents(dlatents)
    
    return dlatents, generated_image

# Tune image in the direction of being considered more likely by the discriminator
def tune_with_discriminator(dlatents, iterations = 32, learning_rate = 1.0):
    # Get session and assign initial dlatents
    sess = tf.get_default_session()
    sess.run(tf.assign(dlatent_variable, np.array([dlatents])))
    
    # Gradient descent
    optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)
    min_op = optimizer.minimize(discriminator_loss, var_list=[[dlatent_variable]])
    
    for i in range(iterations):
        _, loss = sess.run([min_op, basic_loss])
        if i % 100 == 0:
            print("i: {}, l: {}".format(i, loss))
    
    return sess.run(dlatent_variable)[0]

# We have to do truncation ourselves, since we're not using the combined network
def truncate(dlatents, truncation_psi, maxlayer = 8):
    dlatent_avg = tf.get_default_session().run(Gs.own_vars["dlatent_avg"])
    layer_idx = np.arange(16)[np.newaxis, :, np.newaxis]
    ones = np.ones(layer_idx.shape, dtype=np.float32)
    coefs = tf.where(layer_idx < maxlayer, truncation_psi * ones, ones)
    return tf.get_default_session().run(tflib.lerp(dlatent_avg, dlatents, coefs))

# Generate image with disentangled latents as input
def generate_images_from_dlatents(dlatents, truncation_psi = 1.0, randomize_noise = True):
    if not truncation_psi is None:
        dlatents_trunc = truncate(dlatents, truncation_psi)
    else:
        dlatents_trunc = dlatents
        
    # Run the network
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    result_image = Gs.components.synthesis.run(
        dlatents_trunc.reshape((-1, 16, 512)),
        randomize_noise = randomize_noise,
        minibatch_size = 1,
        output_transform=fmt, num_gpus=0
    )[0]
    return result_image

# Sequence of learning steps while reducing lr followed by finetune
def encode_and_tune(image, iters_per_step = 1024):
    initial_latents = np.random.randn(1, Gs.input_shape[1])
    initial_dlatents = Gs.components.mapping.run(initial_latents, None)[0]
    dlatents_gen, image_gen = encode_image(image, iterations = iters_per_step, learning_rate = 100.0, custom_initial_dlatents = initial_dlatents)
    dlatents_gen2, image_gen2 = encode_image(image, iterations = iters_per_step, learning_rate = 10.0, reset_dlatents = False)
    dlatents_gen3, image_gen3 = encode_image(image, iterations = iters_per_step, learning_rate = 1.0, reset_dlatents = False)
    dlatents_gen4, image_gen4 = encode_image(image, iterations = iters_per_step, learning_rate = 0.1, reset_dlatents = False)
    dlatents_gen5, image_gen5 = encode_image(image, iterations = iters_per_step, learning_rate = 0.01, reset_dlatents = False)
    dlatents_gen6, image_gen6 = encode_image(image, iterations = iters_per_step, learning_rate = 0.001, reset_dlatents = False)
    dlatents_gen7, image_gen7 = finetune_image(dlatents_gen5, image, iterations = 128)
    return dlatents_gen7, image_gen7, dlatents_gen6

In [6]:
# Interactive modification!
hair_eyes_only = False # Set to true for fewer tags
lock_updates = False
def modify_and_sample(psi, truncate_pre, truncate_post, **kwargs):
    global lock_updates
    if lock_updates == True:
        return
    
    if truncate_pre == True:
        dlatents_mod = truncate(copy.deepcopy(dlatents_gen), psi)
    else:
        dlatents_mod = copy.deepcopy(dlatents_gen)
        
    for tag in kwargs:
        dlatents_mod += tag_directions[tag] * kwargs[tag]
    value_widgets["psi"].value = str(round(psi, 2))
    
    for tag in kwargs:
        tag_value = round((np.dot(dlatents_mod.flatten(), tag_directions[tag].flatten()) / tag_len[tag]) - kwargs[tag], 2)
        value_widgets[tag].value = str(kwargs[tag]) + " | " + str(tag_value)
    
    display_psi = None
    if truncate_post == True:
        display_psi = psi
    display(PIL.Image.fromarray(generate_images_from_dlatents(dlatents_mod, truncation_psi = display_psi), 'RGB'))

# Load up tags and tag directions
with open("tag_dirs.pkl", 'rb') as f:
    tag_directions = pickle.load(f)
    
tag_len = {}
for tag in tag_directions:
    tag_len[tag] = np.linalg.norm(tag_directions[tag].flatten())
    
mod_latents = np.load("mod_latents.npy")
dlatents_gen = Gs.components.mapping.run(mod_latents, None)[0]  
    
if hair_eyes_only:
    modify_tags = [tag for tag in tag_directions if "_hair" in tag or "_eyes" in tag or "_mouth" in tag]
    modify_tags.append("realistic")
else:
    with open("tags_use.pkl", "rb") as f:
        modify_tags = pickle.load(f)

# Build UI
psi_slider = widgets.FloatSlider(min = 0.0, max = 1.0, step = 0.01, value = 0.7, continuous_update = False, readout = False)
tag_widgets = {}
for tag in modify_tags:
    tag_widgets[tag] = widgets.FloatSlider(min = -2.0, max = 2.0, step = 0.01, continuous_update = False, readout = False)
all_widgets = []

sorted_widgets = sorted(tag_widgets.items(), key = lambda x: x[0])
sorted_widgets = [("psi", psi_slider)] + sorted_widgets
value_widgets = {}
for widget in sorted_widgets:
    label_widget = widgets.Label(widget[0])
    label_widget.layout.width = "140px"
    
    value_widget = widgets.Label("0.0+100.0")
    value_widget.layout.width = "150px"
    value_widgets[widget[0]] = value_widget
    
    tag_hbox = widgets.HBox([label_widget, widget[1], value_widget])
    tag_hbox.layout.width = "320px"
    
    all_widgets.append(tag_hbox)

refresh = widgets.Button(description="New Sample")
modify = widgets.Button(description="Mutate")
reset = widgets.Button(description="Reset Tags")

def new_sample(b):
    global mod_latents
    global dlatents_gen
    mod_latents = np.random.randn(1, Gs.input_shape[1])
    dlatents_gen = Gs.components.mapping.run(mod_latents, None)[0] 
    psi_slider.value += 0.00000000001 # idk how to properly
    
def mutate(b):
    global mod_latents
    global dlatents_gen
    mod_latents_add = np.random.randn(1, Gs.input_shape[1]) * 0.2
    mod_latents += mod_latents_add
    dlatents_gen = Gs.components.mapping.run(mod_latents, None)[0]  
    psi_slider.value += 0.00000000001

def reset_tags(b):
    global lock_updates
    lock_updates = True
    for widget in real_tag_widgets.values():
        widget.value = 0.0
    lock_updates = False
    psi_slider.value += 0.00000000001

real_tag_widgets = copy.copy(tag_widgets)

truncate_pre = widgets.ToggleButton(value=True, description='Truncate Pre')
truncate_post = widgets.ToggleButton(value=False, description='Truncate Post')
refresh.on_click(new_sample)
modify.on_click(mutate)
reset.on_click(reset_tags)

for button in [refresh, modify, truncate_pre, truncate_post, reset]:
    button.layout.width = "120px"

ui = widgets.Box(all_widgets + [refresh, modify, truncate_pre, truncate_post, reset])
tag_widgets["psi"] = psi_slider

ui.layout.flex_flow = 'row wrap'
ui.layout.display = 'inline-flex'
tag_widgets["truncate_pre"] = truncate_pre
tag_widgets["truncate_post"] = truncate_post
out = widgets.interactive_output(modify_and_sample, tag_widgets)

# Lets go! (best used in Presentation Mode)
display(ui, out)

Box(children=(HBox(children=(Label(value='psi', layout=Layout(width='140px')), FloatSlider(value=0.7, continuo…

Output()

### some math for reverse engineering interactive values to obtain latent representation

In [None]:
(mod_latents_open - mod_latents_closed).min()

In [None]:
(dlatents_gen_open - dlatents_gen_closed).min()

In [None]:
# todo: probably examine sometimes later, this has great potential
tag_directions['open_mouth']

In [None]:
# copied code, but modified so I'm not changing any global state
if truncate_pre.value == True:
    dlatents_mod_2 = truncate(copy.deepcopy(dlatents_gen), psi_slider.value)
else:
    dlatents_mod_2 = copy.deepcopy(dlatents_gen)
        
for tag in real_tag_widgets:
    dlatents_mod_2 += tag_directions[tag] * real_tag_widgets[tag].value
print("psi", str(round(psi_slider.value, 2)))

for tag in real_tag_widgets:
    tag_value = round((np.dot(dlatents_mod_2.flatten(), tag_directions[tag].flatten()) / tag_len[tag]) - real_tag_widgets[tag].value, 2)
    print(tag + " | " + str(real_tag_widgets[tag].value) + " | " + str(tag_value))

display_psi_2 = None
if truncate_post.value == True:
    display_psi_2 = psi_slider.value
display(PIL.Image.fromarray(generate_images_from_dlatents(dlatents_mod_2, truncation_psi = display_psi_2), 'RGB'))
# todo, now I replicated that, so would be good if I had method to save those 2 points between which I will be interpolating
# prolly 3 points, the version with ellipsis, but deterministic? or just line from one to another?
# try both, ig

In [None]:
def calc_current_dlatents(vals_dict):
    if truncate_pre.value == True:
        dlatents_mod_2 = truncate(copy.deepcopy(dlatents_gen), psi_slider.value)
    else:
        dlatents_mod_2 = copy.deepcopy(dlatents_gen)
        
    for tag in real_tag_widgets:
        if tag in vals_dict:
            dlatents_mod_2 += tag_directions[tag] * vals_dict[tag]
        else:
            dlatents_mod_2 += tag_directions[tag] * real_tag_widgets[tag].value
    
    display_psi_2 = None
    if truncate_post.value == True:
        display_psi_2 = psi_slider.value
    return dlatents_mod_2, display_psi_2


In [None]:
real_tag_widgets[tag].value

In [None]:
calc_current_dlatents({'open_mouth': -2})

In [None]:
calc_current_dlatents({'open_mouth': 2})

In [None]:
# todo: use interpolation scripts to do interpolation between those 2

In [5]:
tf.global_variables()

[<tf.Variable 'G_synthesis/lod:0' shape=() dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise0:0' shape=(1, 1, 4, 4) dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise1:0' shape=(1, 1, 4, 4) dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise2:0' shape=(1, 1, 8, 8) dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise3:0' shape=(1, 1, 8, 8) dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise4:0' shape=(1, 1, 16, 16) dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise5:0' shape=(1, 1, 16, 16) dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise6:0' shape=(1, 1, 32, 32) dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise7:0' shape=(1, 1, 32, 32) dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise8:0' shape=(1, 1, 64, 64) dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise9:0' shape=(1, 1, 64, 64) dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise10:0' shape=(1, 1, 128, 128) dtype=float32_ref>,
 <tf.Variable 'G_synthesis/noise11:0' shape=(1, 1, 128, 128) dtype=float32_ref>,
 <t