In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=invalid-name, too-many-locals, too-many-arguments, no-member

import os
import importlib
import numpy as np
import tensorflow as tf
import texar as tx

from ctrl_gen_model import CtrlGenModel


config = importlib.import_module('config')

In [2]:
train_data = tx.data.MultiAlignedData(config.train_data)
val_data = tx.data.MultiAlignedData(config.val_data)
test_data = tx.data.MultiAlignedData(config.test_data)
vocab = train_data.vocab(0)

# Each training batch is used twice: once for updating the generator and
# once for updating the discriminator. Feedable data iterator is used for
# such case.
iterator = tx.data.FeedableDataIterator(
    {'train_g': train_data, 'train_d': train_data,
     'val': val_data, 'test': test_data})
batch = iterator.get_next()

# Model
gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma')
lambda_g = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_g')
model = CtrlGenModel(batch, vocab, gamma, lambda_g, config.model)


Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.contrib.distributions`.
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.contrib.distributions`.
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.contrib.distributions`.
Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.contrib.distributions`.


In [56]:
def _train_epoch(sess, gamma_, lambda_g_, epoch, verbose=True):
    avg_meters_d = tx.utils.AverageRecorder(size=10)
    avg_meters_g = tx.utils.AverageRecorder(size=10)

    step = 0
    while True:
        try:
            step += 1
            feed_dict = {
                iterator.handle: iterator.get_handle(sess, 'train_d'),
                gamma: gamma_,
                lambda_g: lambda_g_
            }

            vals_d = sess.run(model.fetches_train_d, feed_dict=feed_dict)
            avg_meters_d.add(vals_d)

            feed_dict = {
                iterator.handle: iterator.get_handle(sess, 'train_g'),
                gamma: gamma_,
                lambda_g: lambda_g_
            }
            vals_g = sess.run(model.fetches_train_g, feed_dict=feed_dict)
            avg_meters_g.add(vals_g)

            if verbose and (step == 1 or step % 10 == 0):
                print('step: {}, {}'.format(step, avg_meters_d.to_str(4)))
                print('step: {}, {}'.format(step, avg_meters_g.to_str(4)))

            if verbose and step % config.display_eval == 0:
                iterator.restart_dataset(sess, 'val')
                _eval_epoch(sess, gamma_, lambda_g_, epoch)

        except tf.errors.OutOfRangeError:
            print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4)))
            print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4)))
            break

def _eval_epoch(sess, gamma_, lambda_g_, epoch, val_or_test='val'):
    avg_meters = tx.utils.AverageRecorder()

    while True:
        try:
            feed_dict = {
                iterator.handle: iterator.get_handle(sess, val_or_test),
                gamma: gamma_,
                lambda_g: lambda_g_,
                tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
            }

            vals = sess.run(model.fetches_eval, feed_dict=feed_dict)

            batch_size = vals.pop('batch_size')

            # Computes BLEU
            samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
            hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab)
            print("samples: ",hyps)

            refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
            refs = np.expand_dims(refs, axis=1)
            print("reference: ",refs)

            bleu = tx.evals.corpus_bleu_moses(refs, hyps)
            vals['bleu'] = bleu

            avg_meters.add(vals, weight=batch_size)

            ###################################!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            # Writes samples
            tx.utils.write_paired_text(
                refs.squeeze(), hyps,
                os.path.join(config.sample_path, 'val.%d'%epoch),
                append=True, mode='v')

        except tf.errors.OutOfRangeError:
            print('{}: {}'.format(
                val_or_test, avg_meters.to_str(precision=4)))
            break

    return avg_meters.avg()


In [57]:
iterator.restart_dataset(sess, ['train_g', 'train_d'])
_train_epoch(sess, gamma_, lambda_g_, epoch)

KeyboardInterrupt: 

In [4]:
tf.gfile.MakeDirs(config.sample_path)
tf.gfile.MakeDirs(config.checkpoint_path)

In [5]:
sess = tf.InteractiveSession()

In [6]:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())

saver = tf.train.Saver(max_to_keep=None)

In [9]:
iterator.initialize_dataset(sess)

gamma_ = 1.
lambda_g_ = 0.

In [10]:
epoch = 1

In [13]:
print('gamma: {}, lambda_g: {}'.format(gamma_, lambda_g_))

gamma: 1.0, lambda_g: 0.0


