Thank you to the following for the code on which this program is based.

For iWGAN implementation
https://github.com/Shaofanl/Keras-GAN/blob/master/GAN/models/GAN.py

For Layernorm implementation
https://github.com/DingKe/nn_playground/blob/master/layernorm/layer_norm_layers.py

To Francois Chollet for his Keras framework and book Deep Learing with Python.
Pages 308-311 give some basic GAN code to get started with.

The paper by Ishaan Gulrajani, Faruk Ahmed, Martin Arjovsky, Vincent Dumoulin, Aaron Courville which introduced the iWGAN is here. 
https://arxiv.org/abs/1704.00028

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib.lines as mlines


In [None]:
import keras 
from keras import layers
from keras.layers import Lambda
from keras.layers import Input
from keras.layers import Activation
from keras import backend as K
from keras.constraints import max_norm
import numpy as np
from tqdm import tqdm
import inspect
import contextlib
import time
from time import sleep
from tqdm import trange
import pydot
import graphviz
from keras.utils import plot_model
import os
from keras.preprocessing import image
import csv
from keras.engine import Layer, InputSpec
from keras import initializers, regularizers
from keras.utils.generic_utils import serialize_keras_object
from keras.utils.generic_utils import deserialize_keras_object
from keras.legacy import interfaces
from keras import backend as K
from keras.engine.topology import Layer
import numpy as np

In [None]:
latent_dim = 64
height = 32
width = 32
channels = 3
#base path where you wish to save output files
save_dir = './'

#save hyper parameter details to a txt file.
f = open(os.path.join(save_dir, 'out/model/hyperparameters.txt'),'a')
f.write('\n' + 'latent_dim = ' + str(latent_dim))
f.write('\n' + 'height = ' + str(height))
f.write('\n' + 'width = ' + str(width))
f.write('\n' + 'channels = ' + str(channels))
f.close()

In [None]:
### Layer normiliziation implementation from the following source.
### https://github.com/DingKe/nn_playground/blob/master/layernorm/layer_norm_layers.py
### this is used in the discriminator architecture.

def to_list(x):
    if type(x) not in [list, tuple]:
        return [x]
    else:
        return list(x)

def LN(x, gamma, beta, epsilon=1e-6, axis=-1):
    m = K.mean(x, axis=axis, keepdims=True)
    std = K.sqrt(K.var(x, axis=axis, keepdims=True) + epsilon)
    x_normed = (x - m) / (std + epsilon)
    x_normed = gamma * x_normed + beta

    return x_normed


