In [1]:
import tensorflow as tf
from config import FLAGS
from IPython.core.debugger import set_trace

  from ._conv import register_converters as _register_converters


In [2]:
def _linear(_input, output_size, scope, use_bias=True):
    if len(_input.get_shape().as_list())>3:
        raise AssertionError('At most rank 3 tensor is supported!')
    shape = _input.get_shape().as_list()
    
    _input = tf.reshape(_input, (-1,shape[-1]))
    
    with tf.variable_scope(scope):
        W = tf.get_variable('W', [shape[-1], output_size], dtype='float32')
        if use_bias:
            B = tf.get_variable('B', [output_size])
            res = tf.matmul(_input,W)+B
        else:
            res = tf.matmul(_input,W)
    
    return tf.reshape(res,(shape[:2] + [output_size])) if len(shape)==3 else res

# Model config

In [3]:
batch_size = FLAGS.batch_size
num_units = FLAGS.num_units
num_layers = FLAGS.num_layers
dim_emb = FLAGS.dim_emb
attn_size =FLAGS.attn_size
# these values will be changed to biggest bucket's length
maxlen_src = 10
maxlen_target = 10
# vocabulary size
src_vocab_size = target_vocab_size = 1000

In [4]:
# placeholders
src = [ tf.placeholder('int32', [batch_size]) for _ in range(maxlen_src) ]
target = [ tf.placeholder('int32', [batch_size]) for _ in range(maxlen_target) ]

# Encoder & Decoder cell (if need dropout,,, add dropout!)
encoder_cell = tf.nn.rnn_cell.LSTMCell(num_units=num_units)
encoder_cell = tf.nn.rnn_cell.MultiRNNCell([encoder_cell]*num_layers)

decoder_cell = tf.nn.rnn_cell.LSTMCell(num_units=num_units)
decoder_cell = tf.nn.rnn_cell.MultiRNNCell([decoder_cell]*num_layers)

# embedding layer(embedding for src sentence & target)
src_embedding = tf.get_variable('src_embedding', [src_vocab_size,dim_emb], 'float32')
target_embedding = tf.get_variable('target_embedding', [target_vocab_size,dim_emb], 'float32')

src_embed = tf.nn.embedding_lookup(src_embedding, src)
target_embed = tf.nn.embedding_lookup(target_embedding, target)

# Encoder(can be bi-directional rnn)
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(cell=encoder_cell, inputs=_linear(src_embed, num_units, 'Encoder/input_projection'),
                                                   time_major=True, scope='Encoder', dtype='float32')

In [5]:
# Decoder
with tf.variable_scope('Decoder'):
    decoder_state = encoder_state

    for t in range(maxlen_target):
        if t==0:
            go_embed = tf.nn.embedding_lookup(target_embedding, tf.fill([batch_size],FLAGS.GO))
            next_input = tf.concat([go_embed, tf.zeros([batch_size, attn_size], 'float32')], axis=1)
        else:
            tf.get_variable_scope().reuse_variables()
        
        # projection
        decoder_output, decoder_state = decoder_cell(_linear(next_input, num_units, 'input_projection'), decoder_state)

        # Attention block
        # 1. attention weights

        # target projections for comparison(num_units --> attn_size)
        target_projections = [_linear(decoder_output,attn_size,'compare_target')]*maxlen_src
        target_projections = tf.stack(target_projections)

        # source projections for comparison(num_units --> attn_size)
        src_projections = _linear(encoder_outputs,attn_size,'compare_src')

        # attention weights(attention distribution)
        attn_weights = tf.nn.softmax(tf.nn.tanh(target_projections+src_projections))
        
        # 2. context vector
        ctx_vec = tf.reduce_sum(attn_weights*encoder_outputs, axis=0)

        # 3. attention vector
        concat = tf.concat([ctx_vec, decoder_output], axis=1)
        attn_vec = tf.nn.tanh(_linear(concat,attn_size,'attn'))

        cur_embed = target_embed[t,:]
        next_input = tf.concat([cur_embed, attn_vec], axis=1)
        #print next_input



In [6]:
[ v.name for v in tf.trainable_variables() ]

[u'src_embedding:0',
 u'target_embedding:0',
 u'Encoder/input_projection/W:0',
 u'Encoder/input_projection/B:0',
 u'Encoder/multi_rnn_cell/cell_0/lstm_cell/kernel:0',
 u'Encoder/multi_rnn_cell/cell_0/lstm_cell/bias:0',
 u'Decoder/input_projection/W:0',
 u'Decoder/input_projection/B:0',
 u'Decoder/multi_rnn_cell/cell_0/lstm_cell/kernel:0',
 u'Decoder/multi_rnn_cell/cell_0/lstm_cell/bias:0',
 u'Decoder/compare_target/W:0',
 u'Decoder/compare_target/B:0',
 u'Decoder/compare_src/W:0',
 u'Decoder/compare_src/B:0',
 u'Decoder/attn/W:0',
 u'Decoder/attn/B:0']