In [1]:
import os
from datetime import datetime


PROJECT_ROOT_DIR = "/Users/gopora/MyStuff/Dev/Workspaces/Sandbox/TroubledLife"
DATASETS_DIR = os.path.join(PROJECT_ROOT_DIR, "data")
TF_LOG_DIR = os.path.join(PROJECT_ROOT_DIR, "tf_logs")
TRAINING_SET_DATA_FILE = "troubled_life_policy_train_data.csv"
TEST_SET_DATA_FILE = "troubled_life_policy_test_data.csv"

now = datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = "{}/run-{}/".format(TF_LOG_DIR, now)

In [2]:
import data_preparation as dp

#dp.generate_troubled_life_policy_data(no_of_policies=10000, runtime=5, file_path=os.path.join(DATASETS_DIR, TRAINING_SET_DATA_FILE))

#dp.generate_troubled_life_policy_data(no_of_policies=2000, runtime=5, file_path=os.path.join(DATASETS_DIR, TEST_SET_DATA_FILE))

policy_histories_train = \
    dp.load_troubled_life_policy_data(file_path=os.path.join(DATASETS_DIR, TRAINING_SET_DATA_FILE))

policy_histories_test = \
    dp.load_troubled_life_policy_data(file_path=os.path.join(DATASETS_DIR, TEST_SET_DATA_FILE))

policy_histories_length_train, max_policy_history_length_train = \
    dp.get_policy_history_lengths(policy_histories=policy_histories_train)

policy_histories_length_test, max_policy_history_length_test = \
    dp.get_policy_history_lengths(policy_histories=policy_histories_test)

max_policy_history_length = max(max_policy_history_length_train, max_policy_history_length_test)

In [3]:
# Pad the histories up to maximum length of both, train and test set

# policy_histories_train = \
#     dp.pad_troubled_life_policy_histories(policy_histories=policy_histories_train,
#                                           policy_histories_lengths=policy_histories_length_train,
#                                           max_policy_history_length=max_policy_history_length)
# 
# policy_histories_test = \
#     dp.pad_troubled_life_policy_histories(policy_histories=policy_histories_test,
#                                           policy_histories_lengths=policy_histories_length_test,
#                                           max_policy_history_length=max_policy_history_length)
# 
# # # Save padded data, since always generating and padding takes too long
# policy_histories_train.to_csv(path_or_buf=os.path.join(DATASETS_DIR, TRAINING_SET_DATA_FILE))
# policy_histories_test.to_csv(path_or_buf=os.path.join(DATASETS_DIR, TEST_SET_DATA_FILE))

# Extract features and labels from dataset as numpy.ndarray(s)
train_labels, train_features, train_seq_lengths =\
    dp.prepare_labels_features_lengths(policy_histories=policy_histories_train,
                                       policy_histories_lengths=policy_histories_length_train, 
                                       max_policy_history_length=max_policy_history_length)
test_labels, test_features, test_seq_lengths =\
    dp.prepare_labels_features_lengths(policy_histories=policy_histories_test,
                                       policy_histories_lengths=policy_histories_length_test, 
                                       max_policy_history_length=max_policy_history_length)

train_data = dp.TrainDataSet(train_labels=train_labels, train_features=train_features, train_seq_lengths=train_seq_lengths)

In [4]:
import tensorflow as tf
import gan as gan


g_learning_rate = 0.001
d_real_learning_rate = 0.0001
d_fake_learning_rate = 0.0001
noise_dimensions = 100
batch_size = 200
g_n_outputs = 2
d_n_inputs = 2

tf.reset_default_graph()
tf.set_random_seed(42)

noise = tf.placeholder(tf.float32, [None, max_policy_history_length, noise_dimensions], name='noise')
noise_seq_length = tf.placeholder(tf.int32, [None], name='noise_seq_length')

X = tf.placeholder(tf.float32, [None, max_policy_history_length, d_n_inputs], name="X")
y = tf.placeholder(tf.int32, [None], name="y")
X_seq_length = tf.placeholder(tf.int32, [None], name="X_seq_length")

g_data = gan.generator(noise=noise, n_outputs=g_n_outputs, seq_length=noise_seq_length)

