In [1]:
import string
import tensorflow as tf
import numpy as np
import os

In [2]:
# TODO: check all special symbols used as strings
# TODO: beam search

In [3]:
special_symbols = {
  '<pad>':    0,
  '<format>': 1, 
  '<next>':   2, 
  '<start>':  3, 
  '<end>':    4, 
  '{':        5, 
  '}':        6,
  ' ':        7
}

sym2id = {sym: i + len(special_symbols) for i, sym in enumerate(string.ascii_letters)}
sym2id = {**special_symbols, **sym2id}
id2sym = {sym2id[sym]: sym for sym in sym2id}

vocab_size = len(string.ascii_letters) + len(special_symbols)
assert vocab_size == len(sym2id) and vocab_size == len(id2sym)

In [4]:
def encode(template, subs):
  input = ['<start>'] + list(template) + ['<format>']
  for sub in subs:
    input += list(sub) + ['<next>']
  input[-1] = '<end>'
  output = ['<start>'] + list(template.format(*subs)) + ['<end>']
  return np.array([sym2id[x] for x in input]), np.array([sym2id[x] for x in output])

words = ['hello', 'world', 'boy', 'real', 'talk', 'something', 'might', 'be', 'wrong', 'just', 'random', 'words']
words_2 = ['MY', 'NAME', 'IS', 'NOT', 'IMPORTANT', 'I', 'JUST', 'WORK', 'HERE']
def sample_dataset():
  n_words = np.random.randint(7)
  sampled = [np.random.choice(words) for _ in range(n_words)]
  text = ' {} '.join(sampled)
  subs = [np.random.choice(words_2) for _ in range(n_words - 1)]
  return text, subs

def next_batch(batch_size):
  inputs = []
  inputs_len = []
  outputs = []
  outputs_len = []
  for i in range(batch_size):
    input, output = encode(*sample_dataset())
    inputs.append(input)
    inputs_len.append(len(input))
    outputs.append(output)
    outputs_len.append(len(output))
    
  input_max_len = np.max([len(x) for x in inputs])
  output_max_len = np.max([len(x) for x in outputs])
  
  inputs = [np.pad(x, [[0, input_max_len - len(x)]], 'constant') for x in inputs]
  outputs = [np.pad(x, [[0, output_max_len - len(x)]], 'constant') for x in outputs]
    
  return np.array(inputs), np.array(outputs), np.array(inputs_len), np.array(outputs_len)

In [5]:
class Encoder(object):
  def __init__(self, num_units):
    self.encoder_fw = tf.nn.rnn_cell.GRUCell(num_units=num_units)
    self.encoder_bw = tf.nn.rnn_cell.GRUCell(num_units=num_units)
    
  def train(self, inputs, seq_len):
    (encoder_fw_outputs, encoder_bw_outputs), (encoder_fw_state, encoder_bw_state) = tf.nn.bidirectional_dynamic_rnn(
      self.encoder_fw, 
      self.encoder_bw, 
      inputs=inputs,
      sequence_length=seq_len,
      dtype=tf.float32)

    return encoder_fw_outputs, encoder_fw_state
  
class Decoder(object):
  def __init__(self, encoder_states, encoder_seq_len, num_units):
    decoder_fw = tf.nn.rnn_cell.GRUCell(num_units=num_units)
    
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
      num_units, 
      encoder_states,
      memory_sequence_length=encoder_seq_len)
    
    decoder_fw = tf.contrib.seq2seq.AttentionWrapper(
      decoder_fw, 
      attention_mechanism,
      attention_layer_size=num_units)
    
    self.decoder_fw = decoder_fw
    self.projection_layer = tf.layers.Dense(vocab_size, use_bias=False)
    
  def train(self, initial_state, inputs, seq_len):
    helper = tf.contrib.seq2seq.TrainingHelper(
      inputs, 
      seq_len)

    decoder = tf.contrib.seq2seq.BasicDecoder(
      self.decoder_fw, 
      helper, 
      self.decoder_fw.zero_state(batch_size, tf.float32).clone(cell_state=initial_state),
      output_layer=self.projection_layer)

    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
      decoder,
      impute_finished=True)

    logits = outputs.rnn_output
    translations = outputs.sample_id
    return logits, translations
  
  def infer(self, initial_state, max_iterations):
    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
      lambda ids: tf.one_hot(ids, vocab_size),
      tf.fill([batch_size], sym2id['<start>']), sym2id['<end>'])
    
    decoder = tf.contrib.seq2seq.BasicDecoder(
      self.decoder_fw, 
      helper, 
      self.decoder_fw.zero_state(batch_size, tf.float32).clone(cell_state=initial_state),
      output_layer=self.projection_layer)
      
    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
      decoder, 
      impute_finished=True,
      maximum_iterations=max_iterations)
    
    logits = outputs.rnn_output
    translations = outputs.sample_id
    return logits, translations

In [6]:
tf.reset_default_graph()

batch_size = 32
max_time = None
  
encoder_num_units = 256
encoder_inputs = tf.placeholder(tf.int32, [batch_size, max_time], name='encoder_inputs')
encoder_seq_len = tf.placeholder(tf.int32, [batch_size], name='encoder_seq_len')

