# RNN - Encoder-Decoder with `static_rnn()`

## Let's Go

In [1]:
!rm -fr logdir3
!mkdir -p logdir3

In [2]:
%load_ext do_not_print_href
%matplotlib inline
from __future__ import print_function, division
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '-'
import sys
import time
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

## Prepare data set - symbols & symbol map

In [3]:
import numpy as np

# - 입력/출력 문자는 0~9 까지의 숫자,
#   연산자 `'+'`, `' '` (END), `'='` (START) 포함하여 13 종류
# - 사용가능한 연산자는 `'+'`
# - 문자(symbol) 와 해당 문자의 인덱스 넘버 사이의 변환테이블들
#   - `symbols[]`    : 인덱스에서 문자로
#   - `symbol_map[]` : 문자에서 인덱스로

symbols        = [' ', '0', '1', '2', '3', '4', '5',
                  '6', '7', '8', '9', '+', '=']
operators      = ['+']

symbol_map     = {s: i \
                  for i,s in enumerate(symbols)}

input_units    = output_units    = len(symbol_map)

hidden_units   = 100

encoder_max_seq_len = 7
decoder_max_seq_len = 5

BATCH_SIZE = 200


def make_random_data():
    t1        = str(np.random.randint(1000))
    op        = np.random.choice(operators)
    t2        = str(np.random.randint(1000))
    
    expr      = t1 + op + t2
    ans       = '='+str(eval(expr))+' '
    
    return expr, ans

def one_hot(n):
    """
    3 ==> [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    """
    res = np.zeros(13, dtype=np.float32)
    res[n] = 1.0
    return res

def arg_max(v):
    """
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] ==> 3
    """
    return np.argmax(v, axis=-1)

# test
assert 7 == arg_max(one_hot(7))

def str_to_onehot(str, max_seq_len):
    buf           = np.zeros([max_seq_len,input_units])
    buf          += \
      one_hot(symbol_map[' ']).reshape([1,-1]) # <<<===
    seq_len       = len(str)
    buf[:seq_len] = [one_hot(symbol_map[c]) for c in str]
    return buf

def onehot_to_str(data, data_len):
    return ''.join([symbols[v] \
                    for v in arg_max(data)][:data_len])





def encode_data(expr, ans):
    e_seq_len         = len(expr)
    e_in              = str_to_onehot(expr,
                                      encoder_max_seq_len)
    d_seq_len         = len(ans) - 1
    d_in              = str_to_onehot(ans[:-1],
                                      decoder_max_seq_len)
    d_out             = str_to_onehot(ans[1:],
                                      decoder_max_seq_len)
    return e_seq_len, e_in, d_seq_len, d_in, d_out

def decode_data(e_len, e_in, d_len, d_in, d_out):
    return  e_len, \
            onehot_to_str(e_in, e_len), \
            d_len, \
            onehot_to_str(d_in, d_len), \
            onehot_to_str(d_out, d_len)

            

class Dataset:
    def __init__(self):
        self.encoder_seq_len  = []
        self.encoder_in_data  = []
        self.decoder_seq_len  = []
        self.decoder_in_data  = []
        self.decoder_out_data = []
        
    def append(self, t):
        self.encoder_seq_len.append(t[0])
        self.encoder_in_data.append(t[1])
        self.decoder_seq_len.append(t[2])
        self.decoder_in_data.append(t[3])
        self.decoder_out_data.append(t[4])
        
    def next_batch(self,batch_size=BATCH_SIZE):
        data_len = len(self.encoder_seq_len)
        batch_pointer = 0
        while batch_pointer + batch_size <= data_len:
            ss   = np.random.randint(
                data_len - batch_size - 1)
            yield \
                self.encoder_seq_len[ss:ss+batch_size], \
                self.encoder_in_data[ss:ss+batch_size], \
                self.decoder_seq_len[ss:ss+batch_size], \
                self.decoder_in_data[ss:ss+batch_size], \
                self.decoder_out_data[ss:ss+batch_size]
            batch_pointer += batch_size

    def get(self, index):
        return \
            self.encoder_seq_len[index], \
            self.encoder_in_data[index], \
            self.decoder_seq_len[index], \
            self.decoder_in_data[index], \
            self.decoder_out_data[index]

