Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)

![title](img/dilated_cnn.jpg)

In [1]:
"""
pip3 install tensor2tensor
"""
import tensorflow as tf
import numpy as np
from tensor2tensor.utils import beam_search

In [2]:
params = {
    'batch_size': 128,
    'text_iter_step': 25,
    'seq_len': 200,
    'kernel_sz': 5,
    'hidden_dim': 128,
    'n_hidden_layer': 4,
    'dropout_rate': 0.1,
    'display_step': 10,
    'generate_step': 100,
    'beam_size': 5,
}

In [3]:
def parse_text(file_path):
    with open(file_path) as f:
        text = f.read()
    
    char2idx = {c: i+3 for i, c in enumerate(set(text))}
    char2idx['<pad>'] = 0
    char2idx['<start>'] = 1
    char2idx['<end>'] = 2
    
    ints = np.array([char2idx[char] for char in list(text)])
    return ints, char2idx

def next_batch(ints):
    len_win = params['seq_len'] * params['batch_size']
    for i in range(0, len(ints)-len_win, params['text_iter_step']):
        clip = ints[i: i+len_win]
        yield clip.reshape([params['batch_size'], params['seq_len']])
        
def input_fn(ints):
    dataset = tf.data.Dataset.from_generator(
        lambda: next_batch(ints), tf.int32, tf.TensorShape([None, params['seq_len']]))
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()

In [4]:
def start_sent(x):
    _x = tf.fill([tf.shape(x)[0], 1], params['char2idx']['<start>'])
    return tf.concat([_x, x], 1)

def end_sent(x):
    _x = tf.fill([tf.shape(x)[0], 1], params['char2idx']['<end>'])
    return tf.concat([x, _x], 1)

def embed_seq(x, vocab_sz, embed_dim, name, zero_pad=True):
    embedding = tf.get_variable(name, [vocab_sz, embed_dim])
    if zero_pad:
        embedding = tf.concat([tf.zeros([1, embed_dim]), embedding[1:, :]], 0)
    x = tf.nn.embedding_lookup(embedding, x)
    return x

In [5]:
def position_encoding(inputs):
    repr_dim = inputs.get_shape()[-1].value
    pos = tf.reshape(tf.range(0.0, tf.to_float(tf.shape(inputs)[1]), dtype=tf.float32), [-1, 1])
    i = np.arange(0, repr_dim, 2, np.float32)
    denom = np.reshape(np.power(10000.0, i / repr_dim), [1, -1])
    enc = tf.expand_dims(tf.concat([tf.sin(pos / denom), tf.cos(pos / denom)], 1), 0)
    return tf.tile(enc, [tf.shape(inputs)[0], 1, 1])


def layer_norm(inputs, epsilon=1e-8):
    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
    normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))
    params_shape = inputs.get_shape()[-1:]
    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())
    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())
    return gamma * normalized + beta


def cnn_block(x, dilation_rate, pad_sz, is_training):
    x = layer_norm(x)
    x = tf.layers.dropout(x, params['dropout_rate'], training=is_training)
    pad = tf.zeros([tf.shape(x)[0], pad_sz, params['hidden_dim']])
    x =  tf.layers.conv1d(inputs = tf.concat([pad, x, pad], 1),
                          filters = params['hidden_dim'],
                          kernel_size = params['kernel_sz'],
                          dilation_rate = dilation_rate)
    x = x[:, :-pad_sz, :]
    x = tf.nn.relu(x)
    return x


def forward(inputs, reuse, is_training):
    with tf.variable_scope('model', reuse=reuse):
        x = embed_seq(inputs, params['vocab_size'], params['hidden_dim'], 'word_embedding')
        x += position_encoding(x)
        
        for i in range(params['n_hidden_layer']):
            dilation_rate = 2 ** i
            pad_sz = (params['kernel_sz'] - 1) * dilation_rate
            with tf.variable_scope('block_%d'%i, reuse=reuse):
                x += cnn_block(x, dilation_rate, pad_sz, is_training)
        
        logits = tf.layers.dense(x, params['vocab_size'])
    return logits

