In [1]:
import os
from datetime import datetime

PROJECT_ROOT_DIR = "/Users/gopora/MyStuff/Dev/Workspaces/Sandbox/TroubledLife"
DATASETS_DIR = os.path.join(PROJECT_ROOT_DIR, "data")
TF_LOG_DIR = os.path.join(PROJECT_ROOT_DIR, "tf_logs")
MODEL_CHECKPOINTS_DIR = os.path.join(PROJECT_ROOT_DIR, "model_checkpoints")
TRAINING_SET_DATA_FILE = "troubled_life_policy_train_data.csv"
TEST_SET_DATA_FILE = "troubled_life_policy_test_data.csv"

now = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "{}/run-{}/".format(TF_LOG_DIR, now)


import data_preparation as dp

#dp.generate_troubled_life_policy_data(no_of_policies=10000, runtime=5, file_path=os.path.join(DATASETS_DIR, TRAINING_SET_DATA_FILE))

#dp.generate_troubled_life_policy_data(no_of_policies=2000, runtime=5, file_path=os.path.join(DATASETS_DIR, TEST_SET_DATA_FILE))

policy_histories_train = \
    dp.load_troubled_life_policy_data(file_path=os.path.join(DATASETS_DIR, TRAINING_SET_DATA_FILE))

policy_histories_test = \
    dp.load_troubled_life_policy_data(file_path=os.path.join(DATASETS_DIR, TEST_SET_DATA_FILE))

policy_histories_length_train, max_policy_history_length_train = \
    dp.get_policy_history_lengths(policy_histories=policy_histories_train)

policy_histories_length_test, max_policy_history_length_test = \
    dp.get_policy_history_lengths(policy_histories=policy_histories_test)

max_length_policy_history = max(max_policy_history_length_train, max_policy_history_length_test)



# Pad the histories up to maximum length of both, train and test set

# policy_histories_train = \
#     dp.pad_troubled_life_policy_histories(policy_histories=policy_histories_train,
#                                           policy_histories_lengths=policy_histories_length_train,
#                                           max_policy_history_length=max_policy_history_length)
# 
# policy_histories_test = \
#     dp.pad_troubled_life_policy_histories(policy_histories=policy_histories_test,
#                                           policy_histories_lengths=policy_histories_length_test,
#                                           max_policy_history_length=max_policy_history_length)
# 
# # # Save padded data, since always generating and padding takes too long
# policy_histories_train.to_csv(path_or_buf=os.path.join(DATASETS_DIR, TRAINING_SET_DATA_FILE))
# policy_histories_test.to_csv(path_or_buf=os.path.join(DATASETS_DIR, TEST_SET_DATA_FILE))

# Extract features and labels from dataset as numpy.ndarray(s)
binary_classification = True

train_labels, train_features, train_seq_lengths =\
    dp.prepare_labels_features_lengths(policy_histories=policy_histories_train,
                                       policy_histories_lengths=policy_histories_length_train, 
                                       max_policy_history_length=max_length_policy_history,
                                       binary_classification=binary_classification)
test_labels, test_features, test_seq_lengths =\
    dp.prepare_labels_features_lengths(policy_histories=policy_histories_test,
                                       policy_histories_lengths=policy_histories_length_test, 
                                       max_policy_history_length=max_length_policy_history,
                                       binary_classification=binary_classification)

train_data = dp.TrainDataSet(train_labels=train_labels, train_features=train_features, train_seq_lengths=train_seq_lengths)

In [2]:
import tensorflow as tf
import gan as gan


learning_rate_g = 0.01
learning_rate_d = 0.0001
size_batch = 200
n_inputs_g = 2
n_outputs_g = 2
n_inputs_d = 2
n_outputs_d = 2 if binary_classification else max_length_policy_history

tf.reset_default_graph()
tf.set_random_seed(42)

Z = tf.placeholder(tf.float32, [size_batch, max_length_policy_history, n_inputs_g], name='Z')
seq_length_z = tf.placeholder(tf.int32, [None], name='seq_length_z')

X = tf.placeholder(tf.float32, [None, max_length_policy_history, n_inputs_d], name="X")
y = tf.placeholder(tf.int32, [None], name="y")
seq_length_x = tf.placeholder(tf.int32, [None], name="seq_length_x")

Gz = gan.generator(Z=Z, n_outputs=n_outputs_g, seq_length=seq_length_z, leaky=False)

