In [1]:
import string
import tensorflow as tf
import numpy as np

  return f(*args, **kwds)


In [2]:
special_symbols = {
  '<format>': 0, 
  '<next>': 1, 
  '<end>': 2, 
  '{': 3, 
  '}': 4,
  ' ': 5
}

In [3]:
def sym2id(id):
  if id in special_symbols:
    return special_symbols[id]
  else:
    i = string.ascii_letters.index(id) + len(special_symbols)
    assert 0 < i < vocab_size
    return i

def encode(template, subs):
  input = list(template) + ['<format>']
  for sub in subs:
    input += list(sub) + ['<next>']
  input[-1] = '<end>'
  output = ['<end>'] + list(template.format(*subs)) + ['<end>']
  return np.array([sym2id(x) for x in input]), np.array([sym2id(x) for x in output])

def next_batch(batch_size):
  inputs = []
  inputs_len = []
  outputs = []
  outputs_len = []
  for i in range(batch_size):
    input, output = encode('hello {} world {} boy', ['ABC' * np.random.randint(3), 'DEF' * np.random.randint(3)])
    inputs.append(input)
    inputs_len.append(len(input))
    outputs.append(output)
    outputs_len.append(len(output))
    
  input_max_len = np.max([len(x) for x in inputs])
  output_max_len = np.max([len(x) for x in outputs])
  
  inputs = [np.pad(x, [[0, input_max_len - len(x)]], 'constant') for x in inputs]
  outputs = [np.pad(x, [[0, output_max_len - len(x)]], 'constant') for x in outputs]
    
  return np.array(inputs), np.array(outputs), np.array(inputs_len), np.array(outputs_len)

In [22]:
tf.reset_default_graph()

batch_size = 32
vocab_size = len(string.ascii_letters) + len(special_symbols)
learning_rate = tf.placeholder(tf.float32, name='learning_rate')

encoder_hidden_size = 128
encoder_fw = tf.nn.rnn_cell.GRUCell(num_units=encoder_hidden_size)
encoder_bw = tf.nn.rnn_cell.GRUCell(num_units=encoder_hidden_size)
encoder_fw_init = tf.zeros([batch_size, encoder_hidden_size])
encoder_bw_init = tf.zeros([batch_size, encoder_hidden_size])
encoder_inputs = tf.placeholder(tf.int32, [batch_size, None], name='encoder_inputs')
encoder_seq_len = tf.placeholder(tf.int32, [batch_size], name='encoder_seq_len')

(_encoder_fw_outputs, _encoder_bw_outputs), (encoder_fw_state, encoder_bw_state) = tf.nn.bidirectional_dynamic_rnn(
  encoder_fw, 
  encoder_bw, 
  inputs=tf.one_hot(encoder_inputs, vocab_size),
  initial_state_fw=encoder_fw_init,
  initial_state_bw=encoder_bw_init,
  sequence_length=encoder_seq_len)

decoder_hidden_size = 128
decoder_fw = tf.nn.rnn_cell.GRUCell(num_units=decoder_hidden_size)
decoder_inputs = tf.placeholder(tf.int32, [batch_size, None], name='decoder_inputs')
decoder_targets = tf.placeholder(tf.int32, [batch_size, None], name='decoder_targets')
decoder_seq_len = tf.placeholder(tf.int32, [batch_size], name='decoder_seq_len')

(decoder_outputs, _decoder_state) = tf.nn.dynamic_rnn(
  decoder_fw,
  inputs=tf.one_hot(decoder_inputs, vocab_size),
  initial_state=encoder_fw_state,
  sequence_length=decoder_seq_len)

dense = tf.layers.Dense(vocab_size)
decoder_logits = dense(decoder_outputs)
loss_mask = tf.sequence_mask(decoder_seq_len, dtype=tf.float32)

def softmax_loss_function(labels, logits):
  return tf.nn.softmax_cross_entropy_with_logits(
    labels=tf.one_hot(labels, vocab_size), 
    logits=logits)

loss = tf.contrib.seq2seq.sequence_loss(
  logits=decoder_logits, 
  targets=decoder_targets, 
  weights=loss_mask)

predicted = tf.argmax(tf.reshape(decoder_logits, [-1, vocab_size]), axis=1, output_type=tf.int32)
actual = tf.reshape(decoder_targets, [-1])
accuracy = tf.reduce_mean(tf.cast(tf.equal(predicted, actual), tf.float32))

train = tf.train.AdamOptimizer(learning_rate).minimize(loss)
init = tf.global_variables_initializer()

# start_tokens = tf.zeros([batch_size], dtype=tf.int32) + sym2id('<end>')
# end_token = sym2id('<end>')
# bs = tf.contrib.seq2seq.BeamSearchDecoder(
#   decoder_fw, 
#   embedding=lambda ids: tf.one_hot(ids, vocab_size),
#   start_tokens=start_tokens,
#   end_token=end_token,
#   initial_state=encoder_fw_state,
#   output_layer=dense,
#   beam_width=1)

In [5]:
steps = 3000
lr = 0.001
log_interval = 100

with tf.Session() as sess:
  sess.run(init)
  
  for i in range(steps):
    inputs, outputs, inputs_len, outputs_len = next_batch(batch_size)
    feed_dict = {learning_rate: lr,
                 encoder_inputs: inputs,
                 encoder_seq_len: inputs_len - 1,
                 decoder_inputs: outputs[:, :-1],
                 decoder_seq_len: outputs_len - 1,
                 decoder_targets: outputs[:, 1:]}
    
    sess.run(train, feed_dict=feed_dict)
  
    if i % log_interval == 0:
      l, a = sess.run([loss, accuracy], feed_dict=feed_dict)
      print(i, l, a)

0 4.01839 0.0822917
100 0.95198 0.728125
200 0.144824 0.771875
300 0.0814916 0.783333
400 0.125753 0.71875
500 0.0167648 0.821875
600 0.00928334 0.7875
700 0.00489141 0.81875
800 0.00344787 0.775
900 0.00244948 0.83125
1000 0.00196289 0.809375
1100 0.00158587 0.8375
1200 0.00133057 0.7875
1300 0.00109435 0.80625
1400 0.000970491 0.778125
1500 0.000848987 0.7875
1600 0.000713601 0.83125
1700 0.000634996 0.79375
1800 0.000556995 0.778125
1900 0.000502672 0.8125
2000 0.000447512 0.815625
2100 0.000403591 0.796875
2200 0.000366558 0.809375
2300 0.000321132 0.81875
2400 0.000309769 0.809375
2500 0.000277892 0.809375
2600 0.000256962 0.775
2700 0.000233995 0.821875
2800 0.000214114 0.790625
2900 0.000196642 0.803125
