# GAN with KL Divergence & mode coverage metrics
# MNIST

In [20]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import keras.models
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os, sys
import itertools
import functools
import time
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops

%matplotlib inline

## Load Data

In [2]:
mb_size = 128
Z_dim = 100
n_classes = 10

In [3]:
mnist = input_data.read_data_sets('mnist', source_url='file:///workspace/fashion_mnist_src/', one_hot=True)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting mnist/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting mnist/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting mnist/t10k-images-idx3-ubyte.gz
Extracting mnist/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


## Metrics

### Pre-trained classifier

In [4]:
(x_train, y_train) = mnist.train.next_batch(60000)

classifier = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(16, (4,4), strides=2, padding='same', activation=tf.nn.relu),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
classifier.compile(optimizer = 'adam', 
             loss='categorical_crossentropy',
             metrics=['accuracy'])

classifier.fit(x_train.reshape(60000, 28, 28, 1), y_train, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fcbc6a9cf28>

In [5]:
(x_val, y_val) = mnist.validation.next_batch(5000)
classifier.evaluate(x_val.reshape(5000, 28, 28, 1), y_val)



[0.29273532178625467, 0.9126]

### Convergence Metrics

In [6]:
distr = tf.contrib.distributions

def classifier_predict(img_batch):
    return classifier.predict(img_batch.reshape(img_batch.shape[0], 28, 28, 1))

def prob_distr(img_batch):
    img_classes = classifier_predict(img_batch)
    totals = np.sum(img_classes, axis=0)
    return totals/np.sum(totals)

def kl_divergence(img_batch):
  
    #prob distr of samples
    gen_pd = prob_distr(img_batch) 
  
    # number of modes covered in samples
    modes_covered = sum(1 for p in gen_pd if p > 0) 
  
    # kl divergence
    tiny_e = 1e-5
    gen_pd_epsilon = gen_pd + tiny_e
    data_pd = np.full((n_classes,),(1/n_classes)) + tiny_e
    kl = np.sum(gen_pd_epsilon * np.log(gen_pd_epsilon / data_pd)) 
    rev_kl = np.sum(data_pd * np.log(data_pd / gen_pd_epsilon)) 
  
    return kl, rev_kl, gen_pd, modes_covered

In [7]:
'''
From https://github.com/tsc2017/Inception-Score
Code derived from https://github.com/openai/improved-gan/blob/master/inception_score/mode.py and https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
Usage:
    Call get_inception_score(images, splits=10)
Args:
    images: A numpy array with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. 
            dtype of the images is recommended to be np.uint8 to save CPU memory.
    splits: The number of splits of the images, default is 10.
Returns:
    Mean and standard deviation of the Inception Score across the splits.
'''

def preds2score(preds, splits=10):
    scores = []
    for i in range(splits):
        part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))
    return np.mean(scores), np.std(scores)

def get_inception_score(images, splits=10):
    print('Calculating "Inception" Score with %i images in %i splits' % (images.shape[0], splits))
    start_time=time.time()
    preds = classifier_predict(images)
    mean, std = preds2score(preds, splits)
    print('Inception Score calculation time: %f s' % (time.time() - start_time))
    return mean, std  # Reference values: 11.34 for 49984 CIFAR-10 training set images, or mean=11.31, std=0.08 if in 10 splits.

In [18]:
regr_buffer = 10
def regr_slope(z):
    x = np.array(range(0, len(z)))
    y = np.array(z)
    x_std = x.std()
    y_std = y.std()
    (r,_) = stats.pearsonr(x, y)
    return r * (y_std/x_std)

## Model 

### Weights and Placeholders

In [21]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

In [22]:
X = tf.placeholder(tf.float32, shape=[None, 784])

D_W1 = tf.Variable(xavier_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))

D_W2 = tf.Variable(xavier_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]

Z = tf.placeholder(tf.float32, shape=[None, 100])

