In [5]:
from numpy import mean
from numpy import ones
from numpy import zeros
from numpy import expand_dims 
from numpy.random import randn
from numpy.random import randint
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.initializers import RandomNormal
import matplotlib.pyplot as plt
from keras import backend
from keras.optimizers import RMSprop
from keras.datasets.mnist import load_data
from keras.constraints import Constraint


class ClipConstraint(Constraint):
  def __init__(self, clip_val):
    self.clip_val = clip_val

  def __call__(self, weights):
    return backend.clip(weights, -self.clip_val, self.clip_val)

  def get_config(self):
    return {'clip_value' : self.clip_val} 

def wasserstein_loss(y_true, y_pred):
  return backend.mean(y_true * y_pred)


#critic model
def define_critic(in_shape = (28, 28, 1)):
  w_init = RandomNormal(stddev = 0.02)
  clip_val = ClipConstraint(0.01) #avoids explosion of grads
  model = Sequential()
  model.add(Conv2D(64, (4, 4), strides = (2, 2), padding = 'same', kernel_initializer=w_init,
                   kernel_constraint = clip_val, input_shape = in_shape))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha = 0.2))


  model.add(Conv2D(64, (4, 4), strides = (2, 2), padding = 'same', kernel_initializer=w_init,
                   kernel_constraint = clip_val))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha = 0.2))  

  model.add(Flatten())
  model.add(Dense(1))  #can also use the linear actn for score, but same op
  optim = RMSprop(lr = 0.00005)
  model.compile(loss = wasserstein_loss, optimizer = optim)
  return model


def define_generator(latent_dims):
  w_init = RandomNormal(stddev=0.02)
  model = Sequential()
  n_nodes = 128 * 7 * 7
  model.add(Dense(n_nodes, kernel_initializer= w_init, input_dim = latent_dims))
  model.add(LeakyReLU(alpha = 0.2))
  model.add(Reshape((7, 7, 128)))

  model.add(Conv2DTranspose(128, (4, 4), strides = (2, 2), padding = 'same', kernel_initializer= w_init))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha = 0.2))

  model.add(Conv2DTranspose(128, (4, 4), strides = (2, 2), padding = 'same', kernel_initializer= w_init))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha = 0.2))

  model.add(Conv2D(1, (7, 7), activation = 'tanh', padding = 'same', kernel_initializer= w_init))
  return model 

def define_WGAN(generator, critic):
  critic.trainable = False
  model = Sequential()
  model.add(generator)
  model.add(critic)
  optim = RMSprop(lr = 0.00005)
  model.compile(loss = wasserstein_loss, optimizer = optim)
  return model

def load_real_samples():
  (X_train, y_train), (X_test, y_test) = load_data()
  selected_idx = y_train == 5
  X = X_train[selected_idx]
  X = expand_dims(X, axis = -1)
  X = X.astype('float32')
  X = (X - 127.5)/127.5
  return X

def generate_real_samples(data, n_samples):
  idx = randint(0, data.shape[0], n_samples)
  X = data[idx]
  y = -ones((n_samples, 1))
  return X, y


def latent_points(latent_dims, n_samples):
  x_in = randn(latent_dims * n_samples)
  x_in = x_in.reshape(n_samples, latent_dims)
  return x_in


def generate_fake_samples(generator, latent_dims, n_samples):
  x_in = latent_points(latent_dims, n_samples)
  X = generator.predict(x_in)
  y = ones((n_samples, 1))
  return X, y

def summarize_performance(step, g_model, latent_dims, n_samples=100):
	X, _ = generate_fake_samples(g_model, latent_dims, n_samples)
	X = (X + 1) / 2.0
	for i in range(100):
		plt.subplot(10, 10, i+1)
		plt.axis('off')
		plt.imshow(X[i, :, :, 0], cmap='gray_r')
  
	filename1 = 'generated_plot_%04d.png' % (step+1)
	plt.savefig(filename1)
	plt.close()
 
	filename2 = 'model_%04d.h5' % (step+1)
	g_model.save(filename2)
	print('---> Saved: %s and %s' % (filename1, filename2))

def plot_history(c1, c2, g):
  plt.plot(c1, label = 'critic_real')
  plt.plot(c2, label = 'critic_fake')
  plt.plot(g, label = 'generated')
  plt.legend()
  plt.savefig('line_loss_plot.png')
  plt.close() 

def train(generator, critic, wgan, data, latent_dims, n_epochs = 1000, n_batch = 64, n_critic = 5):
  batches_per_epoch = int(data.shape[0] / n_batch)
  n_steps = batches_per_epoch * n_epochs
  half_batch = int(n_batch / 2)
  c1_hist, c2_hist, g_hist = list(), list(), list()
  for i in range(n_steps):
    c1_temp, c2_temp = list(), list()
    for j in range(n_critic):
      X_real, y_real = generate_real_samples(data, half_batch)
      c1_loss = critic.train_on_batch(X_real, y_real)
      c1_temp.append(c1_loss)
      X_fake, y_fake = generate_fake_samples(generator, latent_dims, half_batch)
      c2_loss = critic.train_on_batch(X_fake, y_fake)
      c2_temp.append(c2_loss)  
    c1_hist.append(mean(c1_temp))
    c2_hist.append(mean(c2_temp))
    X_g = latent_points(latent_dims, n_batch)
    y_g = -ones((n_batch, 1))  #real samples
    g_loss = wgan.train_on_batch(X_g, y_g)
    g_hist.append(g_loss) 
    print('-->%d, c1=%.3f, c2=%.3f g=%.3f' % (i+1, c1_hist[-1], c2_hist[-1], g_loss))    #loss for that batch alone
    if (i+1) % batches_per_epoch == 0:
      summarize_performance(i, generator, latent_dims)

  plot_history(c1_hist, c2_hist, g_hist)   

latent_dims = 50
critic = define_critic()
generator = define_generator(latent_dims)
wgan = define_WGAN(generator, critic)
data = load_real_samples()
print(data.shape)
train(generator, critic, wgan, data, latent_dims)




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
-->44881, c1=-1065283.575, c2=-943353.762 g=-184854.031
-->44882, c1=-1064801.100, c2=-943654.562 g=-177905.938
-->44883, c1=-1065302.250, c2=-943639.700 g=-166213.797
-->44884, c1=-1065040.700, c2=-943573.125 g=-173755.031
-->44885, c1=-1064806.200, c2=-943602.825 g=-216491.156
-->44886, c1=-1065118.450, c2=-943807.812 g=-195012.781
-->44887, c1=-1065523.300, c2=-943916.512 g=-167025.312
-->44888, c1=-1065618.325, c2=-943944.500 g=-164391.203
-->44889, c1=-1064890.650, c2=-943564.350 g=-215219.344
-->44890, c1=-1065396.600, c2=-943840.012 g=-172861.219
-->44891, c1=-1065432.100, c2=-944035.275 g=-183413.875
-->44892, c1=-1065193.950, c2=-944036.387 g=-186622.453
-->44893, c1=-1065820.100, c2=-944027.088 g=-224846.344
-->44894, c1=-1065232.275, c2=-944051.637 g=-176997.125
-->44895, c1=-1065790.100, c2=-944169.713 g=-196398.656
-->44896, c1=-1066099.850, c2=-944488.662 g=-196995.781
-->44897, c1=-1065809.125, c2=-943590.7

KeyboardInterrupt: ignored

In [None]:
f