In [1]:
import numpy as np
import tensorflow as tf
from sklearn import datasets
from sklearn.preprocessing import minmax_scale, LabelBinarizer

In [2]:
ds = datasets.load_iris()
X_train, y_train = minmax_scale(ds.data), LabelBinarizer().fit_transform(ds.target)
training_data = np.concatenate([X_train, y_train], axis = 1)

In [3]:
n_X_samples = X_train.shape[0]
n_X_features = X_train.shape[1]
n_classes = y_train.shape[1]
n_Z_features = 6

In [4]:
X = tf.placeholder(tf.float32, shape=[None, n_X_features], name='X')
D_W = tf.Variable(tf.random_uniform([n_X_features, 1]), name='D_W')
D_b = tf.Variable(tf.random_uniform([1]), name='D_b')
D_parameters = [D_W, D_b]
def D_logit(X):
    return tf.matmul(X, D_W) + D_b

Z = tf.placeholder(tf.float32, shape=[None, n_Z_features], name='Z')
G_W = tf.Variable(tf.random_uniform([n_Z_features, n_X_features]), name='G_W')
G_b = tf.Variable(tf.random_uniform([n_X_features]), name='G_b')
G_parameters = [G_W, G_b]
def G_logit(Z):
    return tf.matmul(Z, G_W) + G_b

def sample_Z(n_samples, n_features):
    return np.random.uniform(-1., 1., size=[n_samples, n_features]).astype(np.float32)

In [5]:
D_logit_data = D_logit(X)
D_logit_generated = D_logit(tf.nn.sigmoid(G_logit(Z)))

D_loss_data = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_data, labels=tf.ones_like(D_logit_data)))
D_loss_generated = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_generated, labels=tf.zeros_like(D_logit_generated)))
D_loss = D_loss_data + D_loss_generated

G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_generated, labels=tf.ones_like(D_logit_generated)))

In [6]:
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=D_parameters)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=G_parameters)

In [7]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

batch_size = 10
n_batches = int(n_X_samples / batch_size)
for epoch in range(500):
    X_epoch = X_train[np.random.permutation(range(n_X_samples))]
    for batch_index in range(n_batches):
        start_index = batch_index * batch_size
        end_index = start_index + batch_size
        X_batch = X_epoch[start_index:end_index]
        _, D_loss_value = sess.run([D_solver, D_loss], feed_dict={X: X_batch, Z: sample_Z(batch_size, n_Z_features)})
        _, G_loss_value = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(batch_size, n_Z_features)})
    print('Epoch: {}, discriminator loss: {}, generator loss: {}'.format(epoch, D_loss_value, G_loss_value))

Epoch: 0, discriminator loss: 2.3044686317443848, generator loss: 0.13100770115852356
Epoch: 1, discriminator loss: 2.1159284114837646, generator loss: 0.14579716324806213
Epoch: 2, discriminator loss: 2.2519259452819824, generator loss: 0.13755899667739868
Epoch: 3, discriminator loss: 2.256204605102539, generator loss: 0.14657092094421387
Epoch: 4, discriminator loss: 2.1598048210144043, generator loss: 0.17542780935764313
Epoch: 5, discriminator loss: 2.2429144382476807, generator loss: 0.150535449385643
Epoch: 6, discriminator loss: 2.1370797157287598, generator loss: 0.1514035165309906
Epoch: 7, discriminator loss: 2.2357523441314697, generator loss: 0.16957883536815643
Epoch: 8, discriminator loss: 2.0795035362243652, generator loss: 0.17474763095378876
Epoch: 9, discriminator loss: 2.0290422439575195, generator loss: 0.18798455595970154
Epoch: 10, discriminator loss: 1.9893221855163574, generator loss: 0.20700402557849884
Epoch: 11, discriminator loss: 1.92893648147583, generato

In [8]:
X_generated = tf.nn.sigmoid(G_logit(sample_Z(10, n_Z_features)))

In [9]:
sess.run(X_generated)

array([[ 0.3915219 ,  0.19932577,  0.42266262,  0.39455956],
       [ 0.65022224,  0.53885132,  0.84583104,  0.73324174],
       [ 0.28660852,  0.23787358,  0.4005262 ,  0.43196619],
       [ 0.40539929,  0.32836661,  0.56841642,  0.64633626],
       [ 0.41151437,  0.39764419,  0.59860539,  0.50286949],
       [ 0.21944591,  0.22783263,  0.304773  ,  0.38222897],
       [ 0.33324188,  0.26477072,  0.44035298,  0.16930234],
       [ 0.51923728,  0.38239533,  0.6946941 ,  0.63129276],
       [ 0.34728634,  0.27378714,  0.43921995,  0.11173232],
       [ 0.340556  ,  0.32398549,  0.42209229,  0.36254349]], dtype=float32)