# Import/Setup

In [3]:
import os
import pickle
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import config

In [4]:
url_ffhq        = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
url_celebahq    = 'https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf' # karras2019stylegan-celebahq-1024x1024.pkl
url_bedrooms    = 'https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF' # karras2019stylegan-bedrooms-256x256.pkl
url_cars        = 'https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3' # karras2019stylegan-cars-512x384.pkl
url_cats        = 'https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ' # karras2019stylegan-cats-256x256.pkl

synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)

_Gs_cache = dict()

def load_Gs(url):
    if url not in _Gs_cache:
        with open("karras2019stylegan-ffhq-1024x1024.pkl", "rb") as f:
    #    with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
            _G, _D, Gs = pickle.load(f)
        _Gs_cache[url] = Gs
    return _Gs_cache[url]

In [5]:
def draw_truncation_trick_figure(png, Gs, w, h, seeds, psis, style_ranges):

    latents = np.stack(np.random.RandomState(seed).randn(load_Gs(url_ffhq).input_shape[1]) for seed in seeds)
    dlatents = load_Gs(url_ffhq).components.mapping.run(latents, None) # [seed, layer, component]
    src_dlatents = Gs.components.mapping.run(dlatents[0], None) # [seed, layer, component]
    dst_dlatents = Gs.components.mapping.run(dlatents[1], None) # [seed, layer, component]
    dlatent_avg = load_Gs(url_ffhq).get_var('dlatent_avg') # [component]
    placeholder = [0,1]
    for row, dlatent in enumerate(list(dlatents)):
        row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg
        placeholder[row] = row_dlatents

    feed_row_dlatents = placeholder[0]
    feed_col_dlatents = placeholder[1]

    src_images = Gs.components.synthesis.run(feed_row_dlatents, randomize_noise=False, **synthesis_kwargs)
    dst_images = Gs.components.synthesis.run(feed_col_dlatents, randomize_noise=False, **synthesis_kwargs)

    canvas = PIL.Image.new('RGB', (w * (len(psis) + 1), h * (len(psis) + 1)), 'white')
    for col, src_image in enumerate(list(src_images)):
         canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0))
    for row, dst_image in enumerate(list(dst_images)):
        canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h))
        row_dlatents = np.stack([dst_dlatents[row]] * len(psis))
        row_dlatents[:, style_ranges[row]] = feed_row_dlatents[:, style_ranges[row]]
       # row_dlatents = feed_row_dlatents
        row_images = Gs.components.synthesis.run(row_dlatents, randomize_noise=False, **synthesis_kwargs)
        print(row)
        for col, image in enumerate(list(row_images)):
            img = PIL.Image.fromarray(image,"RGB")
            img.save(os.path.join(config.result_dir, "%d-%d.png" %(row,col)))
            canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h))
    canvas.save(os.path.join(config.result_dir, 'testing.png'))


In [7]:
tflib.init_tf()

draw_truncation_trick_figure(os.path.join(config.result_dir, 'testing.png'), load_Gs(url_ffhq),
                             w=1024, h=1024, seeds=[91,388],psis= np.linspace(-1,1,18),
                             style_ranges=[range(0,18)]*18)


0
1
2
3
4
5
6
7
8
9
10
11
12