G_W1 = tf.Variable(xavier_init([100, 128]))
G_b1 = tf.Variable(tf.zeros(shape=[128]))

G_W2 = tf.Variable(xavier_init([128, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))

theta_G = [G_W1, G_W2, G_b1, G_b2]

lr_D = tf.Variable(0.001)
lr_G = tf.Variable(0.001)

### Architecture

In [23]:
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob

def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob, D_logit

### Loss functions

In [24]:
G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

# D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# G_loss = -tf.reduce_mean(tf.log(D_fake))

# Alternative losses:
# -------------------
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

D_solver = tf.train.AdamOptimizer(learning_rate=lr_D).minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer(learning_rate=lr_G).minimize(G_loss, var_list=theta_G)

#KL = distr.kl_divergence(distr.Categorical(probs=prob_distr(G_sample)), distr.Categorical(probs= np.full((n_classes,),(1/n_classes))), allow_nan_stats=True)

## Training Session

In [25]:
def plot(samples):
    fig = plt.figure(figsize=(10, 10))
    gs = gridspec.GridSpec(10, 10)
    gs.update(wspace=0.1, hspace=0.1)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [26]:
lr = [0.001, 0.0005, 0.0001]
lr_pairs = list( itertools.product(lr, lr) )

### Early Termination

In [29]:
run = 0
best_run = -1
best_is = -1
et_start = time.time()
print("Start time: ", end='')
print(et_start)
for (lr_d, lr_g) in lr_pairs:
    print('\n------------ ({}, {}) -------------'.format(lr_d, lr_g))

    # new session
    sess = tf.Session()
    saver = tf.train.Saver()
    
    # initialize
    sess.run(tf.global_variables_initializer())
    g_losses, d_losses, rev_kls, pds, mcs, rev_kl_buffer = [], [], [], [], [], []
    run += 1
    if not os.path.exists('/mnt/p/run_fmnist_%d'%run):
        os.makedirs('/mnt/p/run_fmnist_%d'%run)

    # training loop
    training_steps = 500000
    sample_every = 1000
    
    i = 0
    for it in range(training_steps):

        # every nth iteration, print a sample of generated images
        if (it % sample_every == 0) or (it == training_steps-1):
            samples = sess.run(G_sample, 
                               feed_dict={Z: sample_Z(100, Z_dim), 
                                          lr_D: lr_d, 
                                          lr_G: lr_g})
            fig = plot(samples)
            plt.savefig('/mnt/p/run_fmnist_{}/{}.png'.format(run, str(i).zfill(3)), bbox_inches='tight')
            i += 1
            plt.close(fig)

        # training step
        X_mb, _ = mnist.train.next_batch(mb_size)
        _, D_loss_curr = sess.run([D_solver, D_loss], 
                                  feed_dict={X: X_mb, 
                                             Z: sample_Z(mb_size, Z_dim), 
                                             lr_D: lr_d, 
                                             lr_G: lr_g})
        _, G_loss_curr = sess.run([G_solver, G_loss], 
                                  feed_dict={Z: sample_Z(mb_size, Z_dim), 
                                             lr_D: lr_d, 
                                             lr_G: lr_g})

        # every nth iteration, save metric info
        if it % sample_every == 0:

            # get metrics
            kl_samples = sess.run(G_sample, 
                                  feed_dict={Z: sample_Z(200, Z_dim), 
                                             lr_D: lr_d, 
                                             lr_G: lr_g})
            kl, rev_kl, pd, mc = kl_divergence(kl_samples)

            # print
            print('...{}'.format(it), end='')
#             print('D loss: {:.4}'. format(D_loss_curr))
#             print('G_loss: {:.4}'.format(G_loss_curr))
#             print('KL Div: {:.4}'.format(kl))
#             print('Reverse KL Div: {:.4}'.format(rev_kl))

            # save
            g_losses.append(D_loss_curr)
            d_losses.append(G_loss_curr)
            rev_kls.append(rev_kl)
            rev_kl_buffer.append(rev_kl)
            rev_kl_buffer = rev_kl_buffer[-10:]
            pds.append(pd)
            mcs.append(mc)
            
            # If we've collected enough samples...
            if len(rev_kl_buffer)==10:
                # Test for termination condition
                rev_kl_slope = regr_slope(rev_kl_buffer)
                if rev_kl_slope >= 0.1:
                    print(". EARLY TERMINATION AT %d"%it)
                    break
                

    # save all metrics for the run
    run_summary = np.concatenate([g_losses, d_losses, rev_kls, mcs]).reshape(4, -1 ).T
    np.savetxt('/mnt/p/run_fmnist_{}/metrics_D{}_G{}.csv'.format(run, lr_d, lr_g), run_summary, fmt='%1.4e', delimiter=',', header="GLoss,DLoss,RevKLDivergence,ModesCovered")
    np.savetxt('/mnt/p/run_fmnist_{}/prob_distr.csv'.format(run), pds, fmt='%1.4e', delimiter=',')
    
    ## "Inception" score (not techinically since it uses a different classifier)
    generated_images = sess.run(G_sample, 
                   feed_dict={Z: sample_Z(1000, Z_dim),
                    lr_D: lr_d, 
                    lr_G: lr_g })
    generated_images = generated_images.reshape(1000,28,28,1)
    score_mean, score_std = get_inception_score(generated_images)
    np.savetxt('/mnt/p/run_fmnist_{}/score.csv'.format(run), [score_mean, score_std], fmt='%1.4e', delimiter=',')
    
    if score_mean > best_is:
        best_is = score_mean
        best_run = run
    
    ## Save model and weights
    save_path = saver.save(sess, '/mnt/p/run_fmnist_{}/model.ckpt'.format(run))

et_end = time.time()
print("TOTAL ELAPSED TIME: ", end='')
print(et_end - et_start)

Start time: 1543872773.654608

------------ (0.001, 0.001) -------------
...0...1000...2000...3000...4000...5000...6000...7000...8000...9000...10000...11000...12000...13000...14000...15000...16000...17000...18000...19000...20000...21000...22000...23000...24000...25000...26000...27000...28000...29000...30000...31000...32000...33000...34000...35000...36000...37000...38000...39000...40000...41000...42000...43000...44000...45000...46000...47000...48000...49000...50000...51000...52000...53000...54000...55000...56000...57000...58000...59000...60000...61000...62000...63000...64000...65000...66000...67000...68000...69000...70000...71000...72000...73000...74000...75000...76000...77000...78000...79000...80000...81000...82000...83000...84000...85000...86000...87000...88000...89000...90000...91000...92000...93000...94000...95000...96000...97000...98000...99000...100000...101000...102000...103000...104000...105000...106000...107000...108000...109000...110000...111000...112000...113000...114000...11



Inception Score calculation time: 0.024864 s

------------ (0.001, 0.0005) -------------
...0...1000...2000...3000...4000...5000...6000...7000...8000...9000...10000...11000...12000...13000...14000...15000...16000...17000...18000...19000...20000...21000...22000...23000...24000...25000...26000...27000...28000...29000...30000...31000...32000...33000...34000...35000...36000...37000...38000...39000...40000...41000...42000...43000...44000...45000...46000...47000...48000...49000...50000...51000...52000...53000...54000...55000...56000...57000...58000...59000...60000...61000...62000...63000...64000...65000...66000...67000...68000...69000...70000...71000...72000...73000...74000...75000...76000...77000...78000...79000...80000...81000...82000...83000...84000...85000...86000...87000...88000...89000...90000...91000...92000...93000...94000...95000...96000...97000...98000...99000...100000...101000...102000...103000...104000...105000...106000...107000...108000...109000...110000...111000...112000...1130

In [31]:
print(best_run, best_is)

7 6.6486692


## No Early Termination

In [None]:
run = 0
best_run = -1
best_is = -1
full_start = time.time()
print("Start time: ", end='')
print(full_start)
for (lr_d, lr_g) in lr_pairs:
    print('\n------------ ({}, {}) -------------'.format(lr_d, lr_g))

    # new session
    sess = tf.Session()
    saver = tf.train.Saver()
    
    # initialize
    sess.run(tf.global_variables_initializer())
    g_losses, d_losses = [], []
    run += 1
    if not os.path.exists('/mnt/p/run_fmnist_full_%d'%run):
        os.makedirs('/mnt/p/run_fmnist_full_%d'%run)

    # training loop
    training_steps = 500000
    sample_every = 1000
    
    i = 0
    for it in range(training_steps):

        # every nth iteration, print a sample of generated images
        if (it % sample_every == 0) or (it == training_steps-1):
            samples = sess.run(G_sample, 
                               feed_dict={Z: sample_Z(100, Z_dim), 
                                          lr_D: lr_d, 
                                          lr_G: lr_g})
            fig = plot(samples)
            plt.savefig('/mnt/p/run_fmnist_full_{}/{}.png'.format(run, str(i).zfill(3)), bbox_inches='tight')
            i += 1
            plt.close(fig)

        # training step
        X_mb, _ = mnist.train.next_batch(mb_size)
        _, D_loss_curr = sess.run([D_solver, D_loss], 
                                  feed_dict={X: X_mb, 
                                             Z: sample_Z(mb_size, Z_dim), 
                                             lr_D: lr_d, 
                                             lr_G: lr_g})
        _, G_loss_curr = sess.run([G_solver, G_loss], 
                                  feed_dict={Z: sample_Z(mb_size, Z_dim), 
                                             lr_D: lr_d, 
                                             lr_G: lr_g})

        # every nth iteration, save metric info
        if it % sample_every == 0:

            # print
            print('...{}'.format(it), end='')
            #print('D loss: {:.4}'. format(D_loss_curr))
            #print('G_loss: {:.4}'.format(G_loss_curr))
            #print('KL Div: {:.4}'.format(kl))
            #print('Reverse KL Div: {:.4}'.format(rev_kl))

            # save
            g_losses.append(D_loss_curr)
            d_losses.append(G_loss_curr)

    # save all metrics for the run
    run_summary = np.concatenate([g_losses, d_losses]).reshape(2, -1 ).T
    np.savetxt('/mnt/p/run_fmnist_full_{}/metrics_D{}_G{}.csv'.format(run, lr_d, lr_g), run_summary, fmt='%1.4e', delimiter=',', header="GLoss,DLoss")
    np.savetxt('/mnt/p/run_fmnist_full_{}/prob_distr.csv'.format(run), pds, fmt='%1.4e', delimiter=',')
    
    ## "Inception" score (not techinically since it uses a different classifier)
    generated_images = sess.run(G_sample, 
                   feed_dict={Z: sample_Z(1000, Z_dim),
                    lr_D: lr_d, 
                    lr_G: lr_g })
    generated_images = generated_images.reshape(1000,28,28,1)
    score_mean, score_std = get_inception_score(generated_images)
    np.savetxt('/mnt/p/run_fmnist_full_{}/score.csv'.format(run), [score_mean, score_std], fmt='%1.4e', delimiter=',')
    
    if score_mean > best_is:
        best_is = score_mean
        best_run = run
    
    ## Save model and weights
    save_path = saver.save(sess, '/mnt/p/run_fmnist_full_{}/model.ckpt'.format(run))

full_end = time.time()
print("TOTAL ELAPSED TIME: ", end='')
print(full_end - full_start)

Start time: 1543930332.0702457

------------ (0.001, 0.001) -------------
...0...1000...2000...3000...4000...5000...6000...7000...8000...9000...10000...11000...12000

In [None]:
print(best_run, best_is)