In [1]:
import string
import tensorflow as tf
import numpy as np
import os

In [2]:
# TODO: check all special symbols used as strings
# TODO: beam search

In [3]:
special_symbols = {
  '<pad>':    0,
  '<format>': 1, 
  '<next>':   2, 
  '<start>':  3, 
  '<end>':    4, 
  '{':        5, 
  '}':        6,
  ' ':        7
}

vocab = string.ascii_letters + '0123456789'

sym2id = {sym: i + len(special_symbols) for i, sym in enumerate(vocab)}
sym2id = {**special_symbols, **sym2id}
id2sym = {sym2id[sym]: sym for sym in sym2id}

vocab_size = len(vocab) + len(special_symbols)
assert vocab_size == len(sym2id) and vocab_size == len(id2sym)

In [4]:
def encode(template, subs):
  input = ['<start>'] + list(template) + ['<format>']
  for sub in subs:
    input += list(sub) + ['<next>']
  input[-1] = '<end>'
  output = ['<start>'] + list(template.format(*subs)) + ['<end>']
  return np.array([sym2id[x] for x in input]), np.array([sym2id[x] for x in output])

words = ['hello', 'world', 'boy', 'real', 'talk', 'something', 'might', 'be', 'not', 'important', 'wrong', 'here', 'random', 'just', 'i', 'work', 'words', 'my', 'name', 'is']

def sample_num():
  return str(np.random.randint(1000)), []
  
def sample_word():
  style = np.random.choice([str.upper, str.lower, str.title])
  word = np.random.choice(words)
  return style(word), []

def sample_placeholder():
  sample = np.random.choice([sample_word, sample_num])
  sub, _ = sample()
  return '{}', [sub]

def sample_placeholder_left():
  sample = np.random.choice([sample_word, sample_num])
  text, _  = sample()
  sub, _ = sample()
  return '{}' + text, [sub]

def sample_placeholder_right():
  sample = np.random.choice([sample_word, sample_num])
  text, _  = sample()
  sub, _ = sample()
  return text + '{}', [sub]

def sample_chunk():
  style = np.random.choice([
    sample_word, 
    sample_word, 
    sample_word, 
    sample_num, 
    sample_placeholder, 
    sample_placeholder_left,
    sample_placeholder_right
  ])
  
  return style()
  
def sample_chunks():
  template = []
  subs = []
  
  for i in range(np.random.randint(10)):
    t, s = sample_chunk()
    template.append(t)
    subs += s
  
  return ' '.join(template), subs
  
def dataset_gen():
  while True:
    source, target = encode(*sample_chunks())
    yield source, target

In [5]:
class Encoder(object):
  def __init__(self, num_units):
    self.encoder_fw = tf.nn.rnn_cell.GRUCell(num_units=num_units)
    self.encoder_bw = tf.nn.rnn_cell.GRUCell(num_units=num_units)
    
  def train(self, inputs, seq_len):
    (encoder_fw_outputs, encoder_bw_outputs), (encoder_fw_state, encoder_bw_state) = tf.nn.bidirectional_dynamic_rnn(
      self.encoder_fw, 
      self.encoder_bw, 
      inputs=inputs,
      sequence_length=seq_len,
      dtype=tf.float32)

    return encoder_fw_outputs, encoder_fw_state
  
class Decoder(object):
  def __init__(self, encoder_states, encoder_seq_len, num_units):
    decoder_fw = tf.nn.rnn_cell.GRUCell(num_units=num_units)
    
    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
      num_units, 
      encoder_states,
      memory_sequence_length=encoder_seq_len)
    
    decoder_fw = tf.contrib.seq2seq.AttentionWrapper(
      decoder_fw, 
      attention_mechanism,
      attention_layer_size=num_units)
    
    self.decoder_fw = decoder_fw
    self.projection_layer = tf.layers.Dense(vocab_size, use_bias=False)
    
  def train(self, initial_state, inputs, seq_len):
    helper = tf.contrib.seq2seq.TrainingHelper(
      inputs, 
      seq_len)

    decoder = tf.contrib.seq2seq.BasicDecoder(
      self.decoder_fw, 
      helper, 
      self.decoder_fw.zero_state(batch_size, tf.float32).clone(cell_state=initial_state),
      output_layer=self.projection_layer)

    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
      decoder,
      impute_finished=True)

    logits = outputs.rnn_output
    translations = outputs.sample_id
    return logits, translations
  
  def infer(self, initial_state, max_iterations):
    helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
      lambda ids: tf.one_hot(ids, vocab_size),
      tf.fill([batch_size], sym2id['<start>']), sym2id['<end>'])
    
    decoder = tf.contrib.seq2seq.BasicDecoder(
      self.decoder_fw, 
      helper, 
      self.decoder_fw.zero_state(batch_size, tf.float32).clone(cell_state=initial_state),
      output_layer=self.projection_layer)
      
    outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(
      decoder, 
      impute_finished=True,
      maximum_iterations=max_iterations)
    
    logits = outputs.rnn_output
    translations = outputs.sample_id
    return logits, translations

In [6]:
tf.reset_default_graph()

batch_size = 32
max_time = None

