Notebook to try out progressive GANs.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import scipy.stats as stats

import kbrgan
import kbrgan.glo as glo

In [None]:
# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 18
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

## Load a progressive GAN model

In [None]:
import pickle
import numpy as np
import tensorflow as tf
import PIL.Image

# Initialize TensorFlow session.
tf.InteractiveSession()

# Import official CelebA-HQ networks.
# model_path = glo.prob_model_folder('progan', 'karras2018iclr-celebahq-1024x1024.pkl')
# model_path = glo.prob_model_folder('progan', 'karras2018iclr-lsun-airplane-256x256.pkl')

fname_prefix = 'churchoutdoor'
model_path = glo.prob_model_folder('progan', 'karras2018iclr-lsun-{}-256x256.pkl'.format(fname_prefix))
# model_path = glo.prob_model_folder('progan', 'karras2018iclr-lsun-livingroom-256x256.pkl')
with open(model_path, 'rb') as file:
    G, D, Gs = pickle.load(file)
    # G = Instantaneous snapshot of the generator, mainly useful for resuming a previous training run.
    # D = Instantaneous snapshot of the discriminator, mainly useful for resuming a previous training run.
    # Gs = Long-term average of the generator, yielding higher-quality results than the instantaneous snapshot.

In [None]:
G

In [None]:
# Generate latent vectors.
latents = np.random.RandomState(2005).randn(10, *Gs.input_shapes[0][1:]) # 1000 random latents
# latents = latents[[477, 56, 83, 887, 583, 391, 86, 340, 341, 415]] # hand-picked top-10

# Generate dummy labels (not used by the official networks).
labels = np.zeros([latents.shape[0]] + Gs.input_shapes[1][1:])

In [None]:
# Run the generator to produce a set of images.
images = Gs.run(latents, labels)

In [None]:
# Convert images to PIL-compatible format.
images = np.clip(np.rint((images + 1.0) / 2.0 * 255.0), 0.0, 255.0).astype(np.uint8) # [-1,1] => [0,255]
images = images.transpose(0, 2, 3, 1) # NCHW => NHWC

# Save images as PNG.
for idx in range(images.shape[0]):
    PIL.Image.fromarray(images[idx], 'RGB').save('{}{}.png'.format(fname_prefix, idx+1))