class LayerNormalization(Layer):
    def __init__(self, axis=-1,
                 gamma_init='one', beta_init='zero',
                 gamma_regularizer=None, beta_regularizer=None,
                 epsilon=1e-6, **kwargs): 
        super(LayerNormalization, self).__init__(**kwargs)

        self.axis = to_list(axis)
        self.gamma_init = initializers.get(gamma_init)
        self.beta_init = initializers.get(beta_init)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.epsilon = epsilon

        self.supports_masking = True

    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        shape = [1 for _ in input_shape]
        for i in self.axis:
            shape[i] = input_shape[i]
        self.gamma = self.add_weight(shape=shape,
                                     initializer=self.gamma_init,
                                     regularizer=self.gamma_regularizer,
                                     name='gamma')
        self.beta = self.add_weight(shape=shape,
                                    initializer=self.beta_init,
                                    regularizer=self.beta_regularizer,
                                    name='beta')
        self.built = True

    def call(self, inputs, mask=None):
        return LN(inputs, gamma=self.gamma, beta=self.beta, 
                  axis=self.axis, epsilon=self.epsilon)

    def get_config(self):
        config = {'epsilon': self.epsilon,
                  'axis': self.axis,
                  'gamma_init': initializers.serialize(self.gamma_init),
                  'beta_init': initializers.serialize(self.beta_init),
                  'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
                  'beta_regularizer': regularizers.serialize(self.gamma_regularizer)}
        base_config = super(LayerNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

## The Generator


In [None]:
generator_input =  keras.Input(shape=(latent_dim,))

In [None]:
### Basic clipping layer which can be used in conjunction with Lambda layer in output of Generator.
def clipLayer(x):
    return keras.backend.clip(x, -1,1)

In [None]:
### Setup initializer values for keras RandomNormal which will be used in initialization of weights.
mean = 0.0
stddev = 0.02
keras.initializers.RandomNormal(mean=mean, stddev=stddev, seed=None)
f = open(os.path.join(save_dir, 'out/model/hyperparameters.txt'),'a')
f.write('\n' + 'initialisers mean = ' + str(mean))
f.write('\n' + 'initialisers stddev = ' + str(stddev))
f.close()

In [None]:
#First, Transform the input into a 4 X 4 x 1024 channels feature map
x = layers.Dense(1024 * 4 * 4, kernel_initializer='random_normal',
                bias_initializer='zeros')(generator_input)
x = Activation('relu')(x)
x = layers.BatchNormalization()(x)

x = layers.Reshape((4, 4, 1024))(x)


#Then add a convolution layer
x = layers.Conv2D(1024, 5, padding='same',  kernel_initializer='random_normal',
                bias_initializer='zeros')(x)
x = Activation('relu')(x)
x = layers.BatchNormalization()(x)


#Upsample to 8 x 8 reduce number of filters
x = layers.UpSampling2D()(x)
x = layers.Conv2D(512, 3, padding='same',  kernel_initializer='random_normal',
           bias_initializer='zeros')(x)
x = Activation('relu')(x)
x = layers.BatchNormalization()(x)

#Upsample to 16 x 16 reduce number of filters
x = layers.UpSampling2D()(x)
x = layers.Conv2D(256, 5, padding='same',  kernel_initializer='random_normal',
                bias_initializer='zeros')(x)
x = Activation('relu')(x)
x = layers.BatchNormalization()(x)

#Upsample to 32 x 32 reduce number of filters
x = layers.UpSampling2D()(x)
x = layers.Conv2D(128, 5, padding='same',  kernel_initializer='random_normal',
                bias_initializer='zeros')(x)
x = Activation('relu')(x)
x = layers.BatchNormalization()(x)

#Produce a 32 x 32 l-channel feature map
x = layers.Conv2D(channels, 7, padding='same', kernel_initializer='random_normal',
                bias_initializer='zeros')(x)

### Decide whether to use tanh-alone, BN-with-tanh or BN-with-clipping
### by uncommenting the appropriate lines below.
#x = layers.BatchNormalization()(x)
x = Activation('tanh')(x)   
#x=layers.Lambda(clipLayer, tuple(list((32,32,3))))(x)
generator = keras.models.Model(generator_input, x)

#print a plot of the model to the output directory.
plot_model(generator, show_shapes=True, to_file=os.path.join(save_dir, 'out/model/generator.png'))

# Show a summary of the Generator network by uncommenting the next line.
#generator.summary()


# Initializing the weights of the final BN Layer
If you wish to set the initial weights of the final BatchNorm layer you can do so by uncommenting the line below.
I have initialized to the RGB standard-deviation and RGB mean of the target distribution.
final two lists are the running-mean and running-SD which we set to 0 and 1 respectively
Keras BatchNorm implementation is here for reference
https://github.com/keras-team/keras/blob/master/keras/layers/normalization.py
if you change the architecture you will have to determine the correct layer number 
by using generator.summary() you can count your way down to the appropriate BN layer.

In [None]:

#generator.layers[21].set_weights(np.array([[0.45776676431247115, 0.4371233749533978, 0.440839876713472],[-0.0598859, -0.123213, -0.309562], [0,0,0], [1,1,1]]))

### you may wish to freeze the weights of the final BN Layer. If so you can do that with the following line.
#generator.layers[21].trainable = False


In [None]:
### The iWGAN enforces the lipschitz contstraint with a soft penalty on the norm of the gradient.
class GradNorm(Layer):
    def __init__(self, **kwargs):
        super(GradNorm, self).__init__(**kwargs)

    def build(self, input_shapes):
        super(GradNorm, self).build(input_shapes)

    def call(self, inputs):
        target, wrt = inputs
        grads = K.gradients(target, wrt)
        assert len(grads) == 1
        grad = grads[0]
        
        return K.sqrt(K.sum(K.batch_flatten(K.square(grad)), axis=1, keepdims=True))
    
    def compute_output_shape(self, input_shapes):
        return (input_shapes[1][0], 1)

## The Discriminator

In [None]:
def discriminator_loss(pred, label):    
    return K.mean(label*pred)

In [None]:
discriminator_input = layers.Input(shape=(height, width, channels))
x = layers.Conv2D(128, 3, kernel_initializer='random_normal',
                bias_initializer='zeros' )(discriminator_input)
#
x = layers.LeakyReLU(0.2)(x)
x = LayerNormalization()(x)

x = layers.Conv2D(256, 4, strides=2, kernel_initializer='random_normal',
                bias_initializer='zeros' )(x)
x = layers.LeakyReLU(0.2)(x)
x = LayerNormalization()(x)#x = Activation('tanh')(x)

x = layers.Conv2D(512, 4, strides=2, kernel_initializer='random_normal',
                bias_initializer='zeros')(x)
x = layers.LeakyReLU(0.2)(x)
x = LayerNormalization()(x)

x = layers.Conv2D(1024, 4, strides=2, kernel_initializer='random_normal',
                bias_initializer='zeros')(x)
x = layers.LeakyReLU(0.2)(x)
x = LayerNormalization()(x)

x = layers.Flatten()(x)

x = layers.Dense(1, kernel_initializer='random_normal',
                bias_initializer='zeros')(x)

discriminator = keras.models.Model(discriminator_input, x)
plot_model(discriminator,show_shapes=True, to_file=os.path.join(save_dir, 'out/model/discriminator.png'))
#discriminator.summary()

In [None]:
discriminator_optimizer = keras.optimizers.Adam(lr=0.0001,beta_1=0.9, beta_2=0.99)
discriminator.compile(optimizer=discriminator_optimizer, loss = discriminator_loss)

### update hyperparamters file.
f = open(os.path.join(save_dir, 'out/model/hyperparameters.txt'),'a')
f.write('\n' + 'discriminator optimizer is Adam(lr=0.0001,beta_1=0.9, beta_2=0.99) ' )
f.close()

## The adversarial network

In [None]:
def gan_loss(pred, label):
    return K.mean(label*pred)

In [None]:
def mean_loss(pred, label):
    return K.mean(pred*label)

In [None]:
#set discriminator weights to non-trainable
# (will only apply to the 'gan' model)
discriminator.trainable = False

gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input)) 
gan = keras.models.Model(gan_input, gan_output)

