In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import os
import tensorflow as tf
import numpy as np
from numpy.random import randn, randint
import matplotlib.pyplot as plt
from matplotlib import pyplot
from IPython import display

SEED = 7091998
tf.random.set_seed(SEED)  

# Get current working directory
cwd = os.getcwd()

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
from tensorflow.keras.layers import Dense, BatchNormalization, Conv2DTranspose, LeakyReLU, Reshape, Conv2D, LeakyReLU, Dropout, MaxPool2D, GlobalAveragePooling2D, Flatten, Activation, BatchNormalization, UpSampling2D, Input
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import Model
from PIL import Image
from keras.initializers import RandomNormal
from datetime import datetime
import pickle

In [None]:
!cp "/content/drive/My Drive/datasets/std_dataset.zip" "/content/std_dataset.zip"
!unzip -oq "/content/std_dataset.zip" -d "/content/data/"

## Hyperparameters

In [None]:
in_h, in_w = 4, 4
img_h, img_w = 64, 64
start_f = 256
channels = 3
input_shape = (in_h, in_w, start_f)
img_shape = (img_h, img_w, channels)
latent_dim = 100
n_conv = 2

dropout_g = 0.4
dropout_d = 0.2

net_depth = 3
leaky_relu_slope = 0.18
n_conv = 1

g_kernel_size = (5, 5)
c_kernel_size = (5, 5)
strides_size = (2, 2)
pool_size = (2, 2)

clip_value = 0.01
n_critic = 7

## Data preparation

In [None]:
dataset_dir = "/content/data/std_dataset"

In [None]:
class CustomDataset(tf.keras.utils.Sequence):
  def __init__(self, dataset_dir):
    self.img_list = os.listdir(dataset_dir)
    self.dataset_dir = dataset_dir
  
  def __len__(self):
    return len(self.img_list)
  
  def __getitem__(self, index):
    curr_filename = self.img_list[index]
    img = Image.open(os.path.join(self.dataset_dir, curr_filename)).convert('RGB')
    img = img.resize((64, 64))
    img_arr = np.array(img)
    img_arr = np.float32(img_arr)
    # Normalize to [-1, 1]
    img_arr = (img_arr - 127.5) / 127.5

    return img_arr
  
  def getnparr(self):
    X = []
    for i in range(len(self.img_list)):
      x = self.__getitem__(i)
      X.append(x)
    
    return np.asarray(X)

dataset = CustomDataset(dataset_dir).getnparr()

In [None]:
def generate_real_samples(dataset, n_samples):
  # choose random instances
  ix = randint(0, len(dataset), n_samples)
  # retrieve selected images
  X = dataset[ix]
  # generate 'real' class labels (1)
  y = np.ones((n_samples, 1))
  return X, y

## Custom classes and functions

In [None]:
import keras.backend as K
from keras.constraints import Constraint

def wasserstein_loss(y_true, y_pred):
        return K.mean(y_true * y_pred)

optimizer = tf.keras.optimizers.RMSprop(lr=0.00005)

# clip model weights to a given hypercube
class ClipConstraint(Constraint):
	# set clip value when initialized
	def __init__(self, clip_value):
		self.clip_value = clip_value

	# clip model weights to hypercube
	def __call__(self, weights):
		return K.clip(weights, -self.clip_value, self.clip_value)

	# get the config
	def get_config(self):
		return {'clip_value': self.clip_value}

## Generator

In [None]:
def make_generator_model(batchnorm=True):
    model = tf.keras.Sequential(name='GenToon')
    n_nodes = in_h*in_w*start_f
    init = RandomNormal(stddev=0.02)
    
    # Input and reshape
    model.add(Dense(n_nodes, input_dim=latent_dim))
    model.add(Reshape((in_h, in_w, start_f)))

    # Upsampling
    for i in range(net_depth+1):
      model.add(UpSampling2D())
      for j in range(n_conv):
        model.add(Conv2D(filters=start_f/2**(i+1), kernel_size=g_kernel_size, padding='same', kernel_initializer=init))
        if batchnorm:
          model.add(BatchNormalization(momentum=0.8))
        model.add(Activation('relu'))

    # Output layer
    model.add(Conv2D(filters=3, kernel_size=(3, 3), activation='tanh', padding='same', kernel_initializer=init))

    return model

make_generator_model().summary()

In [None]:
# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
  # generate points in the latent space
  x_input = randn(latent_dim * n_samples)
  # reshape into a batch of inputs for the network
  x_input = x_input.reshape(n_samples, latent_dim)
  return x_input

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(g_model, latent_dim, n_samples):
  # generate points in latent space
  x_input = generate_latent_points(latent_dim, n_samples)
  # predict outputs
  X = g_model.predict(x_input)
  # create 'fake' class labels (0)
  y = np.zeros((n_samples, 1))
  return X, y

## Discriminator / Critic