In [4]:
train_num_data = 60000
test_num_data  = 10000

In [5]:
np.random.seed(37L)

train_data = Dataset()
for i in range(train_num_data):
    expr, ans     = make_random_data()
    train_data.append(encode_data(expr, ans))

test_data  = Dataset()
for i in range(test_num_data):
    expr, ans     = make_random_data()
    test_data.append(encode_data(expr, ans))

In [6]:
e_len,e_in,d_len,d_in,d_out = train_data.next_batch().next()
for i in range(10):
    print(decode_data(e_len[i],e_in[i],d_len[i],d_in[i],d_out[i]))

(6, '756+55', 4, '=811', '811 ')
(7, '269+927', 5, '=1196', '1196 ')
(7, '208+259', 4, '=467', '467 ')
(7, '529+653', 5, '=1182', '1182 ')
(7, '928+119', 5, '=1047', '1047 ')
(7, '811+285', 5, '=1096', '1096 ')
(7, '713+889', 5, '=1602', '1602 ')
(7, '866+703', 5, '=1569', '1569 ')
(7, '868+410', 5, '=1278', '1278 ')
(7, '790+602', 5, '=1392', '1392 ')


## Tensorflow - build graph

In [7]:
tf.reset_default_graph()

## Tensorflow - placeholders

In [8]:
encoder_inputs   = tf.placeholder(
    dtype=tf.float32,
    shape=[None, encoder_max_seq_len, input_units],
    name='encoder_inputs')
encoder_seqlen   = tf.placeholder(
    dtype=tf.int32,
    shape=[None],
    name='encoder_seqlen')
decoder_inputs   = tf.placeholder(
    dtype=tf.float32,
    shape=[None, decoder_max_seq_len, input_units],
    name='decoder_inputs')
decoder_targets  = tf.placeholder(
    dtype=tf.float32,
    shape=[None, decoder_max_seq_len, output_units],
    name='decoder_targets')
decoder_seqlen   = tf.placeholder(
    dtype=tf.int32,
    shape=[None],
    name='decoder_seqlen')
encoder_training = tf.placeholder(
    dtype=tf.bool,
    shape=None,
    name='encoder_training')
tf_batch_size = tf.shape(encoder_inputs)[0] # <<== !!!

In [9]:
dropout_rate = 0.2
keep_prob = tf.cond(encoder_training,
                    lambda: tf.constant(1.0-dropout_rate),
                    lambda: tf.constant(1.0),
                    name='keep_prob')

## RNN, Encoder/Decoder - using `static_rnn()`

- legacy_seq2seq 모듈 아래 메소드들은 입력데이터 포맷 요구 사항이 dynamic_rnn 때와 다름


- dynamic_rnn: (default)
    
  - <span style="color:red">Tensor</span> of shape: [batch_size, max_seq_len, num_units]


- legacy_seqseq:
    
  - <span style="color:red">List</span> of tensors of shape: [batch_size, num_units]



### _batch-major_ to _time-major_