dataset = tf.data.Dataset.from_generator(dataset_gen, (tf.int32, tf.int32), ([None], [None]))
dataset = dataset.map(lambda source, target: ((source, tf.size(source)), (target, tf.size(target))))
dataset = dataset.padded_batch(
  batch_size, 
  padded_shapes=(([None], []), ([None], [])), 
  padding_values=((sym2id['<pad>'], 0), (sym2id['<pad>'], 0)))

iterator = dataset.make_one_shot_iterator()
(source, source_seq_len), (target, target_seq_len) = iterator.get_next()

encoder_inputs = source
encoder_seq_len = source_seq_len
decoder_inputs = target[:, :-1]
decoder_targets = target[:, 1:]
decoder_seq_len = target_seq_len - 1

In [7]:
encoder_num_units = 256
encoder = Encoder(encoder_num_units)

encoder_outputs, encoder_state = encoder.train(
  tf.one_hot(encoder_inputs, vocab_size),
  encoder_seq_len
)

In [8]:
decoder_num_units = 256
decoder = Decoder(encoder_outputs, encoder_seq_len, decoder_num_units)

decoder_logits, _ = decoder.train(
  encoder_state,
  tf.one_hot(decoder_inputs, vocab_size),
  decoder_seq_len
)

max_iterations = tf.round(tf.reduce_max(encoder_seq_len) * 2)
_, translations = decoder.infer(
  encoder_state,
  max_iterations
)

In [9]:
cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(
  labels=decoder_targets, 
  logits=decoder_logits)

loss_mask = tf.sequence_mask(decoder_seq_len, dtype=tf.float32)
loss = tf.reduce_sum(cross_ent * loss_mask) / batch_size

predicted = tf.argmax(tf.reshape(decoder_logits, [-1, vocab_size]), axis=1, output_type=tf.int32)
actual = tf.reshape(decoder_targets, [-1])
accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted, actual), tf.float32))

global_step = tf.get_variable('global_step', initializer=0, trainable=False)
learning_rate = tf.placeholder(tf.float32, name='learning_rate')
train = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

init = tf.global_variables_initializer()
saver = tf.train.Saver()

In [10]:
steps = 3000
lr = 0.001
log_interval = 200
save_interval = 200
restore = True
log_path = os.path.join('tf_log', 'pyformat_rnn_attention')
model_name = os.path.join(log_path, 'model')

with tf.Session() as sess:
  if restore:
    chkpt_fname = tf.train.latest_checkpoint(log_path)
    saver.restore(sess, chkpt_fname)
  else:
    sess.run(init)
  
  for i in range(sess.run(global_step), steps + 1):
    feed_dict = {learning_rate: lr}
    sess.run(train, feed_dict)
  
    if i % log_interval == 0:
      l, a = sess.run([loss, accuracy], feed_dict)
      print('step: {}, loss: {:.4f}, accuracy: {:.2f}%, learning rate: {}'.format(i, l, a * 100, lr))
      if l < 1:
        lr = 0.0001
      
    if i % save_interval == 0:
      save_path = saver.save(sess, model_name)
      print('model saved: {}'.format(save_path))
      
  dlv = sess.run(decoder_logits, feed_dict)
  e, d = sess.run([encoder_inputs, translations], feed_dict)

INFO:tensorflow:Restoring parameters from tf_log/pyformat_rnn_attention/model
step: 2400, loss: 2.6944, accuracy: 99.18%, learning rate: 0.001
model saved: tf_log/pyformat_rnn_attention/model
step: 2600, loss: 1.9991, accuracy: 99.18%, learning rate: 0.001
model saved: tf_log/pyformat_rnn_attention/model
step: 2800, loss: 1.1851, accuracy: 99.41%, learning rate: 0.001
model saved: tf_log/pyformat_rnn_attention/model
step: 3000, loss: 1.4311, accuracy: 99.37%, learning rate: 0.001
model saved: tf_log/pyformat_rnn_attention/model


In [37]:
i = 1

print(dlv[i, :, 0])

er = e[i]
dr = d[i]
source = ''.join([id2sym[id] for id in er])
trans = ''.join([id2sym[id] for id in dr])

f, s = source.split('<start>')[1].split('<format>')
s = s.split('<end>')[0].split('<next>')
expected = f.format(*s)
actual = trans.split('<end>')[0]

print(expected)
print(actual)

[-5.13029146 -2.66684866 -2.82518935 -6.95160198 -3.17210317 -1.44655383
 -2.6025188  -4.75844955 -2.57914996 -2.65188861 -2.3346839  -2.83357072
 -1.6656127  -2.54256296 -2.77526116 -1.12430501 -2.55588794 -2.46709704
 -2.37699795 -3.61170888 -2.79718423 -2.75532269 -1.72889149 -1.66192591
 -0.86099607 -2.54329658 -2.80042386 -4.04833412 -3.37177563 -4.44982529
 -1.66958427 -2.47438049 -2.06568766 -3.47707772 -2.36756301 -2.52705979
 -3.42812824 -2.22227359 -1.69416821 -2.68117976 -3.1669271  -2.28482294
 -3.43186975 -0.54691577 -1.41413486 -2.52128363 -1.45983052 -2.46355486
 -1.4677577  -1.79266477 -3.40038729  0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.          0.        ]
BE Might 353964 725733 work World
BE Might 353964 725733 work World
