# 3362 Anime Character Generator

### To play with the generator, make sure that you satisfy following conditions:
1. "stylegan2" directory exists in the current working directory.
2. You have prepared the virtual environment for stylegan2. Please refer to https://github.com/y1huac00/3362AnimeCharacterGenerator for details of requirements if you have not done so.
3. "tagged_dlatents" directory exists, which stores the tag direction files.
4. pip install any packages that you are missing below.
5. Make sure you have downloaded our best stylegan2 model 70knetwork-snapshot-002800.pkl from Github's Release, and placed it in the working directory.

In [1]:
import os
import pickle
import numpy as np
import PIL.Image
import PIL.ImageSequence
import sys
sys.path.append("stylegan2")
import dnnlib
import dnnlib.tflib as tflib
import tensorflow as tf
import copy
from IPython.display import display, clear_output
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


### You can choose the pretrained tag directions below by specifying the method and the dataset size.

In [2]:
method = ['logistic','mean'][0]
dataset_size = ['65','76','2000','8000','30000','80000'][0]

In [3]:
stylegan2model = "./models/70knetwork-snapshot-002800.pkl"
tag_direction = f"./tagged_dlatents/tag_dirs_cont-{method}-{dataset_size}.pkl"
tag_use = f"./tagged_dlatents/tags_use-{method}-{dataset_size}.pkl"

In [4]:
tflib.init_tf()
# Load pre-trained network.
with open(stylegan2model, 'rb') as f:
    _G, _D, Gs = pickle.load(f)
tag_directions = {}
with open(tag_direction, 'rb') as f:
    tag_directions = pickle.load(f)

Setting up TensorFlow plugin "fused_bias_act.cu": Preprocessing... Loading... Done.
Setting up TensorFlow plugin "upfirdn_2d.cu": Preprocessing... Loading... Done.


In [5]:
# Generate a random latent vector first, you can set a seed to keep every trial the same

rnd = np.random
latents_a = rnd.randn(1, Gs.input_shape[1])

In [6]:
# StyleGAN2 model functions to truncate and generate images.

def truncate(dlatents, truncation_psi, maxlayer = 10):
    dlatent_avg = tf.get_default_session().run(Gs.own_vars["dlatent_avg"])
    layer_idx = np.arange(10)[np.newaxis, :, np.newaxis]
    ones = np.ones(layer_idx.shape, dtype=np.float32)
    coefs = tf.where(layer_idx < maxlayer, truncation_psi * ones, ones)
    return tf.get_default_session().run(tflib.lerp(dlatent_avg, dlatents, coefs))

# Generate image with disentangled latents as input
def generate_images_from_dlatents(dlatents, truncation_psi = 1.0, randomize_noise = True):
    if not truncation_psi is None:
        dlatents_trunc = truncate(dlatents, truncation_psi)
    else:
        dlatents_trunc = dlatents
        
    # Run the network
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    result_image = Gs.components.synthesis.run(
        dlatents_trunc.reshape((-1, 10, 512)),
        randomize_noise = randomize_noise,
        minibatch_size = 1,
        output_transform=fmt
    )[0]
    return result_image

In [7]:
##
# UI
##
# The code of this UI is referenced from https://colab.research.google.com/github/halcy/AnimeFaceNotebooks/blob/master/colab/Stylegan2_Playground.ipynb
hair_eyes_only = False
    
tag_len = {}
for tag in tag_directions:
    tag_len[tag] = np.linalg.norm(tag_directions[tag].flatten())

mod_latents = copy.deepcopy(latents_a)
dlatents_gen = Gs.components.mapping.run(mod_latents, None)[0]  
def modify_and_sample(psi_pre, psi_post, truncate_pre, truncate_post, **kwargs):
    if truncate_pre == True:
        dlatents_mod = truncate(copy.deepcopy(dlatents_gen), psi_pre)
    else:
        dlatents_mod = copy.deepcopy(dlatents_gen)
        
    for tag in kwargs:
        dlatents_mod += tag_directions[tag] * kwargs[tag]  / tag_len[tag] * 25.0
    value_widgets["psi_pre"].value = str(round(psi_pre, 2))
    value_widgets["psi_post"].value = str(round(psi_post, 2))
    
    for tag in kwargs:
        tag_value = round((np.dot(dlatents_mod.flatten(), tag_directions[tag].flatten()) / tag_len[tag]), 2)
        value_widgets[tag].value = str(kwargs[tag]) + " | " + str(tag_value)
    
    display_psi = None
    if truncate_post == True:
        display_psi = psi_post
    display(PIL.Image.fromarray(generate_images_from_dlatents(dlatents_mod, truncation_psi = display_psi), 'RGB').resize(size=(128, 128)))

psi_slider = widgets.FloatSlider(min = 0.0, max = 1.0, step = 0.01, value = 0.5, continuous_update = False, readout = False)
psi_slider_post = widgets.FloatSlider(min = 0.0, max = 1.0, step = 0.01, value = 0.5, continuous_update = False, readout = False)
if hair_eyes_only:
    modify_tags = [tag for tag in tag_directions if "_hair" in tag or "_eyes" in tag or "_mouth" in tag]
else:
    with open(tag_use, "rb") as f:
        modify_tags = pickle.load(f)

tag_widgets = {}
for tag in modify_tags:
    tag_widgets[tag] = widgets.FloatSlider(min = -2.0, max = 10.0, step = 0.005, continuous_update = False, readout = False)
all_widgets = []

sorted_widgets = sorted(tag_widgets.items(), key = lambda x: x[0])
sorted_widgets = [("psi_pre", psi_slider), ("psi_post", psi_slider_post)] + sorted_widgets
value_widgets = {}
for widget in sorted_widgets:
    label_widget = widgets.Label(widget[0])
    label_widget.layout.width = "170px"
    
    value_widget = widgets.Label("0.0+100.0")
    value_widget.layout.width = "150px"
    value_widgets[widget[0]] = value_widget
    
    tag_hbox = widgets.HBox([label_widget, widget[1], value_widget])
    tag_hbox.layout.width = "320px"
    
    all_widgets.append(tag_hbox)

refresh = widgets.Button(description="New Sample")
reset_sliders = widgets.Button(description="Reset")

def new_sample(b):
    global mod_latents
    global dlatents_gen
    mod_latents = np.random.randn(1, Gs.input_shape[1])
    dlatents_gen = Gs.components.mapping.run(mod_latents, None)[0]  
    if psi_slider.value != 1.0:
        psi_slider.value += 0.00000000001
    else:
        psi_slider.value -= 0.00000000001
    
def reset(b):
    for widget in tag_widgets:
        if not widget in ["truncate_pre", "truncate_post", "psi_pre", "psi_post"]:
            tag_widgets[widget].value = 0.0
            
truncate_pre = widgets.ToggleButton(value=True, description='Truncate Pre')
truncate_post = widgets.ToggleButton(value=True, description='Truncate Post')
refresh.on_click(new_sample)
reset_sliders.on_click(reset)

ui = widgets.Box(all_widgets + [refresh, reset_sliders, truncate_pre, truncate_post])
tag_widgets["psi_pre"] = psi_slider
tag_widgets["psi_post"] = psi_slider_post

ui.layout.flex_flow = 'row wrap'
ui.layout.display = 'inline-flex'
tag_widgets["truncate_pre"] = truncate_pre
tag_widgets["truncate_post"] = truncate_post

out = widgets.interactive_output(modify_and_sample, tag_widgets)

In [8]:
display(ui, out)

Box(children=(HBox(children=(Label(value='psi_pre', layout=Layout(width='170px')), FloatSlider(value=0.5, cont…

Output()