Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)

![alt text](vae.png)

In [1]:
import tensorflow as tf
import numpy as np
from modified import ModifiedBasicDecoder, ModifiedBeamSearchDecoder

In [2]:
PARAMS = {
    'max_len': 15,
    'vocab_size': 20000,
    'embed_dims': 128,
    'rnn_size': 128,
    'latent_size': 16,
    'beam_width': 5,
    'clip_norm': 5.0,
    'anneal_max': 1.0,
    'anneal_bias': 6000,
    'batch_size': 128,
    'n_epochs': 30,
    'word_dropout_rate': 0.8,
}

In [3]:
def build_vocab(index_from=4):
    PARAMS['word2idx'] = tf.keras.datasets.imdb.get_word_index()
    PARAMS['word2idx'] = {k: (v + index_from) for k, v in PARAMS['word2idx'].items()}
    PARAMS['word2idx']['<pad>'] = 0
    PARAMS['word2idx']['<start>'] = 1
    PARAMS['word2idx']['<unk>'] = 2
    PARAMS['word2idx']['<end>'] = 3
    PARAMS['idx2word'] = {i: w for w, i in PARAMS['word2idx'].items()}

    
def load_data(index_from=4):
    (X_train, _), (X_test, _) = tf.contrib.keras.datasets\
        .imdb.load_data(num_words=PARAMS['vocab_size'], index_from=index_from)
    return (X_train, X_test)

In [4]:
def kl_w_fn(global_step):
    return PARAMS['anneal_max'] * tf.sigmoid(
        (10 / PARAMS['anneal_bias']) * (
        tf.to_float(global_step) - tf.constant(PARAMS['anneal_bias'] / 2)))

def clip_grads(loss):
    variables = tf.trainable_variables()
    grads = tf.gradients(loss, variables)
    clipped_grads, _ = tf.clip_by_global_norm(grads, PARAMS['clip_norm'])
    return zip(clipped_grads, variables)

def rnn_cell():
    return tf.nn.rnn_cell.GRUCell(PARAMS['rnn_size'],
                                  kernel_initializer=tf.orthogonal_initializer())

In [5]:
def forward(inputs, labels, mode):
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    enc_seq_len = tf.count_nonzero(inputs, 1, dtype=tf.int32)
    batch_sz = tf.shape(inputs)[0]
    
    with tf.variable_scope('Encoder'):
        embedding = tf.get_variable('lookup_table', [PARAMS['vocab_size'], PARAMS['embed_dims']])
        x = tf.nn.embedding_lookup(embedding, inputs)
        
        _, enc_state = tf.nn.dynamic_rnn(rnn_cell(), x, enc_seq_len, dtype=tf.float32)
        
        z_mean = tf.layers.dense(enc_state, PARAMS['latent_size'])
        z_var = tf.layers.dense(enc_state, PARAMS['latent_size'])
        
    posterior = tf.contrib.distributions.MultivariateNormalDiag(z_mean, z_var)
    prior = tf.contrib.distributions.MultivariateNormalDiag(tf.zeros_like(z_mean),
                                                            tf.ones_like(z_var))
    z = posterior.sample()
        
    with tf.variable_scope('Decoder'):
        init_state = tf.layers.dense(z, PARAMS['rnn_size'], tf.nn.elu)
        output_proj = tf.layers.Dense(PARAMS['vocab_size'])
        dec_cell = rnn_cell()
        
        if is_training:
            dec_seq_len = tf.count_nonzero(labels['dec_out'], 1, dtype=tf.int32)
            
            helper = tf.contrib.seq2seq.TrainingHelper(
                inputs = tf.nn.embedding_lookup(embedding, labels['dec_inp']),
                sequence_length = dec_seq_len)
            decoder = ModifiedBasicDecoder(
                cell = dec_cell,
                helper = helper,
                initial_state = init_state,
                output_layer = output_proj,
                concat_z = z)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = decoder,
                maximum_iterations = tf.reduce_max(dec_seq_len))
            logits = decoder_output.rnn_output
            
            return logits, posterior, prior
        else:
            tiled_z = tf.tile(tf.expand_dims(z, 1), [1, PARAMS['beam_width'], 1])
            
            decoder = ModifiedBeamSearchDecoder(
                cell = dec_cell,
                embedding = embedding,
                start_tokens = tf.tile(tf.constant([PARAMS['word2idx']['<start>']], tf.int32),
                                       [batch_sz]),
                end_token = PARAMS['word2idx']['<end>'],
                initial_state = tf.contrib.seq2seq.tile_batch(init_state, PARAMS['beam_width']),
                beam_width = PARAMS['beam_width'],
                output_layer = output_proj,
                concat_z = tiled_z)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = decoder)
            
            return decoder_output.predicted_ids[:, :, 0]

