# GAN with KL Divergence & mode coverage metrics
# MNIST

In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import keras.models
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import itertools

from sklearn.neighbors import KNeighborsClassifier

%matplotlib inline

Using TensorFlow backend.


## Load Data

In [2]:
mb_size = 128
Z_dim = 100
n_classes = 10

In [3]:
mnist = input_data.read_data_sets('mnist', source_url='file:///workspace/mnist_src/', one_hot=True)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting mnist/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting mnist/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting mnist/t10k-images-idx3-ubyte.gz
Extracting mnist/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


## Metrics

### Pre-trained classifier

In [4]:
(x_train, y_train) = mnist.train.next_batch(60000)
(x_val, y_val) = mnist.validation.next_batch(10000)

classifier = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
classifier.compile(optimizer = 'adam', 
             loss='categorical_crossentropy',
             metrics=['accuracy'])

classifier.fit(x_train, y_train, epochs=5)
classifier.evaluate(x_val, y_val)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.06470902684219182, 0.9816]

### Convergence Metrics

In [5]:
distr = tf.contrib.distributions

def prob_distr(img_batch):
    img_classes = classifier.predict(img_batch)
    totals = np.sum(img_classes, axis=0)
    return totals/np.sum(totals)

def kl_divergence(img_batch):
  
    #prob distr of samples
    gen_pd = prob_distr(img_batch) 
  
    # number of modes covered in samples
    modes_covered = sum(1 for p in gen_pd if p > 0) 
  
    # kl divergence
    tiny_e = 1e-5
    gen_pd_epsilon = gen_pd + tiny_e
    data_pd = np.full((n_classes,),(1/n_classes)) + tiny_e
    kl = np.sum(gen_pd_epsilon * np.log(gen_pd_epsilon / data_pd)) 
    rev_kl = np.sum(data_pd * np.log(data_pd / gen_pd_epsilon)) 
  
    return kl, rev_kl, gen_pd, modes_covered

## Model 

### Weights and Placeholders

In [6]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

In [14]:
X = tf.placeholder(tf.float32, shape=[None, 784])

D_W1 = tf.Variable(xavier_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))

D_W2 = tf.Variable(xavier_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]

Z = tf.placeholder(tf.float32, shape=[None, 100])

G_W1 = tf.Variable(xavier_init([100, 128]))
G_b1 = tf.Variable(tf.zeros(shape=[128]))

