In [0]:
from __future__ import print_function, division

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, concatenate
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Lambda
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical
import keras.backend as K

from keras.datasets import mnist

In [0]:
img_rows = 28
img_cols = 28
channels = 1
num_classes = 10
img_shape = (img_rows, img_cols, channels)
latent_dim = 72

In [0]:
def build_disk_and_q_net(img_shape=img_shape, num_classes=num_classes):

  img = Input(shape=img_shape)

  # Shared layers between discriminator and recognition network
  model = Sequential()
  model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.25))
  model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
  model.add(ZeroPadding2D(padding=((0,1),(0,1))))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.25))
  model.add(BatchNormalization(momentum=0.8))
  model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.25))
  model.add(BatchNormalization(momentum=0.8))
  model.add(Conv2D(512, kernel_size=3, strides=2, padding="same"))
  model.add(LeakyReLU(alpha=0.2))
  model.add(Dropout(0.25))
  model.add(BatchNormalization(momentum=0.8))
  model.add(Flatten())

  img_embedding = model(img)

  # Discriminator
  validity = Dense(1, activation='sigmoid')(img_embedding)

  # Recognition
  q_net = Dense(128, activation='relu')(img_embedding)
  label = Dense(num_classes, activation='softmax')(q_net)

  # Return discriminator and recognition network
  return Model(img, validity), Model(img, label)


In [0]:
def build_generator(latent_dim=latent_dim, channels=channels):

  model = Sequential()

  model.add(Dense(128 * 7 * 7, activation="relu", input_dim=latent_dim))
  model.add(Reshape((7, 7, 128)))
  model.add(BatchNormalization(momentum=0.8))
  model.add(UpSampling2D())
  model.add(Conv2D(128, kernel_size=3, padding="same"))
  model.add(Activation("relu"))
  model.add(BatchNormalization(momentum=0.8))
  model.add(UpSampling2D())
  model.add(Conv2D(64, kernel_size=3, padding="same"))
  model.add(Activation("relu"))
  model.add(BatchNormalization(momentum=0.8))
  model.add(Conv2D(channels, kernel_size=3, padding='same'))
  model.add(Activation("tanh"))

  gen_input = Input(shape=(latent_dim,))
  img = model(gen_input)

  model.summary()

  return Model(gen_input, img)

In [0]:
def mutual_info_loss(c, c_given_x):
  """The mutual information metric we aim to minimize"""
  eps = 1e-8
  conditional_entropy = K.mean(- K.sum(K.log(c_given_x + eps) * c, axis=1))
  entropy = K.mean(- K.sum(K.log(c + eps) * c, axis=1))

  return conditional_entropy + entropy

In [0]:
def sample_generator_input(batch_size, num_classes=num_classes):
  # Generator inputs
  sampled_noise = np.random.normal(0, 1, (batch_size, 62))
  sampled_labels = np.random.randint(0, num_classes, batch_size).reshape(-1, 1)
  sampled_labels = to_categorical(sampled_labels, num_classes=num_classes)

  return sampled_noise, sampled_labels

In [0]:
def train(epochs, batch_size=128, sample_interval=50):

  # Load the dataset
  (X_train, y_train), (_, _) = mnist.load_data()

  # Rescale -1 to 1
  X_train = (X_train.astype(np.float32) - 127.5) / 127.5
  X_train = np.expand_dims(X_train, axis=3)
  y_train = y_train.reshape(-1, 1)

  # Adversarial ground truths
  valid = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for epoch in tqdm(range(epochs)):

    # ---------------------
    #  Train Discriminator
    # ---------------------

    # Select a random half batch of images
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgs = X_train[idx]

    # Sample noise and categorical labels
    sampled_noise, sampled_labels = sample_generator_input(batch_size)
    gen_input = np.concatenate((sampled_noise, sampled_labels), axis=1)

    # Generate a half batch of new images
    gen_imgs = generator.predict(gen_input)

    # Train on real and generated data
    d_loss_real = discriminator.train_on_batch(imgs, valid)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)

    # Avg. loss
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # ---------------------
    #  Train Generator and Q-network
    # ---------------------

    g_loss = combined.train_on_batch(gen_input, [valid, sampled_labels])

    # If at save interval => save generated image samples
    if epoch % sample_interval == 0:
        sample_images(epoch)


In [0]:
def sample_images(epoch, num_classes=num_classes):
  r, c = 10, 10

  fig, axs = plt.subplots(r, c)
  for i in range(c):
      sampled_noise, _ = sample_generator_input(c)
      label = to_categorical(np.full(fill_value=i, shape=(r,1)), num_classes=num_classes)
      gen_input = np.concatenate((sampled_noise, label), axis=1)
      gen_imgs = generator.predict(gen_input)
      gen_imgs = 0.5 * gen_imgs + 0.5
      for j in range(r):
          axs[j,i].imshow(gen_imgs[j,:,:,0], cmap='gray')
          axs[j,i].axis('off')
  fig.savefig("%d.png" % epoch)
  plt.close()

In [84]:
optimizer = Adam(0.0002, 0.5)
losses = ['binary_crossentropy',mutual_info_loss]

# Build and the discriminator and recognition network
discriminator, auxilliary = build_disk_and_q_net()

discriminator.compile(loss=['binary_crossentropy'], optimizer=optimizer, metrics=['accuracy'])

# Build and compile the recognition network Q
auxilliary.compile(loss=[mutual_info_loss], optimizer=optimizer, metrics=['accuracy'])

# Build the generator
generator = build_generator()

# The generator takes noise and the target label as input
# and generates the corresponding digit of that label
gen_input = Input(shape=(latent_dim,))
img = generator(gen_input)

# For the combined model we will only train the generator
discriminator.trainable = False

# The discriminator takes generated image as input and determines validity
valid = discriminator(img)
# The recognition network produces the label
target_label = auxilliary(img)

# The combined model  (stacked generator and discriminator)
combined = Model(gen_input, [valid, target_label])
combined.compile(loss=losses, optimizer=optimizer)


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_32 (Dense)             (None, 6272)              457856    
_________________________________________________________________
reshape_8 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
batch_normalization_46 (Batc (None, 7, 7, 128)         512       
_________________________________________________________________
up_sampling2d_15 (UpSampling (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_54 (Conv2D)           (None, 14, 14, 128)       147584    
_________________________________________________________________
activation_22 (Activation)   (None, 14, 14, 128)       0         
_________________________________________________________________
batch_normalization_47 (Batc (None, 14, 14, 128)       512       
__________

In [0]:
train(epochs=4000, batch_size=128, sample_interval=50)

In [0]:
def save_model():

  def save(model, model_name):
      model_path = "%s.json" % model_name
      weights_path = "%s_weights.hdf5" % model_name
      options = {"file_arch": model_path,
                  "file_weight": weights_path}
      json_string = model.to_json()
      open(options['file_arch'], 'w').write(json_string)
      model.save_weights(options['file_weight'])

  save(generator, "generator")
  save(discriminator, "discriminator")

In [0]:
save_model()