logits_d_for_X, y_pred_X = gan.discriminator(X=X, seq_length=X_seq_length, n_outputs=max_policy_history_length)
logits_d_for_g, y_pred_g = gan.discriminator(X=g_data, seq_length=noise_seq_length, n_outputs=max_policy_history_length, reuse=True)

g_loss, accuracy_g = gan.generator_loss(batch_size=batch_size, logits=logits_d_for_g)
d_loss_real, accuracy_d_real = gan.discriminator_loss_real(batch_size=batch_size, logits=logits_d_for_X, y=y)
d_loss_fake, accuracy_d_fake = gan.discriminator_loss_fake(batch_size=batch_size, logits=logits_d_for_g)

g_trainer = gan.generator_trainer(g_learning_rate, g_loss)
d_trainer_real = gan.discriminator_trainer_real(d_real_learning_rate, d_loss_real)
d_trainer_fake = gan.discriminator_trainer_fake(d_fake_learning_rate, d_loss_fake)

tf.summary.scalar('Generator_loss', g_loss)
tf.summary.scalar('Discriminator_loss_real', d_loss_real)
tf.summary.scalar('Discriminator_loss_fake', d_loss_fake)

merged = tf.summary.merge_all()
writer = tf.summary.FileWriter(log_dir, tf.get_default_graph())

In [6]:
import numpy as np


sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

n_epochs = 20

