In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=invalid-name, too-many-locals, too-many-arguments, no-member

import os
import importlib
import numpy as np
import tensorflow as tf
import texar as tx

from RL_model import RLModel

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from itertools import chain
from torchtext import data

In [2]:
config = importlib.import_module('RLconfig')

In [3]:
from texar.modules import WordEmbedder, UnidirectionalRNNEncoder, \
        MLPTransformConnector, AttentionRNNDecoder, \
        GumbelSoftmaxEmbeddingHelper, Conv1DClassifier
from texar.core import get_train_op
from texar.utils import collect_trainable_variables, get_batch_size

In [4]:
# Data
train_data = tx.data.MultiAlignedData(config.train_data)
val_data = tx.data.MultiAlignedData(config.val_data)
test_data = tx.data.MultiAlignedData(config.test_data)
vocab = train_data.vocab(0)

# Each training batch is used twice: once for updating the generator and
# once for updating the discriminator. Feedable data iterator is used for
# such case.
iterator = tx.data.FeedableDataIterator({'train_g': train_data,'val': val_data, 'test': test_data})
batch = iterator.get_next()

In [5]:
# Model
model = RLModel(batch, vocab, config.model)

In [6]:
def _train_epoch(sess, epoch, verbose=True):
        avg_meters_g = tx.utils.AverageRecorder(size=10)

        step = 0
        while True:
            try:
                step += 1

                feed_dict = {iterator.handle: iterator.get_handle(sess, 'train_g')}
                vals_g = sess.run(model.train_g, feed_dict=feed_dict)
                avg_meters_g.add(vals_g)

                if verbose and (step == 1 or step % 5 == 0):
                    print('step: {}, {}'.format(step, avg_meters_g.to_str(4)))
                    
                '''
                if verbose and step % 2 == 0:
                    iterator.restart_dataset(sess, 'val')
                    _eval_epoch(sess, epoch)
                '''
                
            except tf.errors.OutOfRangeError:
                print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4)))
                break

In [7]:
 def _eval_epoch(sess, epoch, val_or_test='val'):
        avg_meters = tx.utils.AverageRecorder()

        while True:
            try:
                feed_dict = {
                    iterator.handle: iterator.get_handle(sess, val_or_test),
                    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
                }

                vals = sess.run(model.samples, feed_dict=feed_dict)

                batch_size = vals.pop('batch_size')

                # Computes BLEU
                samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
                hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab)
                print("samples: ",hyps)

                refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
                refs = np.expand_dims(refs, axis=1)
                print("reference: ",refs)

                bleu = tx.evals.corpus_bleu_moses(refs, hyps)
                vals['bleu'] = bleu

                avg_meters.add(vals, weight=batch_size)

                ###################################!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                # Writes samples
                '''
                tx.utils.write_paired_text(
                    refs.squeeze(), hyps,
                    os.path.join(config.sample_path, 'val.%d'%epoch),
                    append=True, mode='v')
                '''
            except tf.errors.OutOfRangeError:
                print('{}: {}'.format(
                    val_or_test, avg_meters.to_str(precision=4)))
                break

        return avg_meters.avg()

In [8]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())

saver = tf.train.Saver(max_to_keep=None)

In [9]:
if config.restore:
    print('Restore from: {}'.format(config.restore))
    saver.restore(sess, config.restore)

iterator.initialize_dataset(sess)

In [10]:
epoch = 1

In [11]:
feed_dict = {
    iterator.handle: iterator.get_handle(sess, 'val'),
    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
}

vals = sess.run(model.samples, feed_dict=feed_dict)

batch_size = vals.pop('batch_size')

In [12]:
iterator.restart_dataset(sess, ['train_g'])
_train_epoch(sess, epoch)

