In [1]:
from __future__ import absolute_import, division, print_function

import math, os, random, sys, time
import cPickle, gzip
import progressbar

import numpy as np
from six.moves import xrange
import tensorflow as tf

from tensorflow.models.rnn.translate import data_utils
from tensorflow.models.rnn.translate import seq2seq_model

In [2]:
tf.app.flags.DEFINE_float("learning_rate", 0.5, "Learning rate.")
tf.app.flags.DEFINE_float("learning_rate_decay_factor", 0.99,
                          "Learning rate decays by this much.")
tf.app.flags.DEFINE_float("max_gradient_norm", 5.0,
                          "Clip gradients to this norm.")
tf.app.flags.DEFINE_integer("batch_size", 64,
                            "Batch size to use during training.")
tf.app.flags.DEFINE_integer("size", 600, "Size of each model layer.")
tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.")
tf.app.flags.DEFINE_integer("reactant_vocab_size", 326, "Reactant vocabulary size.")
tf.app.flags.DEFINE_integer("product_vocab_size", 197, "Product vocabulary size.")
tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.")
tf.app.flags.DEFINE_integer("max_train_data_size", 0,
                            "Limit on the size of training data (0: no limit).")
tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200,
                            "How many training steps to do per checkpoint.")
tf.app.flags.DEFINE_boolean("decode", False,
                            "Set to True for interactive decoding.")
tf.app.flags.DEFINE_boolean("self_test", False,
                            "Run a self-test if this is set to True.")

FLAGS = tf.app.flags.FLAGS

In [3]:
_buckets = [(54, 54), (70, 60), (90, 65), (150, 80)]

In [4]:
dev_set = [[] for _ in _buckets]
train_set = [[] for _ in _buckets]

bar_dev = progressbar.ProgressBar(max_value=10942)
bar_train = progressbar.ProgressBar(max_value=1083292)

with gzip.open('data/dev.pkl.gz', 'rb') as dev_file:
    i = 1
    while 1:
        try:
            reactants, products = cPickle.load(dev_file)
            products.append(data_utils.EOS_ID)
            for bucket_id, (source_size, target_size) in enumerate(_buckets):
                if len(reactants) < source_size and len(products) < target_size:
                    dev_set[bucket_id].append([reactants, products])
                    break
        except EOFError:
            break
        bar_dev.update(i)
        i += 1

with gzip.open('data/train.pkl.gz', 'rb') as train_file:
    i = 1
    while 1:
        try:
            reactants, products = cPickle.load(train_file)
            products.append(data_utils.EOS_ID)
            for bucket_id, (source_size, target_size) in enumerate(_buckets):
                if len(reactants) < source_size and len(products) < target_size:
                    train_set[bucket_id].append([reactants, products])
                    break
        except EOFError:
            break
        bar_train.update(i)
        i += 1
        
print("dev_set size:", [len(d) for d in dev_set])
print("train_set size:", [len(t) for t in train_set])

 99% (1082881 of 1083292) |################ | Elapsed Time: 0:02:30 ETA: 0:00:00

dev_set size: [3793, 2772, 2307, 2056]
train_set size: [373867, 278092, 228864, 201503]


100% (1083292 of 1083292) |################| Elapsed Time: 0:02:30 ETA:  0:00:00

In [5]:
def create_model(session, forward_only):
    model = seq2seq_model.Seq2SeqModel(
        FLAGS.reactant_vocab_size, FLAGS.product_vocab_size, _buckets,
        FLAGS.size, FLAGS.num_layers, FLAGS.max_gradient_norm, FLAGS.batch_size,
        FLAGS.learning_rate, FLAGS.learning_rate_decay_factor,
        forward_only=forward_only)
    ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
    if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.initialize_all_variables())
    return model

In [6]:
def train():
    with tf.Session() as sess:
        # Create model.
        print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, False)

        # Compute bucket sizes.      
        train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
        train_total_size = float(sum(train_bucket_sizes))

        # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use
        # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to
        # the size if i-th training bucket, as used later.
        train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size
                                for i in xrange(len(train_bucket_sizes))]

        # This is the training loop.
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        while True:
            # Choose a bucket according to data distribution. We pick a random number
            # in [0, 1] and use the corresponding interval in train_buckets_scale.
            random_number_01 = np.random.random_sample()
            bucket_id = min([i for i in xrange(len(train_buckets_scale))
            if train_buckets_scale[i] > random_number_01])

            # Get a batch and make a step.
            start_time = time.time()
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                      train_set, bucket_id)
            _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
                                               target_weights, bucket_id, False)
            step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint
            loss += step_loss / FLAGS.steps_per_checkpoint
            current_step += 1

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % FLAGS.steps_per_checkpoint == 0:
                 # Print statistics for the previous epoch.
                perplexity = math.exp(loss) if loss < 300 else float('inf')
                print ("global step %d learning rate %.4f step-time %.2f perplexity "
                        "%.2f" % (model.global_step.eval(), model.learning_rate.eval(),
                                    step_time, perplexity))
                # Decrease learning rate if no improvement was seen over last 3 times.
                if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
                    sess.run(model.learning_rate_decay_op)
                previous_losses.append(loss)
                # Save checkpoint and zero timer and loss.
                checkpoint_path = os.path.join(FLAGS.train_dir, "translate.ckpt")
                model.saver.save(sess, checkpoint_path, global_step=model.global_step)
                step_time, loss = 0.0, 0.0
                # Run evals on development set and print their perplexity.
                for bucket_id in xrange(len(_buckets)):
                    if len(dev_set[bucket_id]) == 0:
                        print("  eval: empty bucket %d" % (bucket_id))
                        continue
                    encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                        dev_set, bucket_id)
                    _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
                                                    target_weights, bucket_id, True)
                    eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
                    print("  eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))

In [7]:
train()

Creating 3 layers of 600 units.
Created model with fresh parameters.
global step 200 learning rate 0.5000 step-time 0.64 perplexity 24.82
  eval: bucket 0 perplexity 10.18
  eval: bucket 1 perplexity 9.98
  eval: bucket 2 perplexity 11.04
  eval: bucket 3 perplexity 11.12
global step 400 learning rate 0.5000 step-time 0.62 perplexity 9.45
  eval: bucket 0 perplexity 7.24
  eval: bucket 1 perplexity 7.30
  eval: bucket 2 perplexity 7.74
  eval: bucket 3 perplexity 7.67
global step 600 learning rate 0.5000 step-time 0.59 perplexity 6.76
  eval: bucket 0 perplexity 5.51
  eval: bucket 1 perplexity 6.06
  eval: bucket 2 perplexity 6.27
  eval: bucket 3 perplexity 6.69
global step 800 learning rate 0.5000 step-time 0.61 perplexity 5.10
  eval: bucket 0 perplexity 4.48
  eval: bucket 1 perplexity 4.51
  eval: bucket 2 perplexity 4.80
  eval: bucket 3 perplexity 4.81
global step 1000 learning rate 0.5000 step-time 0.59 perplexity 4.14
  eval: bucket 0 perplexity 3.45
  eval: bucket 1 perplexi

KeyboardInterrupt: 