Dx, y_pred_x = gan.discriminator(X=X, seq_length=seq_length_x, n_outputs=n_outputs_d, leaky=False)
Dg, y_pred_g = gan.discriminator(X=Gz, seq_length=seq_length_z, n_outputs=n_outputs_d, leaky=False, reuse=True)

loss_g, accuracy_g = gan.generator_loss(size_batch=size_batch, Dg=Dg)
loss_real_d, accuracy_real_d = gan.discriminator_loss_real(Dx=Dx, y=y)
loss_fake_d, accuracy_fake_d = gan.discriminator_loss_fake(size_batch=size_batch, Dg=Dg)

g_trainer = gan.generator_trainer(learning_rate=learning_rate_g, loss=loss_g)
d_trainer = gan.discriminator_trainer_real(learning_rate_d, loss_real_d + loss_fake_d)

tf.summary.scalar('Generator_loss', loss_g)
tf.summary.scalar('Discriminator_loss_real', loss_real_d)
tf.summary.scalar('Discriminator_loss_fake', loss_fake_d)

merged = tf.summary.merge_all()
writer = tf.summary.FileWriter(log_dir, tf.get_default_graph())

tvars = tf.trainable_variables()

g_saver = tf.train.Saver(var_list=[var for var in tvars if "t_generator" in var.name])
d_saver = tf.train.Saver(var_list=[var for var in tvars if "t_discriminator" in var.name])

In [3]:
import numpy as np

np.set_printoptions(formatter={'float_kind': (lambda x: "%.2f" % x)})

sess = tf.Session()

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

In [4]:
n_epochs = 100