In [None]:
def make_discriminator_model(batchnorm=False):
  model = tf.keras.Sequential(name='ToonCritic')
  filters = (start_f) // 2**net_depth
  
  # weight constraint
  const = ClipConstraint(clip_value=clip_value)
  init = RandomNormal(stddev=0.02)
  
  # Input
  model.add(Conv2D(filters=filters, kernel_size=c_kernel_size, padding='same', input_shape=img_shape, kernel_initializer=init, kernel_constraint=const))
  model.add(LeakyReLU(alpha=leaky_relu_slope))
  model.add(Dropout(dropout_d))

  filters = filters * 2

  # Convolutions
  for i in range(net_depth):
    for j in range(n_conv):
      if j == 0:
        strides = (2, 2)
      else:
        strides = (1, 1)
      model.add(Conv2D(filters=filters, strides=strides, kernel_size=c_kernel_size, padding='same', kernel_initializer=init, kernel_constraint=const))
      if batchnorm:
        model.add(BatchNormalization())
      model.add(LeakyReLU(alpha=leaky_relu_slope))
    
    filters = filters * 2

  # FC part
  model.add(Flatten())
  model.add(Dense(1))

  opt = tf.keras.optimizers.RMSprop(0.00005)
  model.compile(loss=wasserstein_loss, optimizer=opt)
  
  return model

make_discriminator_model().summary()

## WGAN

In [None]:
def create_wgan(generator, critic, opt_critic_weights=None, opt_wgan_weights=None):
  # Freeze the critic
  critic.trainable=False
  # Build model
  model = tf.keras.Sequential(name='WGAN')
  
  # Add the Generator
  model.add(generator)
  
  # Load Critic
  if opt_critic_weights is not None:
    opt = tf.keras.optimizers.deserialize(opt_critic_weights)
    critic.compile(loss=wasserstein_loss, optimizer=opt)
  else:
    opt = tf.keras.optimizers.RMSprop(lr=0.00005)
  model.add(critic)
  
  # Optional WGAN optimizer state
  if opt_wgan_weights is not None:
    opt = tf.keras.optimizers.deserialize(opt_wgan_weights)
  else:
    opt = tf.keras.optimizers.RMSprop(lr=0.00005)
  
  model.compile(loss=wasserstein_loss, optimizer=opt)
  
  return model

In [None]:
fresh_model = True
models_path = '/content/drive/My Drive/Colab Notebooks/wgan_models'
gen_name = 'generator_model_100.h5'
critic_name = 'critic_model_100.h5'
critic_optimizer_name = 'optimizer_critic_100.pkl'
wgan_optimizer_name = 'optimizer_wgan_100.pkl'

if fresh_model:
  print('Creating a fresh model...')
  generator = make_generator_model(True)
  critic = make_discriminator_model(False)

  wgan = create_wgan(generator, critic)
else:
  # Load Generator
  generator_path = os.path.join(models_path, gen_name)
  print('Loading Generator ...')
  generator = tf.keras.models.load_model(generator_path)
  # Load Critic
  critic_path = os.path.join(models_path, critic_name)
  print('Loading Critic ...')
  critic = tf.keras.models.load_model(critic_path, custom_objects={'ClipConstraint': ClipConstraint, 'wasserstein_loss': wasserstein_loss})
  # Critic Optimizer state
  optimizer_path = os.path.join(models_path, critic_optimizer_name)
  print('Loading optimizer for Critic ...')
  with open(optimizer_path, 'rb') as f:
    critic_optimizer = pickle.load(f)
  # WGAN Optimzier state
  optimizer_path = os.path.join(models_path, wgan_optimizer_name)
  print('Loading optimizer for WGAN ...')
  with open(optimizer_path, 'rb') as f:
    wgan_optimizer = pickle.load(f)
  
  wgan = create_wgan(generator, critic, opt_critic_weights=critic_optimizer, opt_wgan_weights=wgan_optimizer)

wgan.summary()

## Train the model

In [None]:
# create and save a plot of generated images
def save_plot(examples, epoch, n=3):
	# scale from [-1,1] to [0,1]
	examples = (examples + 1) / 2.0
	# plot images
	for i in range(n * n):
		# define subplot
		pyplot.subplot(n, n, 1 + i)
		# turn off axis
		pyplot.axis('off')
		# plot raw pixel data
		pyplot.imshow(examples[i])
	# save plot to file
	filename = 'generated_plot_e%03d.png' % (epoch+1)
	pyplot.savefig(os.path.join('/content/drive/My Drive/Colab Notebooks/WGAN_results', filename))
	pyplot.close()

In [None]:
base_path = '/content/drive/My Drive/Colab Notebooks/wgan_models'

def plot_and_save(epoch, g_model, d_model, wgan, dataset, latent_dim, n_samples=9):
  # prepare fake examples
  x_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
  # save plot
  save_plot(x_fake, epoch)
  
  # Generator Model
  filename_g = 'generator_model_%03d.h5' % (epoch+1)
  tf.keras.Model.save(g_model, os.path.join(base_path, filename_g))

  # Critic model
  filename_d = 'critic_model_%03d.h5' % (epoch+1)
  d_model.save(os.path.join(base_path, filename_d), include_optimizer=False)

  # Critic optimizer weights
  weights_filename = 'optimizer_critic_%03d.pkl' % (epoch+1)
  opt_save = os.path.join(base_path, weights_filename)
  opt_weights = tf.keras.optimizers.serialize(critic.optimizer)
  with open(opt_save, 'wb') as f:
    pickle.dump(opt_weights, f)

  # WGAN optimizer weights
  weights_filename = 'optimizer_wgan_%03d.pkl' % (epoch+1)
  opt_save = os.path.join(base_path, weights_filename)
  opt_weights = tf.keras.optimizers.serialize(wgan.optimizer)
  with open(opt_save, 'wb') as f:
    pickle.dump(opt_weights, f)

