In [1]:
import tensorflow as tf
import numpy as np
from discriminator.discriminator_wrapper import DiscriminatorWrapper
from generator.generator_wrapper import GeneratorWrapper, GeneratorSpec
from generator.generator_data import GeneratorData
from synthetic.target_data_generator import TargetDataGenerator
from synthetic.target_data import TargetData

In [2]:
def initialize_sess():
    global sess
    ruv = set(sess.run(tf.report_uninitialized_variables()))
    uv = [v for v in tf.global_variables() if v.name.split(':')[0].encode('ascii') in ruv]
    tf.variables_initializer(uv).run()
    
def reset_sess():
    global sess
    tf.reset_default_graph()
    sess.close()
    sess = tf.InteractiveSession()

def get_mean_reward(rewards):
    np_rewards = np.array(rewards)
    rows, columns = np.nonzero(np_rewards)
    indices = [i - 1 for i, val in enumerate(rows) if i > 0 and val != rows[i-1] or i == rows.shape[0]-1]
    final_rewards = np.zeros((len(indices)))
    for i, idx in enumerate(indices):
        final_rewards[i] = np_rewards[rows[idx], columns[idx]]
    return np.mean(final_rewards)
    

sess = tf.InteractiveSession()

In [3]:
vocab_size = 5000
embedding_dim = 32
hidden_dim = 32
seq_len = 20
image_feature_dim = 32
data_set_size = 20
batch_size = 512

embedding_arr = np.random.normal(size=[vocab_size, embedding_dim])

tdg = TargetDataGenerator(vocab_size, batch_size, embedding_dim, hidden_dim, seq_len)
initialize_sess()

raw_data = []
for i in range(data_set_size):
    raw_data.append(tdg.generate_data(sess))
td = TargetData(np.concatenate(raw_data, axis=0), embedding_arr, image_feature_dim, vocab_size)

In [4]:
gt_loss = tdg.evaluate(sess, td)

GT Loss: 5.94868


In [5]:
gen_spec = GeneratorSpec(input_dim=None, hidden_dim=hidden_dim * 2, output_dim=vocab_size, rnn_activation=None,
                         image_feature_dim=image_feature_dim, n_seq_steps=seq_len-1,
                         embedding_init=tf.Variable(embedding_arr, name="embedding", dtype=tf.float32),
                         n_baseline_layers=1, baseline_hidden_dim=32,
                         mle_learning_rate=1e-2, pg_learning_rate=5e-4,
                         baseline_learning_rate=5e-3, batch_size=batch_size, epsilon=0.2)

### Pretrain the Generator

In [6]:
gen = GeneratorWrapper(gen_spec, None)
initialize_sess()

In [None]:
oracle_loss, mle_cross_entropies, accuracies, discriminator_init_data = [], [], [], []

In [None]:
for epoch in range(20):
    cross_entropy, accuracy = gen.train(sess, td, num_iterations=400, training_type='MLE')
    mle_cross_entropies.append(cross_entropy)
    accuracies.append(accuracy)

    dat, _, __, ___ = gen.test(sess, td, 20)
    full_dat = np.ones((dat.shape[0], seq_len)) * tdg.start_token[0]
    full_dat[:, 1:] = dat
    gd = TargetData(full_dat, embedding_arr, image_feature_dim, vocab_size)
    gd.set_mode('MLE').set_batch_size(batch_size)
    oracle_loss.append(tdg.evaluate(sess, gd))
    discriminator_init_data.append(full_dat)

GT Loss: 10.35077
GT Loss: 10.14843
GT Loss: 10.08874
GT Loss: 10.06451
GT Loss: 10.04653
GT Loss: 10.00387
GT Loss: 9.97735


### Pretrain the Discriminator

In [None]:
bad_data = TargetData(np.concatenate(discriminator_init_data, axis=0), embedding_arr, image_feature_dim, vocab_size)

disc = DiscriminatorWrapper(td, bad_data, td, hidden_dim * 2)
initialize_sess()

for i in range(20):
    train_loss, val_loss = disc.pre_train(sess, iter_num=100, batch_size=batch_size)

### Joint Training

In [None]:
gen._discriminator_reward = disc.assign_reward

In [None]:
prev_sentences = []
for cycle in range(10000):
    td.shuffle()
    caption_sentences, _, img_idxs, r = gen.train(sess, td, 1, training_type='PPO')
    prev_sentences.extend(caption_sentences)
    if len(prev_sentences) > len(caption_sentences):
        idxs = np.random.choice(len(prev_sentences), 2 * len(caption_sentences))
        for i in idxs:
            caption_sentences.append(prev_sentences[i])
        img_idxs = img_idxs * 3
    train_losses, val_losses = disc.online_train(sess, 5, np.array(img_idxs), caption_sentences, batch_size=batch_size)
    while train_losses[-1] > 1.05:
        train_losses, val_losses = disc.online_train(sess, 5, np.array(img_idxs),
                                                     caption_sentences, batch_size=batch_size)
    
    if cycle % data_set_size == 0:
        dat, _, __, ___ = gen.test(sess, td, 20)
        full_dat = np.ones((dat.shape[0], seq_len)) * tdg.start_token[0]
        full_dat[:, 1:] = dat
        gd = TargetData(full_dat, embedding_arr, image_feature_dim, vocab_size)
        gd.set_mode('MLE').set_batch_size(batch_size)
        oracle_loss.append(tdg.evaluate(sess, gd))