In [86]:
for epoch in range(1,10):
    iterator.restart_dataset(sess, ['train_g', 'train_d'])
    _train_epoch(sess, gamma_, lambda_g_, epoch)

step: 1, loss_d: 0.0832 accu_d: 0.9688
step: 1, loss_g: 1.3574 loss_g_ae: 1.3574 loss_g_clas: 1.8153 accu_g: 0.4219 accu_g_gdy: 0.4531
step: 10, loss_d: 0.1106 accu_d: 0.9625
step: 10, loss_g: 1.2945 loss_g_ae: 1.2945 loss_g_clas: 2.1759 accu_g: 0.4016 accu_g_gdy: 0.3688
step: 20, loss_d: 0.0738 accu_d: 0.9734
step: 20, loss_g: 1.2504 loss_g_ae: 1.2504 loss_g_clas: 2.2360 accu_g: 0.4109 accu_g_gdy: 0.3922
step: 30, loss_d: 0.0834 accu_d: 0.9719
step: 30, loss_g: 1.2639 loss_g_ae: 1.2639 loss_g_clas: 2.4046 accu_g: 0.3812 accu_g_gdy: 0.3875
step: 40, loss_d: 0.1122 accu_d: 0.9594
step: 40, loss_g: 1.2858 loss_g_ae: 1.2858 loss_g_clas: 2.4674 accu_g: 0.3766 accu_g_gdy: 0.3812
step: 50, loss_d: 0.0953 accu_d: 0.9656
step: 50, loss_g: 1.1643 loss_g_ae: 1.1643 loss_g_clas: 2.3910 accu_g: 0.3812 accu_g_gdy: 0.3766
step: 60, loss_d: 0.0967 accu_d: 0.9625
step: 60, loss_g: 1.2551 loss_g_ae: 1.2551 loss_g_clas: 2.1699 accu_g: 0.4016 accu_g_gdy: 0.4078
step: 70, loss_d: 0.0729 accu_d: 0.9766
ste

step: 600, loss_d: 0.0544 accu_d: 0.9750
step: 600, loss_g: 0.7937 loss_g_ae: 0.7937 loss_g_clas: 3.4107 accu_g: 0.2922 accu_g_gdy: 0.2891
step: 610, loss_d: 0.0683 accu_d: 0.9844
step: 610, loss_g: 0.8060 loss_g_ae: 0.8060 loss_g_clas: 3.4462 accu_g: 0.2328 accu_g_gdy: 0.2641
step: 620, loss_d: 0.0941 accu_d: 0.9672
step: 620, loss_g: 0.8092 loss_g_ae: 0.8092 loss_g_clas: 3.4596 accu_g: 0.2594 accu_g_gdy: 0.2531
step: 630, loss_d: 0.0773 accu_d: 0.9703
step: 630, loss_g: 0.8320 loss_g_ae: 0.8320 loss_g_clas: 3.6549 accu_g: 0.2500 accu_g_gdy: 0.2609
step: 640, loss_d: 0.1109 accu_d: 0.9641
step: 640, loss_g: 0.7936 loss_g_ae: 0.7936 loss_g_clas: 3.7584 accu_g: 0.2203 accu_g_gdy: 0.2344
step: 650, loss_d: 0.0918 accu_d: 0.9578
step: 650, loss_g: 0.7552 loss_g_ae: 0.7552 loss_g_clas: 3.7028 accu_g: 0.2266 accu_g_gdy: 0.2281
step: 660, loss_d: 0.0782 accu_d: 0.9734
step: 660, loss_g: 0.8139 loss_g_ae: 0.8139 loss_g_clas: 3.5512 accu_g: 0.2609 accu_g_gdy: 0.2422
step: 670, loss_d: 0.1000 a

KeyboardInterrupt: 

In [101]:
iterator.restart_dataset(sess, 'val')
#_eval_epoch(sess, gamma_, lambda_g_, epoch, 'val')
val_or_test='val'
avg_meters = tx.utils.AverageRecorder()
feed_dict = {
                iterator.handle: iterator.get_handle(sess, val_or_test),
                gamma: gamma_,
                lambda_g: lambda_g_,
                tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
            }
vals = sess.run(model.fetches_eval, feed_dict=feed_dict)

In [102]:
vals

