Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)

<img src="img/self_attn.png" width="200">

In [1]:
"""
pip3 install tensor2tensor
"""
import tensorflow as tf
import numpy as np
from tensor2tensor.utils import beam_search
from tensor2tensor.layers.common_attention import add_timing_signal_1d

In [2]:
params = {
    'batch_size': 64,
    'text_iter_step': 25,
    'seq_len': 200,
    'hidden_dim': 128,
    'num_head': 8,
    'n_hidden_layer': 2,
    'display_step': 10,
    'generate_step': 100,
    'beam_size': 5,
}

In [3]:
def parse_text(file_path):
    with open(file_path) as f:
        text = f.read()
    
    char2idx = {c: i+3 for i, c in enumerate(set(text))}
    char2idx['<pad>'] = 0
    char2idx['<start>'] = 1
    char2idx['<end>'] = 2
    
    ints = np.array([char2idx[char] for char in list(text)])
    return ints, char2idx

def next_batch(ints):
    len_win = params['seq_len'] * params['batch_size']
    for i in range(0, len(ints)-len_win, params['text_iter_step']):
        clip = ints[i: i+len_win]
        yield clip.reshape([params['batch_size'], params['seq_len']])
        
def input_fn(ints):
    dataset = tf.data.Dataset.from_generator(
        lambda: next_batch(ints), tf.int32, tf.TensorShape([None, params['seq_len']]))
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()

In [4]:
def start_sent(x):
    _x = tf.fill([tf.shape(x)[0], 1], params['char2idx']['<start>'])
    return tf.concat([_x, x], 1)

def end_sent(x):
    _x = tf.fill([tf.shape(x)[0], 1], params['char2idx']['<end>'])
    return tf.concat([x, _x], 1)

def embed_seq(x, vocab_sz, embed_dim, zero_pad=False, scale=False):
    embedding = tf.get_variable('word2vec', [vocab_sz, embed_dim])
    if zero_pad:
        embedding = tf.concat([tf.zeros([1, embed_dim]), embedding[1:, :]], 0)
    x = tf.nn.embedding_lookup(embedding, x)
    if scale:
        x = x * tf.sqrt(tf.to_float(embed_dim))
    return x

def layer_norm(inputs, epsilon=1e-8):
    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
    normalized = (inputs - mean) * (tf.rsqrt(variance + epsilon))
    
    params_shape = inputs.get_shape()[-1:]
    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())
    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())
    
    return gamma * normalized + beta

def self_attention(inputs, is_training, activation=None):
    num_units = params['hidden_dim']
    num_heads = params['num_head']
    T_q = T_k = tf.shape(inputs)[1]

    Q_K_V = tf.layers.dense(inputs, 3*num_units, activation)
    Q, K, V = tf.split(Q_K_V, 3, -1)
    Q_ = tf.concat(tf.split(Q, num_heads, axis=2), 0)                         
    K_ = tf.concat(tf.split(K, num_heads, axis=2), 0)                        
    V_ = tf.concat(tf.split(V, num_heads, axis=2), 0)                         

    align = tf.matmul(Q_, K_, transpose_b=True)                               
    align *= tf.rsqrt(tf.to_float(K_.get_shape()[-1].value))

    paddings = tf.fill(tf.shape(align), float('-inf'))         
    lower_tri = tf.ones([T_q, T_k])                                                
    lower_tri = tf.linalg.LinearOperatorLowerTriangular(lower_tri).to_dense()      
    masks = tf.tile(tf.expand_dims(lower_tri,0), [tf.shape(align)[0],1,1])       
    align = tf.where(tf.equal(masks, 0), paddings, align)               

    align = tf.nn.softmax(align)                                                  
    align = tf.layers.dropout(align, 0.1, training=is_training)           
    x = tf.matmul(align, V_)                                                 
    x = tf.concat(tf.split(x, num_heads, axis=0), 2)              
    x += inputs                                                                
    x = layer_norm(x)                                                 
    return x

def ffn(inputs, activation=tf.nn.relu):
    x = tf.layers.conv1d(inputs, 4*params['hidden_dim'], 1, activation=activation)
    x = tf.layers.conv1d(x, params['hidden_dim'], 1, activation=None)
    x += inputs
    x = layer_norm(x)
    return x

In [5]:
def forward(inputs, reuse, is_training):
    with tf.variable_scope('model', reuse=reuse):
        x = embed_seq(inputs,
                      params['vocab_size'],
                      params['hidden_dim'],
                      zero_pad=True,
                      scale=True)
        x = add_timing_signal_1d(x)
        x = tf.layers.dropout(x, 0.1, training=is_training)
        
        for i in range(params['n_hidden_layer']):
            with tf.variable_scope('attn_%d'%i, reuse=reuse):
                x = self_attention(x, is_training)
            with tf.variable_scope('ffn_%d'%i, reuse=reuse):
                x = ffn(x)
        
        logits = tf.layers.dense(x, params['vocab_size'])
    return logits