- [`tf.unstack()`](http://devdocs.io/tensorflow~python/tf/unstack)

In [10]:
config = tf.ConfigProto(gpu_options={'per_process_gpu_memory_fraction': 0.0})
sess  = tf.InteractiveSession(config=config)

# b_maj : batch_size = 2, max_seq_len = 3, num_units = 4
b_maj = tf.reshape(tf.range(24,dtype=tf.float32),[2,3,4])
b_maj.eval()

array([[[  0.,   1.,   2.,   3.],
        [  4.,   5.,   6.,   7.],
        [  8.,   9.,  10.,  11.]],

       [[ 12.,  13.,  14.,  15.],
        [ 16.,  17.,  18.,  19.],
        [ 20.,  21.,  22.,  23.]]], dtype=float32)

In [11]:
t_maj = tf.unstack(b_maj,3,1)
t_maj

[<tf.Tensor 'unstack:0' shape=(2, 4) dtype=float32>,
 <tf.Tensor 'unstack:1' shape=(2, 4) dtype=float32>,
 <tf.Tensor 'unstack:2' shape=(2, 4) dtype=float32>]

In [12]:
[t.eval() for t in t_maj]

[array([[  0.,   1.,   2.,   3.],
        [ 12.,  13.,  14.,  15.]], dtype=float32),
 array([[  4.,   5.,   6.,   7.],
        [ 16.,  17.,  18.,  19.]], dtype=float32),
 array([[  8.,   9.,  10.,  11.],
        [ 20.,  21.,  22.,  23.]], dtype=float32)]

### _time-major_ to _batch-major_

- [`tf.stack()`](http://devdocs.io/tensorflow~python/tf/stack)


In [13]:
b_maj_again = tf.stack(t_maj, 1)
b_maj_again.eval()

array([[[  0.,   1.,   2.,   3.],
        [  4.,   5.,   6.,   7.],
        [  8.,   9.,  10.,  11.]],

       [[ 12.,  13.,  14.,  15.],
        [ 16.,  17.,  18.,  19.],
        [ 20.,  21.,  22.,  23.]]], dtype=float32)

### _batch-major_ input/output 에 _time-major_ rnn 적용

In [14]:
# [batch_size, max_seq_len, num_units] 형태의 tensor들을
# list of [batch_size, num_units] 형태로 바꿔야 한다.
# (batch-major to time-major)

with tf.name_scope('encoder_input_transform'):
    e_inputs   = tf.unstack(
        encoder_inputs,
        encoder_max_seq_len,
        1,
        name='transformed_encoder_inputs')
    # shape: list of ([batch_size, num_units])

with tf.name_scope('decoder_input_transform'):
    d_inputs   = tf.unstack(
        decoder_inputs,
        decoder_max_seq_len,
        1,
        name='transformed_decoder_inputs')
    # shape: list of ([batch_size, num_units])

# tf.contrib.legacy_seq2seq.basic_rnn_seq2seq 를
# 쓸 수 없음:
#   - decoder state 를 줄 수 없음
#   - training/inference 모드를 지정할 수 없음

# dec_outs_, state = \
#     tf.contrib.legacy_seq2seq.basic_rnn_seq2seq(
#         e_inputs,
#         d_inputs,
#         cell,
#         dtype=tf.float32)

def e_cell(input_size):
    cell = tf.contrib.rnn.BasicRNNCell(hidden_units)
#     cell = tf.contrib.rnn.LSTMCell(hidden_units)
#     cell = tf.contrib.rnn.LayerNormBasicLSTMCell(
#              hidden_units)
    cell = tf.contrib.rnn.DropoutWrapper(
        cell,
        state_keep_prob = keep_prob,
        variational_recurrent = True,
        input_size = input_size,
        dtype = tf.float32)
    return cell

with tf.variable_scope('encoder'):
    cell = tf.contrib.rnn.MultiRNNCell(
        [
            e_cell(input_units),
            e_cell(hidden_units),
            e_cell(hidden_units)
        ])
    initial_state = cell.zero_state(
        batch_size=tf_batch_size,
        dtype=tf.float32)
    enc_outs_, encoder_state = \
        tf.nn.static_rnn(
            cell,
            e_inputs,
            initial_state=initial_state,
            sequence_length=encoder_seqlen,
            dtype=tf.float32)

def d_cell():
    cell = tf.contrib.rnn.BasicRNNCell(hidden_units)
#     cell = tf.contrib.rnn.LSTMCell(hidden_units)
#     cell = tf.contrib.rnn.LayerNormBasicLSTMCell(
#              hidden_units)
    return cell

with tf.variable_scope('decoder'):
    cell = tf.contrib.rnn.MultiRNNCell(
        [d_cell() for _ in range(3)])
    dec_outs_, decoder_state = \
        tf.nn.static_rnn(
            cell,
            d_inputs,
            initial_state=encoder_state,
            sequence_length=decoder_seqlen,
            dtype=tf.float32)

# outputs 에 대해서는 위 inputs 들에 대한 변환의 역변환
# (time-major to batch-major)
encoder_out = \
    tf.stack(enc_outs_, 1, name='encoder_outputs')
decoder_out = \
    tf.stack(dec_outs_, 1, name='decoder_outputs')


In [15]:
decoder_out.shape.as_list()

[None, 5, 100]

## Fully Connected Network after RNN


In [16]:
outputs = tf.layers.dense(
    decoder_out,
    output_units,
    name='seq2seq_outputs')

## Compare training target vs output, do optimize

In [17]:
seq_mask_ = tf.sequence_mask(
                decoder_seqlen,
                maxlen=decoder_max_seq_len,
                dtype=tf.float32)

seq_mask  = \
    tf.tile(
        tf.reshape(
            seq_mask_,
            [-1,decoder_max_seq_len,1]),
        [1,1,output_units])

In [18]:
loss         = tf.losses.softmax_cross_entropy(
                    decoder_targets * seq_mask,
                    outputs * seq_mask)

optimizer    = tf.train.AdamOptimizer( \
                  learning_rate=0.001)
optimize     = optimizer.minimize(
    loss,
    name='minimize')

## Comparing Sequences

In [19]:
def seq_equals(a,b,a_len=None,b_len=None):
    if a_len is None: a_len = len(a)
    if b_len is None: b_len = len(b)
    a_nums = np.argmax(a[:a_len],-1)
    b_nums = np.argmax(b[:b_len],-1)
    return 1.0 * np.all(np.equal(a_nums, b_nums))

## Training loop

In [20]:
import time

In [21]:
def train(num_epochs, writer):
    t_start = time.time()
    step = 0
    for epoch in range(num_epochs):
        losses  = []
        errs    = []
        for e_len, e_in, d_len, d_in, d_out \
                in train_data.next_batch():
            feed = {
                encoder_training: True,
                encoder_seqlen:   e_len,
                encoder_inputs:   e_in,
                decoder_seqlen:   d_len,
                decoder_inputs:   d_in,
                decoder_targets:  d_out,
            }
            _, out, training_loss = \
                sess.run([optimize, outputs, loss], feed)
            training_err = 1.0 - \
                np.mean([
                    seq_equals(a,b,a_len,b_len)
                    for a,b,a_len,b_len in
                    zip(d_out,out,d_len,d_len)
                ])
            losses.append(training_loss)
            errs.append(training_err)
        test_errs   = []
        for e_len, e_in, d_len, d_in, d_out \
                in test_data.next_batch():
            feed = {
                encoder_training: False,
                encoder_seqlen:   e_len,
                encoder_inputs:   e_in,
                decoder_seqlen:   d_len,
                decoder_inputs:   d_in,
                decoder_targets:  d_out,
            }
            out, = sess.run([outputs], feed)
            test_err = 1.0 - \
                np.mean([
                    seq_equals(a,b,a_len,b_len)
                    for a,b,a_len,b_len in
                    zip(d_out,out,d_len,d_len)
                ])
            test_errs.append(test_err)
        mean_loss       = np.mean(losses)
        mean_err        = np.mean(errs)
        mean_test_err   = np.mean(test_errs)
        summary = tf.Summary(
            value=[
                tf.Summary.Value(
                    tag='loss',
                    simple_value=mean_loss),
                tf.Summary.Value(
                    tag='train_err',
                    simple_value=mean_err),
                tf.Summary.Value(
                    tag='test_err',
                    simple_value=mean_test_err),
            ])
        writer.add_summary(summary,epoch+1)
        if 0 == (epoch+1) % 10:
            t_elapsed = time.time() - t_start
            print(('epoch: {:d}, loss: {:.5f}, ' +
                   'err: {:.5f}, test_err: {:.5f}, ' +
                   'elapsed: {:.2f}').format(
                epoch+1,
                mean_loss,
                mean_err,
                mean_test_err,
                t_elapsed))
            t_start = time.time()

## Start Training

In [22]:
tf_config   = tf.ConfigProto(
    allow_soft_placement=True,
    gpu_options={'allow_growth': True})
sess        = tf.InteractiveSession(config=tf_config)

sess.run(tf.global_variables_initializer())

writer      = tf.summary.FileWriter( \
           'logdir3/encoder_decoder',
           tf.get_default_graph())

In [23]:
train(100, writer)

epoch: 10, loss: 0.79396, err: 0.92645, test_err: 0.78910, elapsed: 31.96
epoch: 20, loss: 0.46629, err: 0.68028, test_err: 0.37710, elapsed: 31.15
epoch: 30, loss: 0.37138, err: 0.57937, test_err: 0.23600, elapsed: 31.06
epoch: 40, loss: 0.31004, err: 0.49672, test_err: 0.25150, elapsed: 29.97
epoch: 50, loss: 0.26574, err: 0.43688, test_err: 0.27790, elapsed: 35.21
epoch: 60, loss: 0.22741, err: 0.37223, test_err: 0.22680, elapsed: 33.87
epoch: 70, loss: 0.20682, err: 0.34608, test_err: 0.28190, elapsed: 49.75
epoch: 80, loss: 0.18469, err: 0.30578, test_err: 0.23730, elapsed: 51.41
epoch: 90, loss: 0.18562, err: 0.31292, test_err: 0.16370, elapsed: 49.72
epoch: 100, loss: 0.17146, err: 0.28697, test_err: 0.23330, elapsed: 51.18


## Training progress was

In [24]:
# !tensorboard --logdir logdir

## Examine inference steps

In [25]:
expr = '123+456'

In [26]:
str_to_onehot(expr,encoder_max_seq_len)

array([[ 0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.]])

In [27]:
# batch size == 1
feed = {
    encoder_inputs: [str_to_onehot(expr,encoder_max_seq_len)],
    encoder_seqlen: [7],
    encoder_training: False,
}
e_out, e_state, = sess.run(
    [encoder_out, encoder_state,],
    feed)

In [28]:
decoder_inputs_ = [str_to_onehot('=',decoder_max_seq_len)]
decoder_input_buf = np.zeros_like(decoder_inputs_)
decoder_input_buf[:,0,:] = np.array(decoder_inputs_)[:,0,:]

In [29]:
list(onehot_to_str(decoder_input_buf[0],5))

['=', ' ', ' ', ' ', ' ']

In [30]:
collect_output = []

In [31]:
i = 0
feed = {
    encoder_state:  e_state,
    decoder_inputs: decoder_input_buf,
    decoder_seqlen: [1],
}
out, d_state      = \
    sess.run([outputs,decoder_state,], feed)
out_decoded = onehot_to_str(out[0],5)
collect_output.append(out_decoded[0])
print('infer: step {}: {} ==> {}'.format(
    i,
    list(onehot_to_str(decoder_input_buf[0],5)),
    list(out_decoded)))

infer: step 0: ['=', ' ', ' ', ' ', ' '] ==> ['5', ' ', ' ', ' ', ' ']


In [32]:
decoder_input_buf[:,0,:] = out[:,0,:]
i = 1
feed = {
    encoder_state:  d_state,
    decoder_inputs: decoder_input_buf,
    decoder_seqlen: [1],
}
out, d_state      = \
    sess.run([outputs,decoder_state,], feed)
out_decoded = onehot_to_str(out[0],5)
collect_output.append(out_decoded[0])
print('infer: step {}: {} ==> {}'.format(
    i,
    list(onehot_to_str(decoder_input_buf[0],5)),
    list(out_decoded)))

infer: step 1: ['5', ' ', ' ', ' ', ' '] ==> ['1', ' ', ' ', ' ', ' ']


In [33]:
decoder_input_buf[:,0,:] = out[:,0,:]
i = 2
feed = {
    encoder_state:  d_state,
    decoder_inputs: decoder_input_buf,
    decoder_seqlen: [1],
}
out, d_state      = \
    sess.run([outputs,decoder_state,], feed)
out_decoded = onehot_to_str(out[0],5)
collect_output.append(out_decoded[0])
print('infer: step {}: {} ==> {}'.format(
    i,
    list(onehot_to_str(decoder_input_buf[0],5)),
    list(out_decoded)))

infer: step 2: ['1', ' ', ' ', ' ', ' '] ==> ['4', ' ', ' ', ' ', ' ']


In [34]:
decoder_input_buf[:,0,:] = out[:,0,:]
i = 3
feed = {
    encoder_state:  d_state,
    decoder_inputs: decoder_input_buf,
    decoder_seqlen: [1],
}
out, d_state      = \
    sess.run([outputs,decoder_state,], feed)
out_decoded = onehot_to_str(out[0],5)
collect_output.append(out_decoded[0])
print('infer: step {}: {} ==> {}'.format(
    i,
    list(onehot_to_str(decoder_input_buf[0],5)),
    list(out_decoded)))

infer: step 3: ['4', ' ', ' ', ' ', ' '] ==> ['1', ' ', ' ', ' ', ' ']


In [35]:
decoder_input_buf[:,0,:] = out[:,0,:]
i = 4
feed = {
    encoder_state:  d_state,
    decoder_inputs: decoder_input_buf,
    decoder_seqlen: [1],
}
out, d_state      = \
    sess.run([outputs,decoder_state,], feed)
out_decoded = onehot_to_str(out[0],5)
collect_output.append(out_decoded[0])
print('infer: step {}: {} ==> {}'.format(
    i,
    list(onehot_to_str(decoder_input_buf[0],5)),
    list(out_decoded)))

infer: step 4: ['1', ' ', ' ', ' ', ' '] ==> ['1', ' ', ' ', ' ', ' ']


In [36]:
collect_output

['5', '1', '4', '1', '1']

## infer()

In [37]:
def infer(expr):
    # encoder
    feed = {
        encoder_seqlen: [len(expr)],
        encoder_inputs: \
            [str_to_onehot(expr,encoder_max_seq_len)],
        encoder_training: False
    }
    e_out, e_state = \
        sess.run([encoder_out, encoder_state], feed)
    
    # decoder: step 0
    out_buf = []
    feed = {
        encoder_state: e_state,
        decoder_seqlen: [1],
        decoder_inputs: \
            [str_to_onehot('=',decoder_max_seq_len)]
    }
    out, d_state = sess.run([outputs, decoder_state], feed)
    out_decoded = onehot_to_str(out[0],1)
    out_buf.append(out_decoded)
    
    # decoder: step 1..n-1
    for _ in range(1,decoder_max_seq_len):
        feed = {
            encoder_state: d_state,
            decoder_seqlen: [1],
            decoder_inputs: \
                [str_to_onehot(out_decoded,decoder_max_seq_len)]
        }
        out, d_state = sess.run([outputs, decoder_state], feed)
        out_decoded = onehot_to_str(out[0],1)
        out_buf.append(out_decoded)

    return ''.join(out_buf), e_out[0]

In [38]:
ans, _ = infer('345+111')
ans

'456  '

In [39]:
ans, _ = infer('345+222')
ans

'567  '

In [40]:
ans, e_out = infer('111+222')
ans

'333  '

In [41]:
ans, e_out = infer('999+99')
ans

'1088 '

In [42]:
errs = []
for _ in range(10):
    expr, ans_ = make_random_data()
    truth = (ans_+' ')[1:6]
    ans, _  = infer(expr)
    print('['+truth+']', '['+ans+']', expr)
    errs.append(0 if truth == ans else 1)
print('errs: {:.5f}'.format(np.mean(errs,dtype=np.float32)))

[1541 ] [1551 ] 905+636
[996  ] [996  ] 525+471
[662  ] [662  ] 61+601
[1016 ] [1015 ] 29+987
[830  ] [820  ] 286+544
[796  ] [796  ] 78+718
[437  ] [447  ] 408+29
[1422 ] [1422 ] 488+934
[869  ] [858  ] 861+8
[1237 ] [1237 ] 428+809
errs: 0.50000