In [6]:
def model_fn(features, labels, mode):
    logits_or_ids = forward(features, labels, mode)        
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions=logits_or_ids)
        
    if mode == tf.estimator.ModeKeys.TRAIN:
        logits, posterior, prior = logits_or_ids
        
        out_dist = tf.distributions.Categorical(logits)
        
        global_step = tf.train.get_global_step()
        
        nll_loss = - tf.reduce_sum(out_dist.log_prob(labels['dec_out']))
        
        kl_w = kl_w_fn(global_step)
        
        kl_loss = tf.reduce_sum(tf.distributions.kl_divergence(posterior, prior))
        
        loss_op = nll_loss + kl_w * kl_loss
        
        train_op = tf.train.AdamOptimizer().apply_gradients(clip_grads(loss_op),
                                                            global_step = global_step)
        
        lth = tf.train.LoggingTensorHook(
            {'nll_loss': nll_loss, 'kl_w': kl_w, 'kl_loss': kl_loss}, every_n_iter=100)
        
        return tf.estimator.EstimatorSpec(
            mode=mode, loss=loss_op, train_op=train_op, training_hooks=[lth])

In [7]:
def inf_inp(test_strs):
    x = [[PARAMS['word2idx'].get(w, 2) for w in s.split()] for s in test_strs]
    x = tf.keras.preprocessing.sequence.pad_sequences(
        x, PARAMS['max_len'], truncating='post', padding='post')
    return x


def demo(test_strs, pred_ids):
    for s, pred in zip(test_strs, pred_ids):
        print('\nOriginal:', s)
        print('Reconstr:', ' '.join([PARAMS['idx2word'].get(idx, '<unk>') for idx in pred]))

        
def word_dropout(x):
    is_dropped = np.random.binomial(1, PARAMS['word_dropout_rate'], x.shape)
    fn = np.vectorize(lambda x, k: PARAMS['word2idx']['<unk>'] if (
                      k and (x not in range(4))) else x)
    return fn(x, is_dropped)
        
        
def main():
    build_vocab()
    
    X = np.concatenate(load_data())
    X = np.concatenate((tf.keras.preprocessing.sequence.pad_sequences(
                            X, PARAMS['max_len'], truncating='post', padding='post'),
                        tf.keras.preprocessing.sequence.pad_sequences(
                            X, PARAMS['max_len'], truncating='pre', padding='post')))

    enc_inp = X[:, 1:]
    dec_inp = X[:]
    dec_out = np.concatenate([X[:, 1:], np.full([X.shape[0], 1], PARAMS['word2idx']['<end>'])], 1)

    test_strs = ['i love this film and i think it is one of the best films',
                 'this movie is a waste of time and there is no point to watch it']

    estimator = tf.estimator.Estimator(model_fn)

    for i in range(PARAMS['n_epochs']):
        print('Epoch %d/%d'%(i+1, PARAMS['n_epochs']))
        print()
        
        estimator.train(tf.estimator.inputs.numpy_input_fn(
            x = enc_inp,
            y = {'dec_inp': word_dropout(dec_inp), 'dec_out': dec_out},
            batch_size = PARAMS['batch_size'],
            shuffle = True))

        pred_ids = list(estimator.predict(tf.estimator.inputs.numpy_input_fn(
            x = inf_inp(test_strs),
            shuffle = False)))

        demo(test_strs, pred_ids)
        print()

In [8]:
main()

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x123f16eb8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
Epoch 1/30

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1 into /var

INFO:tensorflow:global_step/sec: 0.71162
INFO:tensorflow:loss = 10165.197, step = 2265 (140.524 sec)
INFO:tensorflow:nll_loss = 9637.449, kl_w = 0.22676536, kl_loss = 2327.2856 (140.524 sec)
INFO:tensorflow:Saving checkpoints for 2346 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt.
INFO:tensorflow:Loss for final step: 2470.2595.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt-2346
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

Original: i love this film and i think it is one of the best films
Reconstr: i was this movie i have seen this movie of the film of the <end>

Original: this movie is a waste of time and there is no point to watch it
Reconstr: this is a lot of the movie to not the time it it it <end>

Epoch 4/30

INFO:tensorflow:Calling mode