In [6]:
def beam_search_decoding():
    batch_size = 1
    initial_ids = tf.constant(params['char2idx']['<start>'], tf.int32, [batch_size])
    
    def symbols_to_logits(ids):
        logits = forward(ids, reuse=True, is_training=False)
        return logits[:, tf.shape(ids)[1]-1, :]
    
    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        params['beam_size'],
        2 * params['seq_len'],
        params['vocab_size'],
        0.0,
        eos_id = params['char2idx']['<end>'])
    
    return final_ids[0, 0, :]

In [7]:
def build_graph():
    ints, params['char2idx'] = parse_text('../temp/anna.txt')
    params['vocab_size'] = len(params['char2idx'])
    params['idx2char'] = {i: c for c, i in params['char2idx'].items()}
    print('Vocabulary size:', params['vocab_size'])
    X = input_fn(ints)
    logits = forward(start_sent(X), reuse=False, is_training=True)
    ops = {}
    ops['global_step'] = tf.Variable(0, trainable=False)
    targets = end_sent(X)
    ops['loss'] = tf.reduce_mean(tf.contrib.seq2seq.sequence_loss(
        logits = logits,
        targets = targets,
        weights = tf.to_float(tf.ones_like(targets))))
    ops['train'] = tf.train.AdamOptimizer().minimize(ops['loss'],
                                                     global_step=ops['global_step'])
    ops['generate'] = beam_search_decoding()
    return ops

In [None]:
def main():
    ops = build_graph()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    while True:
        try:
            _, step, loss = sess.run([ops['train'], ops['global_step'], ops['loss']])
        except tf.errors.OutOfRangeError:
            break
        else:
            if step % params['display_step'] == 0 or step == 1:
                print("Step %d | Loss %.3f" % (step, loss))
            if step % params['generate_step'] == 0 and step > 1:
                ints = sess.run(ops['generate'])
                print('\n'+''.join([params['idx2char'][i] for i in ints])+'\n')

In [None]:
main()

Vocabulary size: 86
Step 1 | Loss 5.169
Step 10 | Loss 2.878
Step 20 | Loss 2.671
Step 30 | Loss 2.580
Step 40 | Loss 2.532
Step 50 | Loss 2.488
Step 60 | Loss 2.455
Step 70 | Loss 2.431
Step 80 | Loss 2.407
Step 90 | Loss 2.395
Step 100 | Loss 2.375

<start> his his his his his his his his his his his his his he his he his his his he his his his his his his he his his his his his he his his his his he his his his his his he he his his he his his he he h<end><pad>

Step 110 | Loss 2.355
Step 120 | Loss 2.337
Step 130 | Loss 2.311
Step 140 | Loss 2.276
Step 150 | Loss 2.253
Step 160 | Loss 2.223
Step 170 | Loss 2.208
Step 180 | Loss 2.180
Step 190 | Loss 2.160
Step 200 | Loss 2.122

<start> his his his his his his his his hise his his his his his his his his his his hise his his hise his his his his his his hise his his his his his his cofe his his his his his his hise his his his hise <end>

Step 210 | Loss 2.096
Step 220 | Loss 2.060
Step 230 | Loss 2.044
Step 240 | Loss 2.019
Step 25

Step 1910 | Loss 0.620
Step 1920 | Loss 0.615
Step 1930 | Loss 0.611
Step 1940 | Loss 0.604
Step 1950 | Loss 0.612
Step 1960 | Loss 0.623
Step 1970 | Loss 0.612
Step 1980 | Loss 0.595
Step 1990 | Loss 0.622
Step 2000 | Loss 0.642

<start>the could not live with the Shtcherbatskys were old, noble Moscow
families, and had always been on intimate and friendly terms. This
intimacy had grown still closer during Levin's student days. He had<end>

Step 2010 | Loss 0.634
Step 2020 | Loss 0.613
Step 2030 | Loss 0.610
Step 2040 | Loss 0.608
Step 2050 | Loss 0.618
Step 2060 | Loss 0.599
Step 2070 | Loss 0.603
Step 2080 | Loss 0.599
Step 2090 | Loss 0.600
Step 2100 | Loss 0.581

<start>in the professor, and he went
back to his argument. "No," he said; "but said Sergey Ivanovitch, with his habitual
clearness, precision of expression, and elegance of phrase. "I cannnnot in
the could n<end>

Step 2110 | Loss 0.581
Step 2120 | Loss 0.586
Step 2130 | Loss 0.578
Step 2140 | Loss 0.585
Step 2150 | Loss 0