gan_optimizer = keras.optimizers.Adam(lr=0.00001,beta_1=0.9, beta_2=0.99)
gan.compile(optimizer=gan_optimizer, loss=gan_loss)

#Record Hyperparameters
f = open(os.path.join(save_dir, 'out/model/hyperparameters.txt'),'a')
f.write('\n' + 'GAN optimizer is Adam(lr=0.00001,beta_1=0.9, beta_2=0.99)' )
f.close()

In [None]:
from keras.layers.merge import _Merge

class Subtract(_Merge):
    def _merge_function(self, inputs):
        output = inputs[0]
        for i in range(1, len(inputs)):
            output = output-inputs[i]
            #output = input
        return output

In [None]:
from keras.layers.merge import _Merge

class Square(_Merge):
    def _merge_function(self, input):
        output = input*input
        
        return output

In [None]:
lmbd = 10
shape = discriminator.get_input_shape_at(0)[1:]
gen_input, real_input, interpolation = keras.Input(shape), keras.Input(shape), keras.Input(shape)

sub = Subtract()([discriminator(gen_input), discriminator(real_input)])
norm = GradNorm()([discriminator(interpolation), interpolation])

val = keras.Input(norm.get_shape()[1:])
normal = Subtract()([norm, val])
normal_sq = keras.layers.multiply([normal, normal])

dis2batch = keras.models.Model([real_input, gen_input, interpolation, val], [sub, normal_sq]) 
discriminator.trainable = True
dis2batch.compile(optimizer=discriminator_optimizer, loss=[mean_loss,'mse'], loss_weights=[1.0, lmbd])

#Create plot of dis2batch model in output directory
plot_model(dis2batch,show_shapes=True, to_file=os.path.join(save_dir, 'out/model/dis2batch.png'))
#dis2batch.summary()
f = open(os.path.join(save_dir, 'out/model/hyperparameters.txt'),'a')
f.write('\n' + 'loss weights are [1.0, lmbd] ' )
f.close()


## Training the DCGAN