In [None]:
# Functions for training

def generate_real_samples(dataset, n_samples):
	# choose random instances
	ix = randint(0, dataset.shape[0], n_samples)
	# select images
	X = dataset[ix]
	# generate class labels, -1 for 'real'
	y = -np.ones((n_samples, 1))
	return X, y

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
	# generate points in the latent space
	x_input = randn(latent_dim * n_samples)
	# reshape into a batch of inputs for the network
	x_input = x_input.reshape(n_samples, latent_dim)
	return x_input

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
	# generate points in latent space
	x_input = generate_latent_points(latent_dim, n_samples)
	# predict outputs
	X = generator.predict(x_input)
	# create class labels with 1.0 for 'fake'
	y = np.ones((n_samples, 1))
	return X, y

In [None]:
# train the generator and critic
def train(dataset, critic, generator, wgan, epochs, batch_size, start_epoch=0):

  # main gan training loop
  for epoch in range(start_epoch, epochs):
    start_time = datetime.now()
    steps_per_epoch = dataset.shape[0] // (batch_size)
    half_batch = batch_size // 2
    for step in range(steps_per_epoch):
      
      # CRITIC TRAIN
      for _ in range(n_critic):
        # get randomly selected 'real' samples
        X_real, y_real = generate_real_samples(dataset, half_batch)
        # update critic model weights
        c_loss1 = critic.train_on_batch(X_real, y_real)
        # generate 'fake' examples
        X_fake, y_fake = generate_fake_samples(generator, latent_dim, half_batch)
        # update critic model weights
        c_loss2 = critic.train_on_batch(X_fake, y_fake)
        c_loss = (c_loss1 + c_loss2) / 2

      # GENERATOR TRAIN
      # prepare points in latent space as input for the generator
      X_gan = generate_latent_points(latent_dim, batch_size)
      # create inverted labels for the fake samples
      y_gan = -np.ones((batch_size, 1))
      # update the generator via the critic's error
      g_loss = wgan.train_on_batch(X_gan, y_gan)
    
    # Summarize epoch results
    seconds = (datetime.now() - start_time).total_seconds()
    print('Epoch %d >> [critic_loss: %.2f] [generator_loss: %.2f] [time: %d]' % (epoch+1, 1-c_loss, 1-g_loss, seconds))

    if (epoch+1) % 10 == 0:
      plot_and_save(epoch, generator, critic, wgan, dataset, latent_dim)

In [None]:
# train model
EPOCHS = 500
BS = 256
train(dataset, critic, generator, wgan, EPOCHS, BS)

Epoch 1 >> [critic_loss: -48.39] [generator_loss: 34525.77] [time: 119]
Epoch 2 >> [critic_loss: -12.35] [generator_loss: 2147.63] [time: 120]
Epoch 3 >> [critic_loss: 4268.64] [generator_loss: -374.62] [time: 121]
Epoch 4 >> [critic_loss: 3669.42] [generator_loss: 22813.02] [time: 119]
Epoch 5 >> [critic_loss: 691.72] [generator_loss: 25963.17] [time: 119]
Epoch 6 >> [critic_loss: 2067.01] [generator_loss: 4736.51] [time: 119]
Epoch 7 >> [critic_loss: 2743.16] [generator_loss: 7541.40] [time: 118]
Epoch 8 >> [critic_loss: 711.56] [generator_loss: 10429.52] [time: 118]
Epoch 9 >> [critic_loss: -31.39] [generator_loss: 3048.88] [time: 118]
Epoch 10 >> [critic_loss: 1192.18] [generator_loss: 3188.65] [time: 118]
Epoch 11 >> [critic_loss: 1605.44] [generator_loss: 3419.38] [time: 118]
Epoch 12 >> [critic_loss: 356.66] [generator_loss: 1274.45] [time: 118]
Epoch 13 >> [critic_loss: 93.75] [generator_loss: 1303.90] [time: 118]
Epoch 14 >> [critic_loss: 221.53] [generator_loss: 538.80] [time

## Test time!


In [None]:
gen_name = 'generator_model_160.h5'

test_g = tf.keras.models.load_model(os.path.join(models_path, gen_name))
samples, _ = generate_fake_samples(test_g, latent_dim, 16) 
n = 4
samples = (samples + 1) /2.0
# plot images
for i in range(n * n):
  # define subplot
  pyplot.subplot(n, n, 1 + i)
  # turn off axis
  pyplot.axis('off')
  # plot raw pixel data
  pyplot.imshow(samples[i])