In [1]:
import string
import tensorflow as tf
import numpy as np
import os
import datasets.python_format as dataset

In [2]:
# TODO: beam search

In [3]:
class Encoder(object):
  def __init__(self, num_units):
    self.encoder_fw = tf.nn.rnn_cell.GRUCell(num_units=num_units)
    self.encoder_bw = tf.nn.rnn_cell.GRUCell(num_units=num_units)
    
  def train(self, inputs, seq_len):
    (encoder_fw_outputs, encoder_bw_outputs), (encoder_fw_state, encoder_bw_state) = tf.nn.bidirectional_dynamic_rnn(
      self.encoder_fw, 
      self.encoder_bw, 
      inputs=inputs,
      sequence_length=seq_len,
      dtype=tf.float32)

    return encoder_fw_outputs, encoder_fw_state
  
class Decoder(object):
  def __init__(self, encoder_states, encoder_seq_len, num_units):
    decoder_fw = tf.nn.rnn_cell.GRUCell(num_units=num_units)
    
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
      num_units, 
      encoder_states,
      memory_sequence_length=encoder_seq_len)
    
    decoder_fw = tf.contrib.seq2seq.AttentionWrapper(
      decoder_fw, 
      attention_mechanism,
      attention_layer_size=num_units)
    
    self.decoder_fw = decoder_fw
    self.projection_layer = tf.layers.Dense(dataset.vocab_size, use_bias=False)
    
  def train(self, initial_state, inputs, seq_len):
    helper = tf.contrib.seq2seq.TrainingHelper(
      inputs, 
      seq_len)

    decoder = tf.contrib.seq2seq.BasicDecoder(
      self.decoder_fw, 
      helper, 
      self.decoder_fw.zero_state(batch_size, tf.float32).clone(cell_state=initial_state),
      output_layer=self.projection_layer)

    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
      decoder,
      impute_finished=True)

    logits = outputs.rnn_output
    translations = outputs.sample_id
    return logits, translations
  
  def infer(self, initial_state, max_iterations):
    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
      lambda ids: tf.one_hot(ids, dataset.vocab_size),
      tf.fill([batch_size], dataset.sos), 
      dataset.eos,
    )
    
    decoder = tf.contrib.seq2seq.BasicDecoder(
      self.decoder_fw, 
      helper, 
      self.decoder_fw.zero_state(batch_size, tf.float32).clone(cell_state=initial_state),
      output_layer=self.projection_layer)
      
    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
      decoder, 
      impute_finished=True,
      maximum_iterations=max_iterations)
    
    logits = outputs.rnn_output
    translations = outputs.sample_id
    return logits, translations

In [4]:
tf.reset_default_graph()

batch_size = 32
max_time = None

ds = tf.data.Dataset.from_generator(
  lambda: dataset.gen(min_len=3, max_len=7), 
  (tf.int32, tf.int32), 
  ([None], [None]),
)
ds = ds.map(lambda source, target: (
  (tf.concat([[dataset.sos], source, [dataset.eos]], 0), tf.size(source) + 2), 
  (tf.concat([[dataset.sos], target, [dataset.eos]], 0), tf.size(target) + 2),
))
ds = ds.padded_batch(
  batch_size, 
  padded_shapes=(([None], []), ([None], [])), 
  padding_values=((dataset.pad, 0), (dataset.pad, 0)))

iterator = ds.make_one_shot_iterator()
(source, source_seq_len), (target, target_seq_len) = iterator.get_next()

encoder_inputs = source
encoder_seq_len = source_seq_len
decoder_inputs = target[:, :-1]
decoder_targets = target[:, 1:]
decoder_seq_len = target_seq_len - 1

In [5]:
encoder_num_units = 256
encoder = Encoder(encoder_num_units)

encoder_outputs, encoder_state = encoder.train(
  tf.one_hot(encoder_inputs, dataset.vocab_size),
  encoder_seq_len
)

In [6]:
decoder_num_units = 256
decoder = Decoder(encoder_outputs, encoder_seq_len, decoder_num_units)

decoder_logits, _ = decoder.train(
  encoder_state,
  tf.one_hot(decoder_inputs, dataset.vocab_size),
  decoder_seq_len
)

max_iterations = tf.round(tf.reduce_max(encoder_seq_len) * 2)
_, translations = decoder.infer(
  encoder_state,
  max_iterations
)

In [7]:
cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(
  labels=decoder_targets, 
  logits=decoder_logits)

loss_mask = tf.sequence_mask(decoder_seq_len, dtype=tf.float32)
loss = tf.reduce_sum(cross_ent * loss_mask) / batch_size

predicted = tf.argmax(tf.reshape(decoder_logits, [-1, dataset.vocab_size]), axis=1, output_type=tf.int32)
actual = tf.reshape(decoder_targets, [-1])
accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted, actual), tf.float32))

global_step = tf.get_variable('global_step', initializer=0, trainable=False)
learning_rate = tf.placeholder(tf.float32, name='learning_rate')
train = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

init = tf.global_variables_initializer()
saver = tf.train.Saver()

In [8]:
steps = 2000
lr = 0.001
log_interval = 200
save_interval = 200
log_path = os.path.join('tf_log', 'pyformat_rnn_attention')
model_name = os.path.join(log_path, 'model')
checkpoint_filename = tf.train.latest_checkpoint(log_path)
restore = checkpoint_filename is not None

with tf.Session() as sess:
  if restore:
    saver.restore(sess, checkpoint_filename)
  else:
    sess.run(init)
  
  feed_dict = {learning_rate: lr}
  
  for i in range(sess.run(global_step), steps + 1):
    sess.run(train, feed_dict)
  
    if i % log_interval == 0:
      l, a = sess.run([loss, accuracy], feed_dict)
      print('step: {}, loss: {:.4f}, accuracy: {:.2f}%, learning rate: {}'.format(i, l, a * 100, lr))
      
    if i % save_interval == 0:
      save_path = saver.save(sess, model_name)
      print('model saved: {}'.format(save_path))
      
  dlv = sess.run(decoder_logits, feed_dict)
  sb, tb, pb = sess.run([encoder_inputs, target, translations], feed_dict)

INFO:tensorflow:Restoring parameters from tf_log/pyformat_rnn_attention/model
step: 1800, loss: 2.2008, accuracy: 99.43%, learning rate: 0.001
model saved: tf_log/pyformat_rnn_attention/model
step: 2000, loss: 0.5558, accuracy: 99.72%, learning rate: 0.001
model saved: tf_log/pyformat_rnn_attention/model


In [9]:
i = 0
s = sb[i]
t = tb[i]
p = pb[i]

s = dataset.decode(s)
t = dataset.decode(t)
p = dataset.decode(p)

s = s.split('<s>')[1].split('</s>')[0]
t = t.split('<s>')[1].split('</s>')[0]
p = p.split('</s>')[0]

print(s)
print(t)
print(p)

{} {} {}<f>voluptatibus<n>ipsum<n>Alias
voluptatibus ipsum Alias
voluptatibus ipsum Alias