In [None]:
#Load CIFAR10 data
(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data()

#select frog images(class 6)
x_train = x_train[y_train.flatten() == 6]

# Normalize data to range [-1,1]
x_train = ((x_train.reshape((x_train.shape[0],) + (height, width, channels)).astype('float32'))/127.5 - 1) 


In [None]:
### This function can be called if you want to see a png of the critic loss
### I don't use this below but it is useful to output periodically during trainging 
### as opening the csv file can lead to file corruption.
### Be wary tough, it is a slow operation if there are a lot of datapoints so it can
### slow everything down lote in
def outputCriticPlot():        
    plt.plot(d_loss_short_curve,label='Discriminator Loss')    
    plt.title('Critic loss Curves')   
    plt.savefig(os.path.join(save_dir, 'out/model/Critic_Loss.png'))
  

In [None]:
batch_size = 20

d_loss_curve = []
# Start training loop
start = 0
current_step = 0
gen_count = 0 #Keep a count of how many iterations of the generator we have
dis_count = 0 #keep a count of how many iterations of the discriminator we have
iterations = 30001
t =  trange(iterations, desc='')
initial_critic_schedule = 50
normal_critic_schedule = 5
f = open(os.path.join(save_dir, 'out/model/hyperparameters.txt'),'a')
f.write('\n' + 'batch_size = ' + str(batch_size))
f.write('\n' + 'intital_critic_schedule = ' + str(initial_critic_schedule) +"first 25 then every 500")
f.write('\n' + 'normal_critic_schedule = ' + str(normal_critic_schedule))

f.close()

a_loss = 0
d_loss = 0
generated_images=0
interpolation=0
value=0

In [None]:

for step in t:
    
    if current_step < 25 or current_step%500 ==0: 
        n_critic = initial_critic_schedule
       
    else:
        n_critic = normal_critic_schedule
       
   
    ##################  Critic/Discriminator Loop  ##############################################
    
    for dis_step in range(n_critic):
        # I don't think the following line is necessary along with it's counterpart below 
        discriminator.trainable = True
        dis_count = dis_count + 1
        
        #Sample random points in the latent space
        random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
    
        # Decode them to fake images
        generated_images = generator.predict(random_latent_vectors)
                
        #combine them with real images
        stop = start + batch_size
        real_images = x_train[start: stop]
        epsilon = np.random.uniform(0, 1, size=(batch_size,1,1,1))
        interpolation = epsilon*real_images + (1-epsilon)*generated_images
        value = np.ones((batch_size, 1))
    
        # Assemble lables discriminating real from fake images
        labelsFake = np.ones((batch_size, 1))
        labelsReal = -np.ones((batch_size, 1))
                   
        #Train the discriminator
        
        d_loss, d_diff, d_norm = dis2batch.train_on_batch([real_images, generated_images, interpolation, value], 
                                                          [np.ones((batch_size, 1))]*2)
        
        start += batch_size
        if start > len(x_train) - batch_size:
            start = 0   
   
    
    ###################### Generator Single Iteration #########################################################
  
    gen_count = gen_count + 1
    # I don't think the following line is necessary along with it's counterpart above 
    discriminator.trainable = False
    # samples random points in the latent space
    
    random_latent_vectors = np.random.normal(size= (batch_size, latent_dim))
    
    #Assemble labels that say "all real images"
    misleading_targets = -np.ones((batch_size, 1))
    
    # Train the generator (via the gan model,
    # where the discrimator weights are frozen)        
    a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)
     
    ##################################################################################################
    
    d_loss_curve.append(d_loss)
       
    if(current_step%50==0):
                               
        #Save one generated image
        img = image.array_to_img((generated_images[0]+1) * 127.5, scale=False)
        img.save(os.path.join(save_dir, 'out/gen/generated_frog' + str(current_step)+'.png'))
               
        with open(os.path.join(save_dir, 'out/model/critic_loss.csv'), 'w') as myfile:
            wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
            wr.writerow(d_loss_curve)
       
    if current_step % 5000 == 0:
        gan.save_weights(os.path.join(save_dir, 'out/model/' + str(current_step) + 'gan.h5'))
         
    current_step = current_step + 1
    
f = open(os.path.join(save_dir, 'out/model/hyperparameters.txt'),'a')
f.write('\n' + 'Critic iterations = ' + str(dis_count))
f.close()    