G_W2 = tf.Variable(xavier_init([128, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))

theta_G = [G_W1, G_W2, G_b1, G_b2]

lr_D = tf.Variable(0.001)
lr_G = tf.Variable(0.001)

### Architecture

In [15]:
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob

def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob, D_logit

### Loss functions

In [16]:
G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

# D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# G_loss = -tf.reduce_mean(tf.log(D_fake))

# Alternative losses:
# -------------------
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

D_solver = tf.train.AdamOptimizer(learning_rate=lr_D).minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer(learning_rate=lr_G).minimize(G_loss, var_list=theta_G)

#KL = distr.kl_divergence(distr.Categorical(probs=prob_distr(G_sample)), distr.Categorical(probs= np.full((n_classes,),(1/n_classes))), allow_nan_stats=True)

## Training Session

In [17]:
def plot(samples):
    fig = plt.figure(figsize=(10, 10))
    gs = gridspec.GridSpec(10, 10)
    gs.update(wspace=0.1, hspace=0.1)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [24]:
lr = [0.001, 0.0005, 0.0001]
lr_pairs = list( itertools.product(lr, lr) )

In [26]:
run = 0
for (lr_d, lr_g) in lr_pairs:
    print('------------ ({}, {}) -------------'.format(lr_d, lr_g))

    # new session
    sess = tf.Session()

    # initialize
    sess.run(tf.global_variables_initializer())
    g_losses, d_losses, kls, rev_kls, pds, mcs = [], [], [], [], [], []
    run += 1
    if not os.path.exists('out2/run_%d'%run):
        os.makedirs('out2/run_%d'%run)

    # training loop
    training_steps = 500000
    sample_every = 1000
    
    i = 0
    for it in range(training_steps):

        # every nth iteration, print a sample of generated images
        if it % sample_every == 0:
            samples = sess.run(G_sample, 
                               feed_dict={Z: sample_Z(100, Z_dim), 
                                          lr_D: lr_d, 
                                          lr_G: lr_g})

            fig = plot(samples)
            if (it == 499999):
                plt.show(fig)
                plt.savefig('out2/run_{}/final.png'.format(run), bbox_inches='tight')
            plt.savefig('out2/run_{}/{}.png'.format(run, str(i).zfill(3)), bbox_inches='tight')
            i += 1
            plt.close(fig)

        # training step
        X_mb, _ = mnist.train.next_batch(mb_size)
        # get losses
        _, D_loss_curr = sess.run([D_solver, D_loss], 
                                  feed_dict={X: X_mb, 
                                             Z: sample_Z(mb_size, Z_dim), 
                                             lr_D: lr_d, 
                                             lr_G: lr_g})
        _, G_loss_curr = sess.run([G_solver, G_loss], 
                                  feed_dict={Z: sample_Z(mb_size, Z_dim), 
                                             lr_D: lr_d, 
                                             lr_G: lr_g})

        # every nth iteration, save metric info
        if it % sample_every == 0:

            # get metrics
            kl_samples = sess.run(G_sample, 
                                  feed_dict={Z: sample_Z(200, Z_dim), 
                                             lr_D: lr_d, 
                                             lr_G: lr_g})
            kl, rev_kl, pd, mc = kl_divergence(kl_samples)

            # print
            #print('Iter: {}'.format(it))
            #print('D loss: {:.4}'. format(D_loss_curr))
            #print('G_loss: {:.4}'.format(G_loss_curr))
            #print('KL Div: {:.4}'.format(kl))
            #print('Reverse KL Div: {:.4}'.format(rev_kl))

            # save
            g_losses.append(D_loss_curr)
            d_losses.append(G_loss_curr)
            kls.append(kl)
            rev_kls.append(rev_kl)
            pds.append(pd)
            mcs.append(mc)

    # save all metrics for the run
    run_summary = np.concatenate([g_losses, d_losses, kls, rev_kls]).reshape(4, int(training_steps/sample_every) ).T
    np.savetxt('out2/run_{}/metrics_D{}_G{}.csv'.format(run, lr_d, lr_g), run_summary, fmt='%1.4e', delimiter=',', header="GLoss,DLoss,KLDivergence,RevKLDivergence")
    np.savetxt('out2/run_{}/prob_distr.csv'.format(run), pds, fmt='%1.4e', delimiter=',')

------------ (0.001, 0.001) -------------
Iter: 0
Iter: 1000
Iter: 2000
Iter: 3000
Iter: 4000
Iter: 5000
Iter: 6000
Iter: 7000
Iter: 8000
Iter: 9000
Iter: 10000
Iter: 11000
Iter: 12000
Iter: 13000
Iter: 14000
Iter: 15000
Iter: 16000
Iter: 17000
Iter: 18000
Iter: 19000
Iter: 20000
Iter: 21000
Iter: 22000
Iter: 23000
Iter: 24000
Iter: 25000
Iter: 26000
Iter: 27000
Iter: 28000
Iter: 29000
Iter: 30000
Iter: 31000
Iter: 32000
Iter: 33000
Iter: 34000
Iter: 35000
Iter: 36000
Iter: 37000
Iter: 38000
Iter: 39000
Iter: 40000
Iter: 41000
Iter: 42000
Iter: 43000
Iter: 44000
Iter: 45000
Iter: 46000
Iter: 47000
Iter: 48000
Iter: 49000
Iter: 50000
Iter: 51000
Iter: 52000
Iter: 53000
Iter: 54000
Iter: 55000
Iter: 56000
Iter: 57000
Iter: 58000
Iter: 59000
Iter: 60000
Iter: 61000
Iter: 62000
Iter: 63000
Iter: 64000
Iter: 65000
Iter: 66000
Iter: 67000
Iter: 68000
Iter: 69000
Iter: 70000
Iter: 71000
Iter: 72000
Iter: 73000
Iter: 74000
Iter: 75000
Iter: 76000
Iter: 77000
Iter: 78000
Iter: 79000
Iter: 80000