# Pre-train discriminator
for epoch in range(n_epochs):
    for batch in range(train_data.num_examples // size_batch):
        Z_batch = np.random.normal(0, 1, size=[size_batch, max_length_policy_history, n_inputs_g])
        seq_length_z_batch = np.full(size_batch, max_length_policy_history)

        y_batch, X_batch, seq_length_x_batch = train_data.next_batch(size_batch)

        _, lossRealD, lossFakeD, accRealD, accFakeD = \
            sess.run([d_trainer, loss_real_d, loss_fake_d, accuracy_real_d, accuracy_fake_d],
                     {X: X_batch, y: y_batch, seq_length_x: seq_length_x_batch,
                      Z: Z_batch, seq_length_z: seq_length_z_batch})

    print("Epoch:", epoch, "lossRealD:", lossRealD, "accRealD:", accRealD,
          "lossFakeD:", lossFakeD, "accFakeD:", accFakeD)

d_saver.save(sess, os.path.join(MODEL_CHECKPOINTS_DIR, "discriminator.ckpt"))

Epoch: 0 lossRealD: 0.560439 accRealD: 0.935 lossFakeD: 0.542731 accFakeD: 1.0


Epoch: 1 lossRealD: 0.531858 accRealD: 0.935 lossFakeD: 0.000115605 accFakeD: 1.0


Epoch: 2 lossRealD: 1.96263 accRealD: 0.61 lossFakeD: 1.07759e-05 accFakeD: 1.0


Epoch: 3 lossRealD: 0.724709 accRealD: 0.94 lossFakeD: 6.35264e-06 accFakeD: 1.0


Epoch: 4 lossRealD: 0.266904 accRealD: 0.975 lossFakeD: 3.8576e-06 accFakeD: 1.0


Epoch: 5 lossRealD: 1.12307 accRealD: 0.9 lossFakeD: 2.2012e-06 accFakeD: 1.0


Epoch: 6 lossRealD: 0.890741 accRealD: 0.95 lossFakeD: 1.0711e-06 accFakeD: 1.0


Epoch: 7 lossRealD: 0.275393 accRealD: 0.955 lossFakeD: 8.08835e-07 accFakeD: 1.0


KeyboardInterrupt: 

In [5]:
d_saver.restore(sess, os.path.join(MODEL_CHECKPOINTS_DIR, "discriminator.ckpt"))
#g_saver.restore(sess, os.path.join(MODEL_CHECKPOINTS_DIR, "generator.ckpt"))

# Train generator and discriminator together
n_epochs = 200

for epoch in range(n_epochs):
    for batch in range(train_data.num_examples // size_batch):
        Z_batch = np.random.normal(0, 1, size=[size_batch, max_length_policy_history, n_inputs_g])
        seq_length_z_batch = np.full(size_batch, max_length_policy_history)
        
        y_batch, X_batch, seq_length_x_batch = train_data.next_batch(size_batch)

        # Train discriminator on both real and fake data
        _, lossRealD, accRealD = \
            sess.run([d_trainer, loss_real_d, accuracy_real_d],
                     {X: X_batch, y: y_batch, seq_length_x: seq_length_x_batch,
                      Z: Z_batch, seq_length_z: seq_length_z_batch})

        Z_batch = np.random.normal(0, 1, size=[size_batch, max_length_policy_history, n_inputs_g])

        # Train generator    
        _, lossFakeD, accFakeD, lossG, accG, gData = \
            sess.run([g_trainer, loss_fake_d, accuracy_fake_d, loss_g, accuracy_g, Gz], 
                     feed_dict={Z: Z_batch, seq_length_z: seq_length_z_batch})

    print("Epoch:", epoch, 
          "lossRealD:", lossRealD, "accRealD:", accRealD,
          "lossFakeD:", lossFakeD, "accFakeD:", accFakeD,
          "lossG:", lossG, "accG:", accG)

    print(gData[0])

INFO:tensorflow:Restoring parameters from /Users/gopora/MyStuff/Dev/Workspaces/Sandbox/TroubledLife/model_checkpoints/discriminator.ckpt


Epoch: 0 lossRealD: 0.401523 accRealD: 0.98 lossFakeD: 0.55525 accFakeD: 1.0 lossG: 0.853194 accG: 0.0
[[0.92 0.85]
 [0.90 1.62]
 [2.35 3.48]
 [1.43 3.55]
 [3.20 5.13]
 [1.59 4.29]
 [2.89 4.27]
 [1.22 3.11]
 [2.76 4.30]
 [1.47 3.74]
 [3.36 5.42]
 [1.66 4.46]
 [3.01 4.46]
 [1.26 3.27]
 [2.55 3.79]]


Epoch: 1 lossRealD: 0.0536714 accRealD: 0.995 lossFakeD: 0.753085 accFakeD: 0.13 lossG: 0.642449 accG: 0.87
[[1.05 1.05]
 [1.68 2.39]
 [2.47 4.71]
 [4.44 7.41]
 [4.56 11.19]
 [9.22 16.12]
 [7.26 21.19]
 [15.45 28.07]
 [8.88 31.47]
 [20.06 37.97]
 [8.46 37.41]
 [19.94 39.19]
 [6.20 35.32]
 [16.33 32.85]
 [4.21 27.43]]


Epoch: 2 lossRealD: 0.0139686 accRealD: 0.99 lossFakeD: 0.308361 accFakeD: 1.0 lossG: 1.37213 accG: 0.0
[[2.08 1.91]
 [4.99 5.87]
 [7.51 14.33]
 [9.75 18.11]
 [5.89 21.54]
 [5.98 17.48]
 [7.69 10.34]
 [9.88 19.12]
 [21.98 28.28]
 [33.65 56.65]
 [25.45 84.41]
 [29.99 74.73]
 [9.11 56.25]
 [8.81 32.35]
 [31.74 19.84]]


Epoch: 3 lossRealD: 0.050236 accRealD: 0.99 lossFakeD: 0.169732 accFakeD: 1.0 lossG: 2.27809 accG: 0.0
[[3.49 3.65]
 [13.06 13.75]
 [26.03 17.72]
 [20.22 29.29]
 [24.76 42.09]
 [29.27 61.90]
 [24.21 77.49]
 [22.75 94.50]
 [37.48 111.93]
 [18.22 115.56]
 [34.83 123.48]
 [-9.43 30.06]
 [12.50 -10.28]
 [9.21 22.96]
 [1.15 26.38]]


Epoch: 4 lossRealD: 0.530823 accRealD: 0.95 lossFakeD: 0.290456 accFakeD: 1.0 lossG: 1.3894 accG: 0.0
[[4.51 4.15]
 [8.44 11.16]
 [8.35 16.73]
 [9.05 20.50]
 [10.45 26.62]
 [12.21 31.67]
 [10.51 32.07]
 [14.58 36.33]
 [11.53 31.60]
 [12.52 30.75]
 [8.82 20.56]
 [9.54 23.05]
 [7.93 19.67]
 [8.45 21.37]
 [8.18 19.02]]


KeyboardInterrupt: 

In [5]:
g_saver.save(sess, os.path.join(MODEL_CHECKPOINTS_DIR, "generator.ckpt"))

'/Users/gopora/MyStuff/Dev/Workspaces/Sandbox/TroubledLife/model_checkpoints/generator.ckpt'