{'batch_size': 64,
 'loss_g': 0.5256332,
 'loss_g_ae': 0.5256332,
 'loss_g_clas': 4.253072,
 'loss_d': 0.052145958,
 'accu_d': 0.96875,
 'accu_g': 0.171875,
 'accu_g_gdy': 0.078125,
 'original': array([[ 300,   31,  173, ...,   17,    8,    2],
        [  25,  353,   70, ...,    0,    0,    0],
        [   7,   11,  201, ...,    0,    0,    0],
        ...,
        [1097,    2,    0, ...,    0,    0,    0],
        [  86,  226,   38, ...,    2,    0,    0],
        [   7,  119,  582, ...,    2,    0,    0]]),
 'transferred': array([[ 300,   31,  173, ...,   17,    8,    2],
        [  25,  353,   70, ...,    0,    0,    0],
        [   7,   11,  201, ...,    0,    0,    0],
        ...,
        [1097,    2,    0, ...,    0,    0,    0],
        [  86,  226,   38, ...,    2,    0,    0],
        [   7,  119, 1017, ...,    2,    0,    0]], dtype=int32)}

In [80]:
vals

{'batch_size': 64,
 'loss_g': 1.1538124,
 'loss_g_ae': 1.1538124,
 'loss_g_clas': 2.4324093,
 'loss_d': 0.09194705,
 'accu_d': 0.953125,
 'accu_g': 0.40625,
 'accu_g_gdy': 0.359375,
 'original': array([[ 365,  139,    6, ...,    0,    0,    0],
        [3339, 2903,    5, ...,    2,    0,    0],
        [   5,   17,    6, ...,    2,    0,    0],
        ...,
        [  67,  388,   30, ...,    0,    0,    0],
        [  15,  333,   10, ...,    0,    0,    0],
        [  35,  101,   63, ...,    0,    0,    0]]),
 'transferred': array([[185, 139,   6, ...,   0,   0,   0],
        [658, 253,   5, ...,   0,   0,   0],
        [  5,  17,   6, ...,   0,   0,   0],
        ...,
        [ 67, 306,  30, ...,   0,   0,   0],
        [ 15,  19,  10, ...,   0,   0,   0],
        [ 35, 101,  63, ...,   0,   0,   0]], dtype=int32)}

In [72]:
vals

{'batch_size': 64,
 'loss_g': 4.487425,
 'loss_g_ae': 4.487425,
 'loss_g_clas': 0.2423923,
 'loss_d': 0.1387409,
 'accu_d': 0.953125,
 'accu_g': 0.953125,
 'accu_g_gdy': 0.953125,
 'original': array([[   7,   65,   31, ...,    0,    0,    0],
        [1284,   55,   22, ...,    0,    0,    0],
        [3684,   10,  258, ...,    0,    0,    0],
        ...,
        [   7,   81,    5, ...,    0,    0,    0],
        [   7,  119,   52, ...,    0,    0,    0],
        [  21,  276,    8, ...,    0,    0,    0]]),
 'transferred': array([[ 5,  5,  5, ...,  0,  0,  0],
        [ 5,  5, 11, ...,  0,  0,  0],
        [ 5,  5, 11, ...,  0,  0,  0],
        ...,
        [ 5,  5,  5, ...,  0,  0,  0],
        [ 5, 10, 10, ...,  0,  0,  0],
        [20,  4,  4, ...,  0,  0,  0]], dtype=int32)}

In [66]:
vals

{'batch_size': 64,
 'loss_g': 4.594885,
 'loss_g_ae': 4.594885,
 'loss_g_clas': 0.36999792,
 'loss_d': 0.1608523,
 'accu_d': 0.9375,
 'accu_g': 0.921875,
 'accu_g_gdy': 0.640625,
 'original': array([[  20,   10,   45, ...,    0,    0,    0],
        [  12,   25,  871, ...,    0,    0,    0],
        [  23, 3893,  436, ...,  896,    4,    2],
        ...,
        [  56, 1387,  112, ...,    0,    0,    0],
        [1190,    6, 2502, ...,    0,    0,    0],
        [  15,   10,   12, ...,    0,    0,    0]]),
 'transferred': array([[ 7,  5,  5, ...,  0,  0,  0],
        [ 7,  5,  5, ...,  0,  0,  0],
        [ 5,  5,  5, ...,  4,  4,  2],
        ...,
        [ 5, 16, 16, ...,  0,  0,  0],
        [ 5,  5,  6, ...,  0,  0,  0],
        [ 7,  5,  5, ...,  0,  0,  0]], dtype=int32)}