# Pre-train discriminator
for epoch in range(n_epochs):
    for batch in range(train_data.num_examples // batch_size):
        noise_batch = np.random.normal(0, 1, size=[batch_size, max_policy_history_length, noise_dimensions])
        noise_seq_length_batch = np.full(batch_size, max_policy_history_length)
    
        y_batch, X_batch, X_seq_length_batch = train_data.next_batch(batch_size)
    
        _, __, dLossReal, dLossFake, acc_real_train, acc_fake_train = \
            sess.run([d_trainer_real, d_trainer_fake, d_loss_real, d_loss_fake, accuracy_d_real, accuracy_d_fake],
                     {X: X_batch, y: y_batch, X_seq_length: X_seq_length_batch,
                      noise: noise_batch, noise_seq_length: noise_seq_length_batch})

    print("Epoch:", epoch, "dLossReal:", dLossReal, "acc_real_train:", acc_real_train,
          "dLossFake:", dLossFake, "acc_fake_train:", acc_fake_train)

Epoch: 0 dLossReal: 51.1334 acc_real_train: 0.745 dLossFake: 0.409864 acc_fake_train: 1.0


Epoch: 1 dLossReal: 6.80715 acc_real_train: 0.79 dLossFake: 0.000215351 acc_fake_train: 1.0


Epoch: 2 dLossReal: 4.47854 acc_real_train: 0.795 dLossFake: 0.000124225 acc_fake_train: 1.0


Epoch: 3 dLossReal: 1.88812 acc_real_train: 0.865 dLossFake: 8.85739e-05 acc_fake_train: 1.0


Epoch: 4 dLossReal: 5.08818 acc_real_train: 0.705 dLossFake: 0.000221516 acc_fake_train: 1.0


Epoch: 5 dLossReal: 5.6872 acc_real_train: 0.825 dLossFake: 1.53113e-05 acc_fake_train: 1.0


Epoch: 6 dLossReal: 3.30148 acc_real_train: 0.86 dLossFake: 3.69425e-06 acc_fake_train: 1.0


Epoch: 7 dLossReal: 7.55907 acc_real_train: 0.895 dLossFake: 7.15254e-07 acc_fake_train: 1.0


Epoch: 8 dLossReal: 3.51052 acc_real_train: 0.8 dLossFake: 4.01735e-07 acc_fake_train: 1.0


Epoch: 9 dLossReal: 3.05524 acc_real_train: 0.85 dLossFake: 2.87294e-07 acc_fake_train: 1.0


Epoch: 10 dLossReal: 2.72455 acc_real_train: 0.82 dLossFake: 1.36494e-07 acc_fake_train: 1.0


Epoch: 11 dLossReal: 0.739919 acc_real_train: 0.915 dLossFake: 8.64267e-08 acc_fake_train: 1.0


Epoch: 12 dLossReal: 1.84591 acc_real_train: 0.875 dLossFake: 4.11272e-08 acc_fake_train: 1.0


Epoch: 13 dLossReal: 1.39556 acc_real_train: 0.83 dLossFake: 5.78164e-08 acc_fake_train: 1.0


Epoch: 14 dLossReal: 2.68751 acc_real_train: 0.71 dLossFake: 1.18017e-07 acc_fake_train: 1.0


Epoch: 15 dLossReal: 1.6193 acc_real_train: 0.895 dLossFake: 3.9339e-08 acc_fake_train: 1.0


Epoch: 16 dLossReal: 1.16044 acc_real_train: 0.925 dLossFake: 4.35113e-08 acc_fake_train: 1.0


Epoch: 17 dLossReal: 1.13205 acc_real_train: 0.88 dLossFake: 2.02656e-08 acc_fake_train: 1.0


Epoch: 18 dLossReal: 0.873133 acc_real_train: 0.91 dLossFake: 1.90735e-08 acc_fake_train: 1.0


Epoch: 19 dLossReal: 0.656374 acc_real_train: 0.92 dLossFake: 9.53674e-09 acc_fake_train: 1.0


In [7]:
#Train generator and discriminator together
n_epochs = 200

# Pre-train discriminator
for epoch in range(n_epochs):
    for batch in range(train_data.num_examples // batch_size):
        noise_batch = np.random.normal(0, 1, size=[batch_size, max_policy_history_length, noise_dimensions])
        noise_seq_length_batch = np.full(batch_size, max_policy_history_length)
        
        y_batch, X_batch, X_seq_length_batch = train_data.next_batch(batch_size)
    
        # Train discriminator on both real and fake data
        _, __, dLossReal, dLossFake, acc_real_train, acc_fake_train = \
            sess.run([d_trainer_real, d_trainer_fake, d_loss_real, d_loss_fake, accuracy_d_real, accuracy_d_fake],
                     {X: X_batch, y: y_batch, X_seq_length: X_seq_length_batch,
                      noise: noise_batch, noise_seq_length: noise_seq_length_batch})
    
        # Train generator
        noise_batch = np.random.normal(0, 1, size=[batch_size, max_policy_history_length, noise_dimensions])
    
        _, gLoss, acc_g = sess.run([g_trainer, g_loss, accuracy_g], feed_dict={noise: noise_batch, noise_seq_length: noise_seq_length_batch})

    print("Epoch:", epoch, "dLossReal:", dLossReal, "acc_real_train:", acc_real_train,
          "dLossFake:", dLossFake, "acc_fake_train:", acc_fake_train,
          "gLoss:", gLoss, "acc_g:", acc_g)

Epoch: 0 dLossReal: 289.359 acc_real_train: 0.52 dLossFake: 0.0087164 acc_fake_train: 1.0 gLoss: 6.62249 acc_g: 0.0


Epoch: 1 dLossReal: 146.469 acc_real_train: 0.32 dLossFake: 0.00421564 acc_fake_train: 1.0 gLoss: 8.24388 acc_g: 0.0


Epoch: 2 dLossReal: 534.464 acc_real_train: 0.0 dLossFake: 0.00100036 acc_fake_train: 1.0 gLoss: 12.2682 acc_g: 0.0


Epoch: 3 dLossReal: 127.121 acc_real_train: 0.595 dLossFake: 0.273507 acc_fake_train: 1.0 gLoss: 2.5628 acc_g: 0.0


Epoch: 4 dLossReal: 146.882 acc_real_train: 0.835 dLossFake: 0.04842 acc_fake_train: 1.0 gLoss: 5.12982 acc_g: 0.0


Epoch: 5 dLossReal: 28.913 acc_real_train: 0.815 dLossFake: 0.00274184 acc_fake_train: 1.0 gLoss: 7.38817 acc_g: 0.0


Epoch: 6 dLossReal: 47.278 acc_real_train: 0.76 dLossFake: 0.0956574 acc_fake_train: 1.0 gLoss: 3.96549 acc_g: 0.0


Epoch: 7 dLossReal: 90.3273 acc_real_train: 0.305 dLossFake: 0.845973 acc_fake_train: 0.96 gLoss: 1.58679 acc_g: 0.0


Epoch: 8 dLossReal: 119.156 acc_real_train: 0.75 dLossFake: 0.274606 acc_fake_train: 1.0 gLoss: 1.70859 acc_g: 0.0


Epoch: 9 dLossReal: 264.747 acc_real_train: 0.78 dLossFake: 0.0831777 acc_fake_train: 1.0 gLoss: 3.39014 acc_g: 0.0


Epoch: 10 dLossReal: 24.3267 acc_real_train: 0.62 dLossFake: 0.00386207 acc_fake_train: 1.0 gLoss: 6.04579 acc_g: 0.0


Epoch: 11 dLossReal: 13.2893 acc_real_train: 0.825 dLossFake: 0.00301028 acc_fake_train: 1.0 gLoss: 6.3843 acc_g: 0.0


Epoch: 12 dLossReal: 28.918 acc_real_train: 0.765 dLossFake: 0.00402868 acc_fake_train: 1.0 gLoss: 5.72852 acc_g: 0.0


Epoch: 13 dLossReal: 19.9743 acc_real_train: 0.535 dLossFake: 0.00150587 acc_fake_train: 1.0 gLoss: 7.08392 acc_g: 0.0


Epoch: 14 dLossReal: 44.4186 acc_real_train: 0.715 dLossFake: 4.11272e-08 acc_fake_train: 1.0 gLoss: 19.584 acc_g: 0.0


Epoch: 15 dLossReal: 166.652 acc_real_train: 0.29 dLossFake: 0.110001 acc_fake_train: 0.99 gLoss: 9.59638 acc_g: 0.0


Epoch: 16 dLossReal: 285.719 acc_real_train: 0.64 dLossFake: 1.01188 acc_fake_train: 0.395 gLoss: 4.26449 acc_g: 0.0


Epoch: 17 dLossReal: 258.54 acc_real_train: 0.26 dLossFake: 0.0264506 acc_fake_train: 1.0 gLoss: 3.85334 acc_g: 0.0


Epoch: 18 dLossReal: 83.5938 acc_real_train: 0.725 dLossFake: 0.0562896 acc_fake_train: 1.0 gLoss: 4.19819 acc_g: 0.0


Epoch: 19 dLossReal: 90.2265 acc_real_train: 0.085 dLossFake: 2.59216 acc_fake_train: 0.035 gLoss: 3.66936 acc_g: 0.0


Epoch: 20 dLossReal: 165.035 acc_real_train: 0.765 dLossFake: 0.108814 acc_fake_train: 1.0 gLoss: 3.4669 acc_g: 0.0


Epoch: 21 dLossReal: 23.4775 acc_real_train: 0.815 dLossFake: 0.0401281 acc_fake_train: 1.0 gLoss: 3.96012 acc_g: 0.0


Epoch: 22 dLossReal: 53.6933 acc_real_train: 0.82 dLossFake: 0.0460031 acc_fake_train: 1.0 gLoss: 4.15747 acc_g: 0.0


Epoch: 23 dLossReal: 54.0943 acc_real_train: 0.87 dLossFake: 0.0612983 acc_fake_train: 1.0 gLoss: 5.18882 acc_g: 0.0


Epoch: 24 dLossReal: 361.451 acc_real_train: 0.46 dLossFake: 11.1722 acc_fake_train: 0.0 gLoss: 26.2064 acc_g: 0.0


Epoch: 25 dLossReal: 98.217 acc_real_train: 0.775 dLossFake: 0.081565 acc_fake_train: 1.0 gLoss: 4.11268 acc_g: 0.0


Epoch: 26 dLossReal: 105.79 acc_real_train: 0.74 dLossFake: 0.424782 acc_fake_train: 1.0 gLoss: 2.27827 acc_g: 0.0


Epoch: 27 dLossReal: 89.6001 acc_real_train: 0.83 dLossFake: 1.66893e-08 acc_fake_train: 1.0 gLoss: 51.3917 acc_g: 0.0


Epoch: 28 dLossReal: 58.5684 acc_real_train: 0.805 dLossFake: 0.0103195 acc_fake_train: 1.0 gLoss: 6.60099 acc_g: 0.0


Epoch: 29 dLossReal: 47.2634 acc_real_train: 0.805 dLossFake: 0.134533 acc_fake_train: 1.0 gLoss: 4.44234 acc_g: 0.0


Epoch: 30 dLossReal: 976.843 acc_real_train: 0.015 dLossFake: 15.1613 acc_fake_train: 0.0 gLoss: 0.0201166 acc_g: 1.0


Epoch: 31 dLossReal: 304.101 acc_real_train: 0.325 dLossFake: 0.609966 acc_fake_train: 0.695 gLoss: 1.71747 acc_g: 0.295


Epoch: 32 dLossReal: 193.379 acc_real_train: 0.635 dLossFake: 1.19969 acc_fake_train: 0.5 gLoss: 1.23531 acc_g: 0.56


Epoch: 33 dLossReal: 474.366 acc_real_train: 0.765 dLossFake: 0.0 acc_fake_train: 1.0 gLoss: 297.971 acc_g: 0.0


Epoch: 34 dLossReal: 130.272 acc_real_train: 0.78 dLossFake: 0.908399 acc_fake_train: 0.675 gLoss: 3.18384 acc_g: 0.0


Epoch: 35 dLossReal: 3990.96 acc_real_train: 0.765 dLossFake: 312.687 acc_fake_train: 0.0 gLoss: 57.909 acc_g: 0.0


Epoch: 36 dLossReal: 1581.01 acc_real_train: 0.04 dLossFake: 3.0962e-06 acc_fake_train: 1.0 gLoss: 0.503501 acc_g: 0.955


Epoch: 37 dLossReal: 1128.94 acc_real_train: 0.745 dLossFake: 61.8128 acc_fake_train: 0.0 gLoss: 1.74677 acc_g: 0.415


Epoch: 38 dLossReal: 6871.71 acc_real_train: 0.04 dLossFake: 63.8726 acc_fake_train: 0.02 gLoss: 480.328 acc_g: 0.0


Epoch: 39 dLossReal: 779.72 acc_real_train: 0.74 dLossFake: 15.261 acc_fake_train: 0.0 gLoss: 1.32734 acc_g: 0.36


Epoch: 40 dLossReal: 1170.7 acc_real_train: 0.83 dLossFake: 73.0681 acc_fake_train: 0.0 gLoss: 2.40221e-06 acc_g: 1.0


Epoch: 41 dLossReal: 5090.41 acc_real_train: 0.03 dLossFake: 2.98023e-09 acc_fake_train: 1.0 gLoss: 65.6723 acc_g: 0.0


Epoch: 42 dLossReal: 1934.34 acc_real_train: 0.715 dLossFake: 5.17828 acc_fake_train: 0.805 gLoss: 65.0482 acc_g: 0.105


Epoch: 43 dLossReal: 613.794 acc_real_train: 0.465 dLossFake: 2.4775 acc_fake_train: 0.285 gLoss: 0.716929 acc_g: 0.775


Epoch: 44 dLossReal: 525.738 acc_real_train: 0.4 dLossFake: 23.3414 acc_fake_train: 0.0 gLoss: 0.172239 acc_g: 0.95


Epoch: 45 dLossReal: 7883.4 acc_real_train: 0.805 dLossFake: 0.0 acc_fake_train: 1.0 gLoss: 505.935 acc_g: 0.0


Epoch: 46 dLossReal: 558.54 acc_real_train: 0.81 dLossFake: 2.07889 acc_fake_train: 0.6 gLoss: 17.6285 acc_g: 0.0


Epoch: 47 dLossReal: 693.806 acc_real_train: 0.565 dLossFake: 47.1656 acc_fake_train: 0.0 gLoss: 0.186122 acc_g: 0.925


Epoch: 48 dLossReal: 1391.24 acc_real_train: 0.5 dLossFake: 202.649 acc_fake_train: 0.0 gLoss: 0.0 acc_g: 1.0


Epoch: 49 dLossReal: 2218.18 acc_real_train: 0.74 dLossFake: 232.218 acc_fake_train: 0.0 gLoss: 0.0951066 acc_g: 0.99


Epoch: 50 dLossReal: 18408.5 acc_real_train: 0.215 dLossFake: 471.77 acc_fake_train: 0.005 gLoss: 1.8623e-06 acc_g: 1.0


Epoch: 51 dLossReal: 5461.01 acc_real_train: 0.085 dLossFake: 49.7633 acc_fake_train: 0.0 gLoss: 3.308 acc_g: 0.125


Epoch: 52 dLossReal: 2014.46 acc_real_train: 0.71 dLossFake: 133.607 acc_fake_train: 0.105 gLoss: 2.30017 acc_g: 0.8


Epoch: 53 dLossReal: 1503.75 acc_real_train: 0.715 dLossFake: 0.936866 acc_fake_train: 0.81 gLoss: 1.79226 acc_g: 0.17


Epoch: 54 dLossReal: 517.911 acc_real_train: 0.09 dLossFake: 3.05522 acc_fake_train: 0.025 gLoss: 0.443494 acc_g: 0.98


Epoch: 55 dLossReal: 683.692 acc_real_train: 0.725 dLossFake: 29.7128 acc_fake_train: 0.0 gLoss: 0.003016 acc_g: 1.0


Epoch: 56 dLossReal: 609.574 acc_real_train: 0.495 dLossFake: 63.4401 acc_fake_train: 0.0 gLoss: 7.77079e-06 acc_g: 1.0


Epoch: 57 dLossReal: 1602.71 acc_real_train: 0.695 dLossFake: 214.647 acc_fake_train: 0.0 gLoss: 5.60282e-08 acc_g: 1.0


Epoch: 58 dLossReal: 2380.69 acc_real_train: 0.52 dLossFake: 2694.65 acc_fake_train: 0.0 gLoss: 0.0 acc_g: 1.0


Epoch: 59 dLossReal: 18879.4 acc_real_train: 0.005 dLossFake: 11411.0 acc_fake_train: 0.0 gLoss: 0.0 acc_g: 1.0


Epoch: 60 dLossReal: 42277.8 acc_real_train: 0.745 dLossFake: 3508.36 acc_fake_train: 0.0 gLoss: 0.0 acc_g: 1.0


Epoch: 61 dLossReal: 20570.3 acc_real_train: 0.78 dLossFake: 6.44305e-07 acc_fake_train: 1.0 gLoss: 249.848 acc_g: 0.0


Epoch: 62 dLossReal: 164403.0 acc_real_train: 0.0 dLossFake: 53509.8 acc_fake_train: 0.0 gLoss: 8055.8 acc_g: 0.0


Epoch: 63 dLossReal: 83650.5 acc_real_train: 0.0 dLossFake: 0.206268 acc_fake_train: 0.995 gLoss: 274.853 acc_g: 0.02


Epoch: 64 dLossReal: 6606.36 acc_real_train: 0.76 dLossFake: 7.27321 acc_fake_train: 0.27 gLoss: 2.89245 acc_g: 0.75


Epoch: 65 dLossReal: 2210.96 acc_real_train: 0.72 dLossFake: 149.885 acc_fake_train: 0.04 gLoss: 10.334 acc_g: 0.905


Epoch: 66 dLossReal: 2022.11 acc_real_train: 0.335 dLossFake: 518.828 acc_fake_train: 0.0 gLoss: 0.0 acc_g: 1.0


Epoch: 67 dLossReal: 2078.09 acc_real_train: 0.745 dLossFake: 619.673 acc_fake_train: 0.0 gLoss: 0.0 acc_g: 1.0


Epoch: 68 dLossReal: 2056.71 acc_real_train: 0.81 dLossFake: 2066.78 acc_fake_train: 0.0 gLoss: 0.0 acc_g: 1.0


Epoch: 69 dLossReal: 109121.0 acc_real_train: 0.755 dLossFake: 1.42084e+06 acc_fake_train: 0.0 gLoss: 0.0 acc_g: 1.0


Epoch: 70 dLossReal: 219683.0 acc_real_train: 0.755 dLossFake: 5564.44 acc_fake_train: 0.035 gLoss: 3667.33 acc_g: 0.0


KeyboardInterrupt: 