In [1]:
import tensorflow as tf
import data.reader as reader
import datetime
import os
import logging
import signal
import data.reader as reader

In [2]:
logging.basicConfig(format="[%(asctime)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger()
logger.setLevel('INFO')
signal.signal(signal.SIGTERM, lambda s,f: sys.exit(0))

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('data_dir', 'data', """Directory of data.""")
tf.app.flags.DEFINE_integer('batch_size', 32, """Batch size.""")
tf.app.flags.DEFINE_float('learning_rate', 1.0, """Learning rate.""")
tf.app.flags.DEFINE_integer('lstm_size', 64, """LSTM hidden size.""")
tf.app.flags.DEFINE_integer('num_layers', 2, """Number of LSTM layers.""")
tf.app.flags.DEFINE_integer('num_steps', 100, """Sequence length.""")

In [3]:
def construct_graph(graph):
  size = FLAGS.lstm_size
  batch_size = FLAGS.batch_size
  num_steps = FLAGS.num_steps
  num_layers = FLAGS.num_layers

  with graph.as_default():
    inputs, labels = reader.get_batch(batch_size, num_steps)

    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias=0.0, state_is_tuple=True)
    cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_layers, state_is_tuple=True)
    initial_state = cell.zero_state(batch_size, tf.float32)
    
    outputs = []
    state = initial_state
    with tf.variable_scope("RNN"):
      embedding = tf.get_variable("embedding", shape=[vocab_size, size], dtype=tf.float32)
      embed_inputs = tf.nn.embedding_lookup(embedding, inputs)
      outputs, state = tf.nn.dynamic_rnn(cell, embed_inputs, initial_state=initial_state)
      
    with tf.name_scope('loss'):
      output = tf.reshape(tf.concat(1, outputs), [-1, size])
      softmax_w = tf.get_variable("softmax_w", [size, vocab_size], dtype=tf.float32)
      softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=tf.float32)
      logits = tf.matmul(output, softmax_w) + softmax_b
      targets = tf.reshape(labels, [-1])
      #weights = tf.cast(targets > 0, tf.float32) # (PAD)は重み0、それ以外は1
      weights = tf.ones([batch_size * num_steps], dtype=tf.float32)
      sequence_loss = tf.nn.seq2seq.sequence_loss_by_example(
          [logits],
          [targets],
          [weights])
      loss = tf.reduce_sum(sequence_loss) / batch_size # perplexity
      tf.scalar_summary('loss', loss)
      probs = tf.reshape(tf.nn.softmax(logits), tf.shape(logits))
    # Optimizer.
    with tf.name_scope('optimizer'):
      global_step = tf.Variable(0, trainable=False)
      learning_rate = tf.train.exponential_decay(
        FLAGS.learning_rate, global_step, 10000, 0.96, staircase=True)
      optimizer = tf.train.GradientDescentOptimizer(learning_rate)
      gradients, v = zip(*optimizer.compute_gradients(loss))
      gradients, _ = tf.clip_by_global_norm(gradients, 1.25)
      optimizer = optimizer.apply_gradients(
        zip(gradients, v), global_step=global_step)
      tf.scalar_summary('learning_rate', learning_rate)
    
    # Saver
    saver = tf.train.Saver()

    # Summary
    summary = tf.merge_all_summaries()
  return optimizer, loss, inputs, probs, learning_rate, vocabulary, saver, summary

In [5]:
def train():
  graph = tf.Graph()
  optimizer, loss, inputs, probs, learning_rate, vocabulary, saver, summary = construct_graph(graph)

  train_writer = tf.train.SummaryWriter('tensorboard/train', graph)

  today = datetime.date.today().strftime("%Y%m%d")
  checkpoint_dir = "checkpoints/{}".format(today)

  with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    try:
      step = 0
      while not coord.should_stop():
        _, train_loss, train_inputs, train_probs, train_lr, train_summary = sess.run([optimizer, loss, inputs, probs, learning_rate, summary])
        train_writer.add_summary(train_summary, step)
        train_writer.flush()
        
        if step % 100 == 0:
          logger.info("=" * 80)
          logger.info("Loss at step {}: {}".format(step, train_loss))
#           x_string = utils.codes_to_string(train_inputs[0], vocabulary)
#           d_string = utils.codes_to_string(train_probs[0:140].argmax(axis=1), vocabulary)
          logger.info('  Input: ' + train_inputs[0],)
          logger.info('Decoded: ' + train_probs[0:100])
#         if step % 1000 == 0:
#           logger.info("Learning rate: {}".format(train_lr))
#           os.makedirs(checkpoint_dir, exist_ok=True)
#           save_path = saver.save(sess, "checkpoints/{}/model.ckpt".format(today))
#           logger.info("Model saved in file: %s" % save_path)
        step += 1

    except KeyboardInterrupt:
      logger.warn('Interrupted')
      save_path = saver.save(sess, "checkpoints/{}/model.ckpt".format(today))
      logger.info("Model saved in file: %s" % save_path)      
      coord.request_stop()
    except SystemExit as e:
      logger.warn('Exited')
      save_path = saver.save(sess, "checkpoints/{}/model.ckpt".format(today))
      logger.info("Model saved in file: %s" % save_path)   
      coord.request_stop(e)
    except Exception as e:
      logger.error('Exception: {}'.format(e.args))
      coord.request_stop(e)
    finally:
      coord.request_stop()
      coord.join(threads)
      
train()

ValueError: setting an array element with a sequence.