In [1]:
import ipywidgets as widgets
import pretrained_networks
import PIL.Image
import numpy as np
import dnnlib
import dnnlib.tflib as tflib

network_pkl = 'results/00008-stylegan2-afhq_labeld-2gpu-cond-config-e/network-snapshot-005532.pkl'
_G, _D, Gs = pretrained_networks.load_networks(network_pkl)

Gs_syn_kwargs = dnnlib.EasyDict()
batch_size = 1
Gs_syn_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
Gs_syn_kwargs.randomize_noise = True
Gs_syn_kwargs.minibatch_size = batch_size

Setting up TensorFlow plugin "fused_bias_act.cu": Preprocessing... Loading... Done.
Setting up TensorFlow plugin "upfirdn_2d.cu": Preprocessing... Loading... Done.


In [2]:
def display_sample_conditional(cat, dog, wild, seed, truncation, return_img=False):
    batch_size = 1
    l1 = np.zeros((1,3))
    l1[0][0] = cat
    l1[0][1] = dog
    l1[0][2] = wild

    all_seeds = [seed] * batch_size
    all_z = np.stack([np.random.RandomState(seed).randn(*Gs.input_shape[1:]) for seed in all_seeds])
    all_w = Gs.components.mapping.run(all_z, np.tile(l1, (batch_size, 1))) # [minibatch, layer, component]
    if truncation != 1:
        w_avg = Gs.get_var('dlatent_avg')
        all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]
    all_images = Gs.components.synthesis.run(all_w, **Gs_syn_kwargs)
    if return_img:
        return PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8))
    else:
        display(PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)))

## Conditional generation of animals

In [3]:
animal = widgets.Dropdown(
    options=[('Cat', 0), ('Dog', 1), ('Wild', 2)],
    value=0,
    description='Animal: '
)

seed = widgets.IntSlider(min=0, max=100000, step=1, value=0, description='Seed: ')
truncation = widgets.FloatSlider(min=0, max=1, step=0.1, value=1, description='Truncation: ')

top_box = widgets.HBox([animal])
bot_box = widgets.HBox([seed, truncation])
ui = widgets.VBox([top_box, bot_box])

def display_animal(animal, seed, truncation):
    cat = (animal == 0)
    dog = (animal == 1)
    wild = (animal == 2)
    display_sample_conditional(cat, dog, wild, seed, truncation)

out = widgets.interactive_output(display_animal, {'animal': animal, 'seed': seed, 'truncation': truncation})

display(ui, out)

VBox(children=(HBox(children=(Dropdown(description='Animal: ', options=(('Cat', 0), ('Dog', 1), ('Wild', 2)), …

Output()

## Mixed generation of animal

In [4]:
cat = widgets.FloatSlider(min=0, max=1, step=0.05, value=1, description='Cat: ')
dog = widgets.FloatSlider(min=0, max=1, step=0.05, value=0, description='Dog: ')
wild = widgets.FloatSlider(min=0, max=1, step=0.05, value=0, description='Wild: ')

top_box = widgets.HBox([cat, dog, wild])
bot_box = widgets.HBox([seed, truncation])
ui = widgets.VBox([top_box, bot_box])


out = widgets.interactive_output(display_sample_conditional,
                                 {'cat': cat, 'dog': dog, 'wild': wild,
                                  'seed': seed, 'truncation': truncation})

display(ui, out)

VBox(children=(HBox(children=(FloatSlider(value=1.0, description='Cat: ', max=1.0, step=0.05), FloatSlider(val…

Output()

## Transition between labels

In [5]:
direction = widgets.Dropdown(
    options=['cat2wild', 'cat2dog', 'dog2wild'],
    value='cat2wild',
    description='Animal: '
)
value = widgets.FloatSlider(min=0, max=1, step=0.05, value=1, description='Value: ')


top_box = widgets.HBox([direction, value])
bot_box = widgets.HBox([seed, truncation])
ui = widgets.VBox([top_box, bot_box])

def display_transition(direction, value, truncation, seed, return_img=False):
    if direction == 'cat2wild':
        wild = value
        cat = 1 - value
        dog = 0
    elif direction == 'cat2dog':
        dog = value
        cat = 1 - value
        wild = 0
    elif direction == 'dog2wild':
        wild = value
        dog = 1 - value
        cat = 0
    else:
        raise ValueError('Wrong direction value')
    
    if return_img:
        return display_sample_conditional(cat, dog, wild, seed, truncation, return_img)
    else:
        display_sample_conditional(cat, dog, wild, seed, truncation, return_img)

out = widgets.interactive_output(display_transition, {'direction': direction, 'value': value,
                                                      'seed': seed, 'truncation': truncation})

display(ui, out)

VBox(children=(HBox(children=(Dropdown(description='Animal: ', options=('cat2wild', 'cat2dog', 'dog2wild'), va…

Output()

## Save images for animation with imagemagick

In [6]:
imgs = [display_transition(direction.value, i, truncation.value, seed.value, return_img=True) for
        i in np.linspace(0, 1, 31)]

for i, im in enumerate(imgs):
    im.save(f'animations/3/{i:03}.jpg')

convert -delay 10 -layers optimize animations/3/*.jpg anim3.gif