In [6]:
def beam_search_decoding():
    batch_size = 1
    initial_ids = tf.constant(params['char2idx']['<start>'], tf.int32, [batch_size])
    
    def symbols_to_logits(ids):
        logits = forward(ids, reuse=True, is_training=False)
        return logits[:, tf.shape(ids)[1]-1, :]
    
    final_ids, final_probs = beam_search.beam_search(
        symbols_to_logits,
        initial_ids,
        params['beam_size'],
        params['seq_len'],
        params['vocab_size'],
        0.0,
        eos_id = params['char2idx']['<end>'])
    
    return final_ids[0, 0, :]

In [None]:
ints, params['char2idx'] = parse_text('../temp/anna.txt')
params['vocab_size'] = len(params['char2idx'])
params['idx2char'] = {i: c for c, i in params['char2idx'].items()}
print('Vocabulary size:', params['vocab_size'])

X = input_fn(ints)
logits = forward(start_sent(X), reuse=False, is_training=True)

ops = {}
ops['global_step'] = tf.Variable(0, trainable=False)

targets = end_sent(X)
ops['loss'] = tf.reduce_mean(tf.contrib.seq2seq.sequence_loss(
    logits = logits,
    targets = targets,
    weights = tf.to_float(tf.ones_like(targets))))

ops['train'] = tf.train.AdamOptimizer().minimize(ops['loss'], global_step=ops['global_step'])

ops['generate'] = beam_search_decoding()

Vocabulary size: 86


In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
while True:
    try:
        _, step, loss = sess.run([ops['train'], ops['global_step'], ops['loss']])
    except tf.errors.OutOfRangeError:
        break
    else:
        if step % params['display_step'] == 0 or step == 1:
            print("Step %d | Loss %.3f" % (step, loss))
        if step % params['generate_step'] == 0 and step > 1:
            ints = sess.run(ops['generate'])
            print('\n'+''.join([params['idx2char'][i] for i in ints])+'\n')

Step 1 | Loss 7.378
Step 10 | Loss 3.380
Step 20 | Loss 3.177
Step 30 | Loss 3.090
Step 40 | Loss 3.022
Step 50 | Loss 2.962
Step 60 | Loss 2.899
Step 70 | Loss 2.837
Step 80 | Loss 2.753
Step 90 | Loss 2.682
Step 100 | Loss 2.607

<start> he the the the the the the the the ther the the the the thes the the the the the thin the the the the ther the the the thes the the the thin the the the the the the the ther the the the thes the the<end>

Step 110 | Loss 2.557
Step 120 | Loss 2.505
Step 130 | Loss 2.457
Step 140 | Loss 2.417
Step 150 | Loss 2.384
Step 160 | Loss 2.348
Step 170 | Loss 2.327
Step 180 | Loss 2.297
Step 190 | Loss 2.268
Step 200 | Loss 2.240

<start> and her and and her and and her and and her and he ther and an thing he thing of he ther and on the thing he thing he thing on the sher and he thing of the she ther and and he ther and her and he the

Step 210 | Loss 2.210
Step 220 | Loss 2.188
Step 230 | Loss 2.166
Step 240 | Loss 2.148
Step 250 | Loss 2.121
Step 260 | Loss

Step 1910 | Loss 0.318
Step 1920 | Loss 0.313
Step 1930 | Loss 0.316
Step 1940 | Loss 0.323
Step 1950 | Loss 0.313
Step 1960 | Loss 0.312
Step 1970 | Loss 0.309
Step 1980 | Loss 0.309
Step 1990 | Loss 0.305
Step 2000 | Loss 0.298

<start> that he had no sort of proof that he would be rejected. And
he had now come to Moscow with a firm determination to make an offer,
and get married if he were accepted. Or ... he could not conceive wha

Step 2010 | Loss 0.300
Step 2020 | Loss 0.306
Step 2030 | Loss 0.311
Step 2040 | Loss 0.313
Step 2050 | Loss 0.301
Step 2060 | Loss 0.301
Step 2070 | Loss 0.305
Step 2080 | Loss 0.309
Step 2090 | Loss 0.304
Step 2100 | Loss 0.297

<start> that he had no sort of proof that he would be rejected. And
he had now come to Moscow with a firm determination to make an offer,
and get married if he were accepted. Or ... he could not conceive wha

Step 2110 | Loss 0.310
Step 2120 | Loss 0.302