step: 1, loss_g_ae: 18.0444 train_op_g_ae: 18.0444
step: 5, loss_g_ae: 17.8378 train_op_g_ae: 17.8378
step: 10, loss_g_ae: 16.2451 train_op_g_ae: 16.2451
step: 15, loss_g_ae: 12.4887 train_op_g_ae: 12.4887
step: 20, loss_g_ae: 9.2054 train_op_g_ae: 9.2054
step: 25, loss_g_ae: 7.7582 train_op_g_ae: 7.7582
step: 30, loss_g_ae: 7.3081 train_op_g_ae: 7.3081
step: 35, loss_g_ae: 6.9869 train_op_g_ae: 6.9869
step: 40, loss_g_ae: 6.6773 train_op_g_ae: 6.6773
step: 45, loss_g_ae: 6.6282 train_op_g_ae: 6.6282
step: 50, loss_g_ae: 6.5619 train_op_g_ae: 6.5619
step: 55, loss_g_ae: 6.3774 train_op_g_ae: 6.3774
step: 60, loss_g_ae: 6.3412 train_op_g_ae: 6.3412
step: 65, loss_g_ae: 6.3216 train_op_g_ae: 6.3216
step: 70, loss_g_ae: 6.2219 train_op_g_ae: 6.2219
step: 75, loss_g_ae: 6.1349 train_op_g_ae: 6.1349
step: 80, loss_g_ae: 6.0558 train_op_g_ae: 6.0558
step: 85, loss_g_ae: 5.9959 train_op_g_ae: 5.9959
step: 90, loss_g_ae: 5.9284 train_op_g_ae: 5.9284
step: 95, loss_g_ae: 5.8516 train_op_g_ae: 5

KeyboardInterrupt: 

In [15]:
iterator.restart_dataset(sess, ['train_g'])
_train_epoch(sess, epoch)

step: 1, loss_g_ae: 4.8476 train_op_g_ae: 4.8476
step: 5, loss_g_ae: 4.7750 train_op_g_ae: 4.7750
step: 10, loss_g_ae: 4.7905 train_op_g_ae: 4.7905
step: 15, loss_g_ae: 4.7529 train_op_g_ae: 4.7529
step: 20, loss_g_ae: 4.6318 train_op_g_ae: 4.6318
step: 25, loss_g_ae: 4.5595 train_op_g_ae: 4.5595
step: 30, loss_g_ae: 4.6126 train_op_g_ae: 4.6126
step: 35, loss_g_ae: 4.6772 train_op_g_ae: 4.6772
step: 40, loss_g_ae: 4.6215 train_op_g_ae: 4.6215
step: 45, loss_g_ae: 4.5977 train_op_g_ae: 4.5977
step: 50, loss_g_ae: 4.6583 train_op_g_ae: 4.6583
step: 55, loss_g_ae: 4.5719 train_op_g_ae: 4.5719
step: 60, loss_g_ae: 4.5181 train_op_g_ae: 4.5181
step: 65, loss_g_ae: 4.5695 train_op_g_ae: 4.5695
step: 70, loss_g_ae: 4.5535 train_op_g_ae: 4.5535
step: 75, loss_g_ae: 4.5394 train_op_g_ae: 4.5394
step: 80, loss_g_ae: 4.5399 train_op_g_ae: 4.5399
step: 85, loss_g_ae: 4.4974 train_op_g_ae: 4.4974
step: 90, loss_g_ae: 4.4795 train_op_g_ae: 4.4795
step: 95, loss_g_ae: 4.4959 train_op_g_ae: 4.4959
st

step: 810, loss_g_ae: 1.4366 train_op_g_ae: 1.4366
step: 815, loss_g_ae: 1.4341 train_op_g_ae: 1.4341
step: 820, loss_g_ae: 1.3809 train_op_g_ae: 1.3809
step: 825, loss_g_ae: 1.3295 train_op_g_ae: 1.3295
step: 830, loss_g_ae: 1.3073 train_op_g_ae: 1.3073
step: 835, loss_g_ae: 1.3237 train_op_g_ae: 1.3237
step: 840, loss_g_ae: 1.3285 train_op_g_ae: 1.3285
step: 845, loss_g_ae: 1.2960 train_op_g_ae: 1.2960
step: 850, loss_g_ae: 1.2772 train_op_g_ae: 1.2772
step: 855, loss_g_ae: 1.3133 train_op_g_ae: 1.3133
step: 860, loss_g_ae: 1.3926 train_op_g_ae: 1.3926
step: 865, loss_g_ae: 1.4339 train_op_g_ae: 1.4339
step: 870, loss_g_ae: 1.4136 train_op_g_ae: 1.4136
step: 875, loss_g_ae: 1.3642 train_op_g_ae: 1.3642
step: 880, loss_g_ae: 1.3208 train_op_g_ae: 1.3208
step: 885, loss_g_ae: 1.2337 train_op_g_ae: 1.2337
step: 890, loss_g_ae: 1.1677 train_op_g_ae: 1.1677
step: 895, loss_g_ae: 1.1575 train_op_g_ae: 1.1575
step: 900, loss_g_ae: 1.1727 train_op_g_ae: 1.1727
step: 905, loss_g_ae: 1.2481 tr