INFO:tensorflow:loss = 11141.858, step = 4411 (138.940 sec)
INFO:tensorflow:nll_loss = 10165.846, kl_w = 0.9129343, kl_loss = 1069.0938 (138.940 sec)
INFO:tensorflow:global_step/sec: 0.777566
INFO:tensorflow:loss = 10777.204, step = 4511 (128.606 sec)
INFO:tensorflow:nll_loss = 9790.451, kl_w = 0.92530197, kl_loss = 1066.4117 (128.606 sec)
INFO:tensorflow:global_step/sec: 0.756028
INFO:tensorflow:loss = 11016.194, step = 4611 (132.270 sec)
INFO:tensorflow:nll_loss = 10036.266, kl_w = 0.93603605, kl_loss = 1046.8916 (132.271 sec)
INFO:tensorflow:Saving checkpoints for 4692 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt.
INFO:tensorflow:Loss for final step: 2645.3716.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt-4692
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done runnin

INFO:tensorflow:loss = 10160.389, step = 6657 (143.167 sec)
INFO:tensorflow:nll_loss = 9010.473, kl_w = 0.99774724, kl_loss = 1152.5125 (143.167 sec)
INFO:tensorflow:Saving checkpoints for 6690 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt.
INFO:tensorflow:global_step/sec: 0.723974
INFO:tensorflow:loss = 10228.971, step = 6757 (138.127 sec)
INFO:tensorflow:nll_loss = 9054.753, kl_w = 0.99809235, kl_loss = 1176.4619 (138.126 sec)
INFO:tensorflow:global_step/sec: 0.703937
INFO:tensorflow:loss = 10054.073, step = 6857 (142.058 sec)
INFO:tensorflow:nll_loss = 8881.458, kl_w = 0.99838483, kl_loss = 1174.5118 (142.058 sec)
INFO:tensorflow:global_step/sec: 0.699849
INFO:tensorflow:loss = 10893.012, step = 6957 (142.888 sec)
INFO:tensorflow:nll_loss = 9842.256, kl_w = 0.9986324, kl_loss = 1052.1951 (142.888 sec)
INFO:tensorflow:Saving checkpoints for 7038 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt.
INFO:tensorflow:Loss for final 

INFO:tensorflow:nll_loss = 9509.91, kl_w = 0.9999368, kl_loss = 1175.0089 (134.527 sec)
INFO:tensorflow:global_step/sec: 0.795414
INFO:tensorflow:loss = 10582.662, step = 8903 (125.720 sec)
INFO:tensorflow:nll_loss = 9437.434, kl_w = 0.9999466, kl_loss = 1145.2896 (125.719 sec)
INFO:tensorflow:global_step/sec: 0.759166
INFO:tensorflow:loss = 10479.792, step = 9003 (131.724 sec)
INFO:tensorflow:nll_loss = 9346.637, kl_w = 0.9999547, kl_loss = 1133.2068 (131.724 sec)
INFO:tensorflow:Saving checkpoints for 9055 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt.
INFO:tensorflow:global_step/sec: 0.725741
INFO:tensorflow:loss = 10373.581, step = 9103 (137.790 sec)
INFO:tensorflow:nll_loss = 9161.928, kl_w = 0.99996173, kl_loss = 1211.6995 (137.790 sec)
INFO:tensorflow:global_step/sec: 0.705545
INFO:tensorflow:loss = 10183.721, step = 9203 (141.735 sec)
INFO:tensorflow:nll_loss = 8972.045, kl_w = 0.9999676, kl_loss = 1211.7156 (141.735 sec)
INFO:tensorflow:global_st

INFO:tensorflow:nll_loss = 9295.404, kl_w = 0.9999982, kl_loss = 1285.5261
INFO:tensorflow:global_step/sec: 0.808415
INFO:tensorflow:loss = 10628.579, step = 11049 (123.700 sec)
INFO:tensorflow:nll_loss = 9387.16, kl_w = 0.99999845, kl_loss = 1241.4207 (123.700 sec)
INFO:tensorflow:global_step/sec: 0.798896
INFO:tensorflow:loss = 10555.951, step = 11149 (125.173 sec)
INFO:tensorflow:nll_loss = 9339.686, kl_w = 0.9999987, kl_loss = 1216.2676 (125.173 sec)
INFO:tensorflow:global_step/sec: 0.796561
INFO:tensorflow:loss = 10566.429, step = 11249 (125.540 sec)
INFO:tensorflow:nll_loss = 9316.506, kl_w = 0.9999989, kl_loss = 1249.924 (125.539 sec)
INFO:tensorflow:global_step/sec: 0.793878
INFO:tensorflow:loss = 9876.546, step = 11349 (125.964 sec)
INFO:tensorflow:nll_loss = 8576.07, kl_w = 0.99999905, kl_loss = 1300.4768 (125.964 sec)
INFO:tensorflow:Saving checkpoints for 11427 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt.
INFO:tensorflow:global_step/sec: 0.7

INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 13295 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt.
INFO:tensorflow:loss = 10042.709, step = 13295
INFO:tensorflow:nll_loss = 8669.055, kl_w = 1.0, kl_loss = 1373.6543
INFO:tensorflow:global_step/sec: 0.813114
INFO:tensorflow:loss = 9901.697, step = 13395 (122.985 sec)
INFO:tensorflow:nll_loss = 8574.932, kl_w = 1.0, kl_loss = 1326.7659 (122.985 sec)
INFO:tensorflow:global_step/sec: 0.801922
INFO:tensorflow:loss = 10293.38, step = 13495 (124.700 sec)
INFO:tensorflow:nll_loss = 8996.635, kl_w = 1.0, kl_loss = 1296.7451 (124.700 sec)
INFO:tensorflow:global_step/sec: 0.802225
INFO:tensorflow:loss = 10361.375, step = 13595 (124.653 sec)
INFO:tensorflow:nll_loss = 9046.932, kl_w = 1.0, kl_loss = 1314.4438 (124.653 sec)
INFO:tensorflow:global_step/sec: 0.801316
INFO:tensorflow:loss = 10293.802, step = 13695 (124.795 sec)
INFO:tensorflow:nll

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt-15640
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 15641 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt.
INFO:tensorflow:loss = 10558.953, step = 15641
INFO:tensorflow:nll_loss = 9237.1455, kl_w = 1.0, kl_loss = 1321.8071
INFO:tensorflow:global_step/sec: 0.812453
INFO:tensorflow:loss = 10064.067, step = 15741 (123.085 sec)
INFO:tensorflow:nll_loss = 8726.362, kl_w = 1.0, kl_loss = 1337.7051 (123.085 sec)
INFO:tensorflow:global_step/sec: 0.80349
INFO:tensorflow:loss = 10009.711, step = 15841 (124.457 sec)
INFO:tensorflow:nll_loss = 8604.828, kl_w = 1.0, kl_loss = 1404.883 (124.457 sec)
INFO:tensorflow:global_step/sec: 0.801593
INFO:tensorflow:loss = 9671.873, step = 15941 (124.752 sec)
INFO

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt-17986
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 17987 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt.
INFO:tensorflow:loss = 9669.643, step = 17987
INFO:tensorflow:nll_loss = 8199.959, kl_w = 1.0, kl_loss = 1469.6836
INFO:tensorflow:global_step/sec: 0.736541
INFO:tensorflow:loss = 9624.725, step = 18087 (135.771 sec)
INFO:tensorflow:nll_loss = 8155.2695, kl_w = 1.0, kl_loss = 1469.4547 (135.771 sec)
INFO:tensorflow:global_step/sec: 0.697171
INFO:tensorflow:loss = 9413.46, step = 18187 (143.437 sec)
INFO:tensorflow:nll_loss = 7940.933, kl_w = 1.0, kl_loss = 1472.5267 (143.437 sec)
INFO:tensorflow:global_step/sec: 0.

INFO:tensorflow:Done running local_init_op.

Original: i love this film and i think it is one of the best films
Reconstr: i loved this movie and i first saw it on the top of the <end>

Original: this movie is a waste of time and there is no point to watch it
Reconstr: this movie is one of the best i've seen in all time to the <end>

Epoch 27/30

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt-20332
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 20333 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt.
INFO:tensorflow:loss = 9808.472, step = 20333
INFO:tensorflow:nll_loss = 8283.312, kl_w = 1.0, kl_loss = 1525.1604
INFO:tensorflow:global_step/sec: 0.620099
INFO:tensorflow

INFO:tensorflow:Loss for final step: 2530.2808.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt-22678
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

Original: i love this film and i think it is one of the best films
Reconstr: i loved this movie but i thought it was a lot of the worst <end>

Original: this movie is a waste of time and there is no point to watch it
Reconstr: this is a complete waste of time but it's a hard to watch it <end>

Epoch 30/30

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpqgwokls5/model.ckpt-22678
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:D