In [50]:
vals

{'batch_size': 64,
 'loss_g': 4.973015,
 'loss_g_ae': 4.973015,
 'loss_g_clas': 0.70567334,
 'loss_d': 0.16262153,
 'accu_d': 0.953125,
 'accu_g': 0.609375,
 'accu_g_gdy': 0.6875,
 'original': array([[  7, 177,  37, ...,   0,   0,   0],
        [240,  27, 670, ...,   0,   0,   0],
        [  7,  65,  28, ...,   0,   0,   0],
        ...,
        [250, 125,   4, ...,   0,   0,   0],
        [258,  12,  55, ...,   0,   0,   0],
        [ 25, 184,   9, ...,   0,   0,   0]]),
 'transferred': array([[ 7,  5,  5, ...,  0,  0,  0],
        [ 7,  5,  5, ...,  0,  0,  0],
        [ 5,  5,  5, ...,  0,  0,  0],
        ...,
        [16,  8,  8, ...,  0,  0,  0],
        [ 5,  5,  5, ...,  0,  0,  0],
        [ 5,  5,  5, ...,  0,  0,  0]], dtype=int32)}

In [95]:
batch_size = vals.pop('batch_size')

In [82]:
# Computes BLEU
samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab)
print("samples: ",hyps)

samples:  ['off work and they are ever recommended !'
 'needless took the front desk is a tasted doctors ... ever too .'
 'the food and service at this location has always been on selection .'
 'service selection on providers and in line , and a flavorful staff .'
 'a sure of asking checked restaurants in rice town specials .'
 'this place is awesome .' 'the menu is pretty highlight .'
 'plenty of chips places around with better works , and nice food .'
 'they are sour helpful , always friendly , always interested .'
 'the person cream and service are over good too .'
 "thank 's booths with the burgh ." 'i have always why this sign bell .'
 'my : its way too consistently in perfect .'
 'great people there and they all taste very friendly .'
 'pretty good food !' 'no bueno !'
 'to get any room of service , you need to any at the bar area .'
 'the food was great , service friendly and spot on .'
 "love my dropped 's your doctors ." 'so so disappointed .'
 'i wish continue eat was an chai

In [83]:
refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
refs = np.expand_dims(refs, axis=1)
print("reference: ",refs)

reference:  [['honest work and they are highly recommended !']
 ['josh @ the front desk is a real hoot ... fun guy .']
 ['the food and service at this location has always been top notch .']
 ['excellent selection on tap and in bottles , and a knowledgeable staff .']
 ['a bit of touristy southwest heaven in old town scottsdale .']
 ['this place is awesome .']
 ['the menu is pretty extensive .']
 ['tons of sandwich places around with better attitudes , and better food .']
 ['they are soooo helpful , always friendly , always cheerful .']
 ['the ice cream and service are both good too .']
 ["let 's begin with the positives ."]
 ['i have always used this discount tire .']
 ['2nd : its way too dark in night .']
 ['great people there and they all seemed very friendly .']
 ['pretty good food !']
 ['no dice !']
 ['to get any kind of service , you need to sit at the bar .']
 ['the food was great , service friendly and spot on .']
 ["love my mother 's day pie ."]
 ['so so disappointed .']
 ['i wi

In [103]:
# Computes BLEU
samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab)
print("samples: ",hyps)

samples:  ["ca n't wait to come back to watch some anyone and eat amazing food !"
 'very pleasant experience .'
 'i was extremely pleased with my experience .'
 'although the melting are a little expensive , it is nor worth it .'
 'will definitely be back !' 'no , no , no .' 'not fancy .' 'wow !'
 'i am so not overwhelming !'
 'its sweet with that brown answering practices .'
 "i 've eaten a lot of it in my day ." 'our bartender was very nice .'
 'so looking not worth it .'
 'no server with the food here , just the service .'
 'service was really slow .'
 'if i could give it a negative star i would in a safety !'
 'i have continue the owner several times with no return phone call .'
 'i thought and moist and found a gem for the kids to enjoy .'
 'so wo for the total means !'
 'place needs a favorite sale and techs in customer service .'
 'the casino and apologetic is so good .'
 'i wanted to like this place but it just seeing a big anyway .'
 'great family owned restaurant !'
 'i would

In [104]:
refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
refs = np.expand_dims(refs, axis=1)
print("reference: ",refs)

reference:  [["ca n't wait to come back to watch some games and eat amazing food !"]
 ['very pleasant experience .']
 ['i was extremely pleased with my experience .']
 ['although the treatments are a little expensive , it is certainly worth it .']
 ['will definitely be back !']
 ['no , no , no .']
 ['not anymore .']
 ['wow !']
 ['i am so not lying !']
 ['its sweet with that brown sugar topping .']
 ["i 've eaten a lot of it in my day ."]
 ['our bartender was very nice .']
 ['so totally not worth it .']
 ['no problem with the food here , just the service .']
 ['service was really slow .']
 ['if i could give it a negative star i would in a heartbeat !']
 ['i have called the owner several times with no return phone call .']
 ['i live and chandler and found a gem for the kids to enjoy .']
 ['so excited for the year ahead !']
 ['place needs a deep cleaning and training in customer service .']
 ['the pepperoni and ricotta is so good .']
 ['i wanted to like this place but it just became a big

In [33]:
bleu = tx.evals.corpus_bleu_moses(refs, hyps)
vals['bleu'] = bleu

avg_meters.add(vals, weight=batch_size)

{'loss_g': 15.633689880371094,
 'loss_g_ae': 15.633689880371094,
 'loss_g_clas': 0.6949148774147034,
 'loss_d': 0.6888538599014282,
 'accu_d': 0.5625,
 'accu_g': 0.4375,
 'accu_g_gdy': 0.4375,
 'bleu': 0.0}

In [None]:

        try:
            feed_dict = {
                iterator.handle: iterator.get_handle(sess, val_or_test),
                gamma: gamma_,
                lambda_g: lambda_g_,
                tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
            }

            vals = sess.run(model.fetches_eval, feed_dict=feed_dict)

            batch_size = vals.pop('batch_size')

            # Computes BLEU
            samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
            hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab)
            print("samples: ",hyps)

            refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
            refs = np.expand_dims(refs, axis=1)
            print("reference: ",refs)

            bleu = tx.evals.corpus_bleu_moses(refs, hyps)
            vals['bleu'] = bleu

            avg_meters.add(vals, weight=batch_size)

            ###################################!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            # Writes samples
            tx.utils.write_paired_text(
                refs.squeeze(), hyps,
                os.path.join(config.sample_path, 'val.%d'%epoch),
                append=True, mode='v')

        except tf.errors.OutOfRangeError:
            print('{}: {}'.format(
                val_or_test, avg_meters.to_str(precision=4)))
            break

    return avg_meters.avg()


In [None]:
if config.restore:
    print('Restore from: {}'.format(config.restore))
    saver.restore(sess, config.restore)

for epoch in range(1, config.max_nepochs+1):
    if epoch > config.pretrain_nepochs:
        # Anneals the gumbel-softmax temperature
        gamma_ = max(0.001, gamma_ * config.gamma_decay)
        lambda_g_ = config.lambda_g
    print('gamma: {}, lambda_g: {}'.format(gamma_, lambda_g_))

    # Train
    iterator.restart_dataset(sess, ['train_g', 'train_d'])
    _train_epoch(sess, gamma_, lambda_g_, epoch)

    # Val
    iterator.restart_dataset(sess, 'val')
    _eval_epoch(sess, gamma_, lambda_g_, epoch, 'val')

    saver.save(
        sess, os.path.join(config.checkpoint_path, 'ckpt'), epoch)

    # Test
    iterator.restart_dataset(sess, 'test')
    _eval_epoch(sess, gamma_, lambda_g_, epoch, 'test')

In [98]:
saver.save(sess, os.path.join('save', 'model'), epoch)

INFO:tensorflow:save/model-1 is not in all_model_checkpoint_paths. Manually adding it.


'save/model-1'

In [113]:
saver.restore(sess, 'save2/model.ckpt')

INFO:tensorflow:Restoring parameters from save2/model.ckpt


In [112]:
saver.save(sess,'save2/model.ckpt')

INFO:tensorflow:save2/model.ckpt is not in all_model_checkpoint_paths. Manually adding it.


'save2/model.ckpt'

In [None]:
def _main(_):
    # Data
    train_data = tx.data.MultiAlignedData(config.train_data)
    val_data = tx.data.MultiAlignedData(config.val_data)
    test_data = tx.data.MultiAlignedData(config.test_data)
    vocab = train_data.vocab(0)

    # Each training batch is used twice: once for updating the generator and
    # once for updating the discriminator. Feedable data iterator is used for
    # such case.
    iterator = tx.data.FeedableDataIterator(
        {'train_g': train_data, 'train_d': train_data,
         'val': val_data, 'test': test_data})
    batch = iterator.get_next()

    # Model
    gamma = tf.placeholder(dtype=tf.float32, shape=[], name='gamma')
    lambda_g = tf.placeholder(dtype=tf.float32, shape=[], name='lambda_g')
    model = CtrlGenModel(batch, vocab, gamma, lambda_g, config.model)

    def _train_epoch(sess, gamma_, lambda_g_, epoch, verbose=True):
        avg_meters_d = tx.utils.AverageRecorder(size=10)
        avg_meters_g = tx.utils.AverageRecorder(size=10)

        step = 0
        while True:
            try:
                step += 1
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train_d'),
                    gamma: gamma_,
                    lambda_g: lambda_g_
                }

                vals_d = sess.run(model.fetches_train_d, feed_dict=feed_dict)
                avg_meters_d.add(vals_d)

                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, 'train_g'),
                    gamma: gamma_,
                    lambda_g: lambda_g_
                }
                vals_g = sess.run(model.fetches_train_g, feed_dict=feed_dict)
                avg_meters_g.add(vals_g)

                if verbose and (step == 1 or step % config.display == 0):
                    print('step: {}, {}'.format(step, avg_meters_d.to_str(4)))
                    print('step: {}, {}'.format(step, avg_meters_g.to_str(4)))

                if verbose and step % config.display_eval == 0:
                    iterator.restart_dataset(sess, 'val')
                    _eval_epoch(sess, gamma_, lambda_g_, epoch)

            except tf.errors.OutOfRangeError:
                print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4)))
                print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4)))
                break

    def _eval_epoch(sess, gamma_, lambda_g_, epoch, val_or_test='val'):
        avg_meters = tx.utils.AverageRecorder()

        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, val_or_test),
                    gamma: gamma_,
                    lambda_g: lambda_g_,
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
                }

                vals = sess.run(model.fetches_eval, feed_dict=feed_dict)

                batch_size = vals.pop('batch_size')

                # Computes BLEU
                samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
                hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab)
                print("samples: ",hyps)

                refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
                refs = np.expand_dims(refs, axis=1)
                print("reference: ",refs)

                bleu = tx.evals.corpus_bleu_moses(refs, hyps)
                vals['bleu'] = bleu

                avg_meters.add(vals, weight=batch_size)

                ###################################!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                # Writes samples
                tx.utils.write_paired_text(
                    refs.squeeze(), hyps,
                    os.path.join(config.sample_path, 'val.%d'%epoch),
                    append=True, mode='v')

            except tf.errors.OutOfRangeError:
                print('{}: {}'.format(
                    val_or_test, avg_meters.to_str(precision=4)))
                break

        return avg_meters.avg()

    tf.gfile.MakeDirs(config.sample_path)
    tf.gfile.MakeDirs(config.checkpoint_path)

    # Runs the logics
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        saver = tf.train.Saver(max_to_keep=None)
        if config.restore:
            print('Restore from: {}'.format(config.restore))
            saver.restore(sess, config.restore)

        iterator.initialize_dataset(sess)

        gamma_ = 1.
        lambda_g_ = 0.
        for epoch in range(1, config.max_nepochs+1):
            if epoch > config.pretrain_nepochs:
                # Anneals the gumbel-softmax temperature
                gamma_ = max(0.001, gamma_ * config.gamma_decay)
                lambda_g_ = config.lambda_g
            print('gamma: {}, lambda_g: {}'.format(gamma_, lambda_g_))

            # Train
            iterator.restart_dataset(sess, ['train_g', 'train_d'])
            _train_epoch(sess, gamma_, lambda_g_, epoch)

            # Val
            iterator.restart_dataset(sess, 'val')
            _eval_epoch(sess, gamma_, lambda_g_, epoch, 'val')

            saver.save(
                sess, os.path.join(config.checkpoint_path, 'ckpt'), epoch)

            # Test
            iterator.restart_dataset(sess, 'test')
            _eval_epoch(sess, gamma_, lambda_g_, epoch, 'test')

if __name__ == '__main__':
    tf.app.run(main=_main)