KeyboardInterrupt: 

In [18]:
for epoch in range(1, 10):
    # Train
    iterator.restart_dataset(sess, ['train_g'])
    _train_epoch(sess, epoch)

step: 1, loss_g_ae: 1.3450 train_op_g_ae: 1.3450
step: 5, loss_g_ae: 1.2362 train_op_g_ae: 1.2362
step: 10, loss_g_ae: 1.1934 train_op_g_ae: 1.1934
step: 15, loss_g_ae: 1.0762 train_op_g_ae: 1.0762
step: 20, loss_g_ae: 1.0415 train_op_g_ae: 1.0415
step: 25, loss_g_ae: 1.0993 train_op_g_ae: 1.0993
step: 30, loss_g_ae: 1.1369 train_op_g_ae: 1.1369
step: 35, loss_g_ae: 1.0971 train_op_g_ae: 1.0971
step: 40, loss_g_ae: 1.0907 train_op_g_ae: 1.0907
step: 45, loss_g_ae: 1.1162 train_op_g_ae: 1.1162
step: 50, loss_g_ae: 1.0587 train_op_g_ae: 1.0587
step: 55, loss_g_ae: 1.0665 train_op_g_ae: 1.0665
step: 60, loss_g_ae: 1.0563 train_op_g_ae: 1.0563
step: 65, loss_g_ae: 1.0465 train_op_g_ae: 1.0465
step: 70, loss_g_ae: 1.0888 train_op_g_ae: 1.0888
step: 75, loss_g_ae: 1.0295 train_op_g_ae: 1.0295
step: 80, loss_g_ae: 0.9901 train_op_g_ae: 0.9901
step: 85, loss_g_ae: 1.0438 train_op_g_ae: 1.0438
step: 90, loss_g_ae: 1.0498 train_op_g_ae: 1.0498
step: 95, loss_g_ae: 1.0276 train_op_g_ae: 1.0276
st

KeyboardInterrupt: 

In [22]:
saver.save(sess,'RLsave/texar_data_model2.ckpt')

INFO:tensorflow:RLsave/texar_data_model2.ckpt is not in all_model_checkpoint_paths. Manually adding it.


'RLsave/texar_data_model2.ckpt'

In [21]:
feed_dict = {
    iterator.handle: iterator.get_handle(sess, 'val'),
    tx.context.global_mode(): tf.estimator.ModeKeys.EVAL
}

vals = sess.run(model.samples, feed_dict=feed_dict)

batch_size = vals.pop('batch_size')

# Computes BLEU
samples = tx.utils.dict_pop(vals, list(model.samples.keys()))
hyps = tx.utils.map_ids_to_strs(samples['transferred'], vocab)
print("samples: ",hyps)

refs = tx.utils.map_ids_to_strs(samples['original'], vocab)
refs = np.expand_dims(refs, axis=1)
print("reference: ",refs)

samples:  ['service was terrible .' 'service was good , but food was very slow .'
 "it 's the outcome of it that makes me so supposed ."
 '1st seating of ladies and chili oil .'
 'i never thought anyone could make that one bad .'
 'then they are too big for the other cookies .' 'great service .'
 'the prices they charge are sealed for what you get .'
 'this was a fun and delicious dinner to spend a birthday .'
 "i visited t 's because i mean that they had a glaze menu ."
 'this place is a serious amount with improved service .'
 'when i went back with my wife , the cooking was completely car .'
 'top _num_ .' 'every side of walnut is just amazing .'
 'the food is delicious and served family style .'
 'free bike train , etc .' 'our dogs mean everything to us .' 'wow !'
 'it was a filthy friday .' 'the gyro fish with character was so good !'
 'anyway , i have been twice for a dinner and it has been notice .'
 'she credit back _num_ mins later and meatballs me his afternoon work looks gre

In [None]:
for epoch in range(1, 10):
    # Train
    iterator.restart_dataset(sess, ['train_g'])
    _train_epoch(sess, epoch)

    # Val
    iterator.restart_dataset(sess, 'val')
    _eval_epoch(sess, epoch, 'val')

    #saver.save(sess, os.path.join(config.checkpoint_path, 'ckpt'), epoch)
    saver.save(sess, os.path.join('screensave', 'model'), epoch)

    # Test
    iterator.restart_dataset(sess, 'test')
    _eval_epoch(sess, gamma_, lambda_g_, epoch, 'test')