In [1]:
import numpy as np
import tensorflow as tf
from sklearn import datasets
from sklearn.preprocessing import minmax_scale, LabelBinarizer

In [2]:
ds = datasets.load_iris()
X_train, y_train = minmax_scale(ds.data), LabelBinarizer().fit_transform(ds.target)
training_data = np.concatenate([X_train, y_train], axis = 1)

In [3]:
n_X_samples = X_train.shape[0]
n_X_features = X_train.shape[1]
n_classes = y_train.shape[1]
n_Z_features = 10

In [4]:
y = tf.placeholder(tf.float32, shape=[None, n_classes])

X = tf.placeholder(tf.float32, shape=[None, n_X_features], name='X')
D_W1 = tf.Variable(tf.random_uniform([n_X_features + n_classes, 6]), name='D_W1')
D_b1 = tf.Variable(tf.random_uniform([6]), name='D_b1')
D_W2 = tf.Variable(tf.random_uniform([6, 1]), name='D_W2')
D_b2 = tf.Variable(tf.random_uniform([1]), name='D_b2')
D_parameters = [D_W1, D_W2, D_b1, D_b2]
def D_logit(X, y):
    inputs = tf.concat(axis=1, values=[X, y])
    D_h1 = tf.nn.tanh(tf.matmul(inputs, D_W1) + D_b1)
    return tf.matmul(D_h1, D_W2) + D_b2

Z = tf.placeholder(tf.float32, shape=[None, n_Z_features], name='Z')
G_W1 = tf.Variable(tf.random_uniform([n_Z_features + n_classes, 6]), name='G_W1')
G_b1 = tf.Variable(tf.random_uniform([6]), name='G_b1')
G_W2 = tf.Variable(tf.random_uniform([6, n_X_features]), name='G_W2')
G_b2 = tf.Variable(tf.random_uniform([n_X_features]), name='G_b2')
G_parameters = [G_W1, G_W2, G_b1, G_b2]
def G_logit(Z, y):
    inputs = tf.concat(axis=1, values=[Z, y])
    G_h1 = tf.nn.tanh(tf.matmul(inputs, G_W1) + G_b1)
    return tf.matmul(G_h1, G_W2) + G_b2

def sample_Z(n_samples, n_features):
    return np.random.uniform(-1., 1., size=[n_samples, n_features]).astype(np.float32)

def sample_y(n_samples, n_classes, class_label):
    output = np.zeros(shape=[n_samples, n_classes]).astype(np.float32)
    output[:, class_label] = 1.
    return output

In [5]:
D_logit_data = D_logit(X, y)
D_logit_generated = D_logit(tf.nn.sigmoid(G_logit(Z, y)), y)

D_loss_data = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_data, labels=tf.ones_like(D_logit_data)))
D_loss_generated = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_generated, labels=tf.zeros_like(D_logit_generated)))
D_loss = D_loss_data + D_loss_generated

G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_generated, labels=tf.ones_like(D_logit_generated)))

In [6]:
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=D_parameters)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=G_parameters)

In [8]:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

batch_size = 5
n_batches = int(n_X_samples / batch_size)
for epoch in range(500):
    epoch_permuted_indices = np.random.permutation(range(n_X_samples))
    X_epoch = X_train[epoch_permuted_indices]
    y_epoch = y_train[epoch_permuted_indices]
    for batch_index in range(n_batches):
        start_index = batch_index * batch_size
        end_index = start_index + batch_size
        X_batch = X_epoch[start_index:end_index]
        y_batch = y_epoch[start_index:end_index]
        _, D_loss_value = sess.run([D_solver, D_loss], feed_dict={X: X_batch, Z: sample_Z(batch_size, n_Z_features), y: y_batch})
        _, G_loss_value = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(batch_size, n_Z_features), y: y_batch})
    print('Epoch: {}, discriminator loss: {}, generator loss: {}'.format(epoch, D_loss_value, G_loss_value))

Epoch: 0, discriminator loss: 4.202922344207764, generator loss: 0.01588446833193302
Epoch: 1, discriminator loss: 3.9094738960266113, generator loss: 0.018823150545358658
Epoch: 2, discriminator loss: 3.8059606552124023, generator loss: 0.022607285529375076
Epoch: 3, discriminator loss: 3.639249801635742, generator loss: 0.027837589383125305
Epoch: 4, discriminator loss: 3.427393674850464, generator loss: 0.03814699873328209
Epoch: 5, discriminator loss: 3.221202850341797, generator loss: 0.04424634948372841
Epoch: 6, discriminator loss: 2.9936392307281494, generator loss: 0.05537525936961174
Epoch: 7, discriminator loss: 2.7353363037109375, generator loss: 0.0749804899096489
Epoch: 8, discriminator loss: 2.517587661743164, generator loss: 0.09529682993888855
Epoch: 9, discriminator loss: 2.1044423580169678, generator loss: 0.1592247188091278
Epoch: 10, discriminator loss: 1.7798351049423218, generator loss: 0.24169805645942688
Epoch: 11, discriminator loss: 1.6325829029083252, genera

In [18]:
X_generated = tf.nn.sigmoid(G_logit(sample_Z(10, n_Z_features), sample_y(10, n_classes, 2)))

In [19]:
sess.run(X_generated)

array([[ 0.40602136,  0.27141324,  0.8307209 ,  0.57854921],
       [ 0.46065405,  0.30945   ,  0.78602254,  0.65364873],
       [ 0.44015408,  0.277385  ,  0.80347764,  0.60472214],
       [ 0.38755029,  0.28218868,  0.82850897,  0.5852614 ],
       [ 0.4099665 ,  0.25503406,  0.8052249 ,  0.59090215],
       [ 0.31517851,  0.22329536,  0.83608657,  0.50601941],
       [ 0.46828321,  0.30253586,  0.78813523,  0.65890127],
       [ 0.51689166,  0.41220194,  0.83162063,  0.72102773],
       [ 0.67201138,  0.53954577,  0.79447925,  0.80900848],
       [ 0.33292016,  0.24067263,  0.85144645,  0.53288835]], dtype=float32)

In [17]:
X_train[ds.target == 2]

array([[ 0.55555556,  0.54166667,  0.84745763,  1.        ],
       [ 0.41666667,  0.29166667,  0.69491525,  0.75      ],
       [ 0.77777778,  0.41666667,  0.83050847,  0.83333333],
       [ 0.55555556,  0.375     ,  0.77966102,  0.70833333],
       [ 0.61111111,  0.41666667,  0.81355932,  0.875     ],
       [ 0.91666667,  0.41666667,  0.94915254,  0.83333333],
       [ 0.16666667,  0.20833333,  0.59322034,  0.66666667],
       [ 0.83333333,  0.375     ,  0.89830508,  0.70833333],
       [ 0.66666667,  0.20833333,  0.81355932,  0.70833333],
       [ 0.80555556,  0.66666667,  0.86440678,  1.        ],
       [ 0.61111111,  0.5       ,  0.69491525,  0.79166667],
       [ 0.58333333,  0.29166667,  0.72881356,  0.75      ],
       [ 0.69444444,  0.41666667,  0.76271186,  0.83333333],
       [ 0.38888889,  0.20833333,  0.6779661 ,  0.79166667],
       [ 0.41666667,  0.33333333,  0.69491525,  0.95833333],
       [ 0.58333333,  0.5       ,  0.72881356,  0.91666667],
       [ 0.61111111,  0.