In [9]:
import numpy as np
import tensorflow as tf
from sklearn import datasets
from sklearn.preprocessing import minmax_scale, LabelBinarizer
from keras.layers import Dense
from keras import backend as K

Using TensorFlow backend.


In [2]:
ds = datasets.load_iris()
X_train, y_train = minmax_scale(ds.data), LabelBinarizer().fit_transform(ds.target)

In [3]:
n_X_samples = X_train.shape[0]
n_X_features = X_train.shape[1]
n_Z_features = 6

In [4]:
X = tf.placeholder(tf.float32, shape=[None, n_X_features], name='X')
D_W1 = tf.Variable(tf.random_uniform([n_X_features, 6]), name='D_W1')
D_b1 = tf.Variable(tf.random_uniform([6]), name='D_b1')
D_W2 = tf.Variable(tf.random_uniform([6, 1]), name='D_W2')
D_b2 = tf.Variable(tf.random_uniform([1]), name='D_b2')
D_parameters = [D_W1, D_W2, D_b1, D_b2]
def D_logit(X):
    D_h1 = tf.nn.tanh(tf.matmul(X, D_W1) + D_b1)
    return tf.matmul(D_h1, D_W2) + D_b2

Z = tf.placeholder(tf.float32, shape=[None, n_Z_features], name='Z')
G_W1 = tf.Variable(tf.random_uniform([n_Z_features, 6]), name='G_W1')
G_b1 = tf.Variable(tf.random_uniform([6]), name='G_b1')
G_W2 = tf.Variable(tf.random_uniform([6, n_X_features]), name='G_W2')
G_b2 = tf.Variable(tf.random_uniform([n_X_features]), name='G_b2')
G_parameters = [G_W1, G_W2, G_b1, G_b2]
def G_logit(Z):
    G_h1 = tf.nn.tanh(tf.matmul(Z, G_W1) + G_b1)
    return tf.matmul(G_h1, G_W2) + G_b2

def sample_Z(n_samples, n_features):
    return np.random.uniform(-1., 1., size=[n_samples, n_features]).astype(np.float32)

In [5]:
D_logit_data = D_logit(X)
D_logit_generated = D_logit(tf.nn.sigmoid(G_logit(Z)))

D_loss_data = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_data, labels=tf.ones_like(D_logit_data)))
D_loss_generated = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_generated, labels=tf.zeros_like(D_logit_generated)))
D_loss = D_loss_data + D_loss_generated

G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_generated, labels=tf.ones_like(D_logit_generated)))

In [6]:
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=D_parameters)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=G_parameters)

In [7]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

batch_size = 10
n_batches = int(n_X_samples / batch_size)
for epoch in range(50):
    X_epoch = X_train[np.random.permutation(range(n_X_samples))]
    for batch_index in range(n_batches):
        start_index = batch_index * batch_size
        end_index = start_index + batch_size
        X_batch = X_epoch[start_index:end_index]
        _, D_loss_value = sess.run([D_solver, D_loss], feed_dict={X: X_batch, Z: sample_Z(batch_size, n_Z_features)})
        _, G_loss_value = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(batch_size, n_Z_features)})
    print('Epoch: {}, discriminator loss: {}, generator loss: {}'.format(epoch, D_loss_value, G_loss_value))

Epoch: 0, discriminator loss: 2.6091184616088867, generator loss: 0.07768820226192474
Epoch: 1, discriminator loss: 2.5885040760040283, generator loss: 0.09816193580627441
Epoch: 2, discriminator loss: 2.534332275390625, generator loss: 0.09133169800043106
Epoch: 3, discriminator loss: 2.477931261062622, generator loss: 0.0961957573890686
Epoch: 4, discriminator loss: 2.406693696975708, generator loss: 0.11907561868429184
Epoch: 5, discriminator loss: 2.318744659423828, generator loss: 0.11431384086608887
Epoch: 6, discriminator loss: 2.238431453704834, generator loss: 0.13139823079109192
Epoch: 7, discriminator loss: 2.1927709579467773, generator loss: 0.14901892840862274
Epoch: 8, discriminator loss: 2.1741034984588623, generator loss: 0.16706955432891846
Epoch: 9, discriminator loss: 2.073620080947876, generator loss: 0.18236283957958221
Epoch: 10, discriminator loss: 2.0146782398223877, generator loss: 0.18602879345417023
Epoch: 11, discriminator loss: 1.974812626838684, generator 

In [8]:
X_generated = tf.nn.sigmoid(G_logit(sample_Z(10, n_Z_features)))

In [9]:
sess.run(X_generated)

array([[ 0.24513245,  0.54048389,  0.26786584,  0.87239295],
       [ 0.28050452,  0.74936378,  0.51870602,  0.90157622],
       [ 0.23966016,  0.62389594,  0.3239021 ,  0.92718315],
       [ 0.29626796,  0.70880497,  0.40506876,  0.96753514],
       [ 0.40586182,  0.64257467,  0.41495639,  0.89226812],
       [ 0.27027369,  0.68068832,  0.3591398 ,  0.96063471],
       [ 0.24787085,  0.78308898,  0.44295961,  0.98795205],
       [ 0.37847656,  0.65923154,  0.40360287,  0.92068249],
       [ 0.50756919,  0.46224296,  0.33950633,  0.23162611],
       [ 0.31925187,  0.72746146,  0.43038481,  0.96071488]], dtype=float32)