encoder = Encoder(encoder_num_units)
encoder_outputs, encoder_state = encoder.train(
  tf.one_hot(encoder_inputs, vocab_size),
  encoder_seq_len
)

decoder_num_units = 256
decoder_inputs = tf.placeholder(tf.int32, [batch_size, max_time], name='decoder_inputs')
decoder_targets = tf.placeholder(tf.int32, [batch_size, max_time], name='decoder_targets')
decoder_seq_len = tf.placeholder(tf.int32, [batch_size], name='decoder_seq_len')

decoder = Decoder(encoder_outputs, encoder_seq_len, decoder_num_units)

decoder_logits, _ = decoder.train(
  encoder_state,
  tf.one_hot(decoder_inputs, vocab_size),
  decoder_seq_len
)

max_iterations = tf.round(tf.reduce_max(encoder_seq_len) * 2)
_, translations = decoder.infer(
  encoder_state,
  max_iterations
)

cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(
  labels=decoder_targets, 
  logits=decoder_logits)

loss_mask = tf.sequence_mask(decoder_seq_len, dtype=tf.float32)
loss = tf.reduce_sum(cross_ent * loss_mask) / batch_size

predicted = tf.argmax(tf.reshape(decoder_logits, [-1, vocab_size]), axis=1, output_type=tf.int32)
actual = tf.reshape(decoder_targets, [-1])
accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted, actual), tf.float32))

global_step = tf.Variable(0, name='global_step', trainable=False)
learning_rate = tf.placeholder(tf.float32, name='learning_rate')
train = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

init = tf.global_variables_initializer()
saver = tf.train.Saver()

In [7]:
steps = 3000
lr = 0.001
log_interval = 200
save_interval = 200
restore = True
log_path = os.path.join('tf_log', 'pyformat_rnn_attention')
model_name = os.path.join(log_path, 'model')

with tf.Session() as sess:
  if restore:
    chkpt_fname = tf.train.latest_checkpoint(log_path)
    saver.restore(sess, chkpt_fname)
  else:
    sess.run(init)
  
  for i in range(sess.run(global_step), steps + 1):
    inputs, outputs, inputs_len, outputs_len = next_batch(batch_size)
    feed_dict = {learning_rate: lr,
                 encoder_inputs: inputs,
                 encoder_seq_len: inputs_len - 1,
                 decoder_inputs: outputs[:, :-1],
                 decoder_seq_len: outputs_len - 1,
                 decoder_targets: outputs[:, 1:]}
    
    sess.run(train, feed_dict=feed_dict)
  
    if i % log_interval == 0:
      l, a = sess.run([loss, accuracy], feed_dict=feed_dict)
      print('step: {}, loss: {:.4f}, accuracy: {:.2f}%, learning rate: {}'.format(i, l, a * 100, lr))
      if l < 1:
        lr = 0.0001
      
    if i % save_interval == 0:
      save_path = saver.save(sess, model_name)
      print('model saved: {}'.format(save_path))
      
  dlv = sess.run(decoder_logits, feed_dict)
  e, d = sess.run([encoder_inputs, translations], feed_dict)

INFO:tensorflow:Restoring parameters from tf_log/pyformat_rnn_attention/model
step: 2600, loss: 0.4297, accuracy: 99.79%
step: 2800, loss: 0.1931, accuracy: 99.91%
step: 3000, loss: 0.0436, accuracy: 100.00%
model saved: tf_log/pyformat_rnn_attention/model


In [8]:
i = 0

print(dlv[i, :, 0])

er = e[i]
dr = d[i]
source = ''.join([id2sym[id] for id in er])
trans = ''.join([id2sym[id] for id in dr])
print(source)
print(trans)

[-2.98189044 -5.23765612 -2.71927404 -5.57502174 -4.33229923 -6.28060246
 -4.4224081  -3.3578763  -1.7469703  -5.72459555 -4.03153229 -2.79955482
 -3.47255945 -4.89219236 -5.48732996 -4.38164806 -1.31911492 -3.60341144
 -2.48678541 -1.28363502 -2.0377028  -3.11512589 -1.21032524 -0.55045027
 -3.18950129 -1.92290854 -3.51135135 -1.00990498 -1.64928806 -2.54523993
 -1.32728994 -2.58081651 -2.7898376  -2.61958647 -0.680484   -3.73724937
 -2.7693305  -3.96833801 -1.42586076 -2.03165746 -2.32882977 -3.4909904
 -2.79012346 -4.97259235 -3.88369823 -3.59264708 -3.01291203 -3.46519208
 -4.20907211 -2.59122896 -4.75237989 -3.58106995 -3.71496677 -1.98655272
 -5.69217539  0.          0.          0.          0.          0.          0.        ]
<start>something {} real {} boy {} world {} might {} wrong<format>MY<next>MY<next>HERE<next>I<next>HERE<end><pad><pad><pad><pad><pad><pad>
something MY real MY boy HERE world I might HERE wrong<end><pad><pad><pad><pad><pad><pad>
