Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)

![alt text](vae.png)

Related Papers
* [Generating Sentences from a Continuous Space](https://arxiv.org/abs/1511.06349)

In [1]:
import tensorflow as tf
import numpy as np
from modified import ModifiedBasicDecoder, ModifiedBeamSearchDecoder

In [2]:
PARAMS = {
    'max_len': 15,
    'embed_dims': 128,
    'rnn_size': 128,
    'latent_size': 16,
    'beam_width': 5,
    'num_sampled': 1000,
    'clip_norm': 5.0,
    'anneal_max': 1.0,
    'anneal_bias': 6000,
    'batch_size': 128,
    'n_epochs': 30,
    'word_dropout_rate': 0.8,
}

In [3]:
def build_vocab(index_from=4):
    PARAMS['word2idx'] = tf.keras.datasets.imdb.get_word_index()
    PARAMS['word2idx'] = {k: (v + index_from) for k, v in PARAMS['word2idx'].items()}
    PARAMS['word2idx']['<pad>'] = 0
    PARAMS['word2idx']['<start>'] = 1
    PARAMS['word2idx']['<unk>'] = 2
    PARAMS['word2idx']['<end>'] = 3
    PARAMS['idx2word'] = {i: w for w, i in PARAMS['word2idx'].items()}
    PARAMS['vocab_size'] = len(PARAMS['word2idx'])

    
def load_data(index_from=4):
    (X_train, _), (X_test, _) = tf.contrib.keras.datasets.imdb.load_data(
        num_words=None, index_from=index_from)
    return (X_train, X_test)

In [4]:
def reparam_trick(z_mean, z_logvar):
    gaussian = tf.truncated_normal(tf.shape(z_logvar))
    z = z_mean + tf.exp(0.5 * z_logvar) * gaussian
    return z


def kl_w_fn(global_step):
    return PARAMS['anneal_max'] * tf.sigmoid((10 / PARAMS['anneal_bias']) * (
        tf.to_float(global_step) - tf.constant(PARAMS['anneal_bias'] / 2)))


def kl_loss_fn(mean, logvar):
    loss = 0.5 * tf.reduce_sum(tf.exp(logvar) + tf.square(mean) - 1 - logvar)
    return loss / tf.to_float(tf.shape(mean)[0])

In [5]:
def clip_grads(loss):
    variables = tf.trainable_variables()
    grads = tf.gradients(loss, variables)
    clipped_grads, _ = tf.clip_by_global_norm(grads, PARAMS['clip_norm'])
    return zip(clipped_grads, variables)


def rnn_cell():
    return tf.nn.rnn_cell.GRUCell(PARAMS['rnn_size'],
                                  kernel_initializer=tf.orthogonal_initializer())


def forward(inputs, labels, mode):
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    enc_seq_len = tf.count_nonzero(inputs, 1, dtype=tf.int32)
    batch_sz = tf.shape(inputs)[0]
    
    with tf.variable_scope('Encoder'):
        embedding = tf.get_variable('lookup_table', [len(PARAMS['word2idx']), PARAMS['embed_dims']])
        x = tf.nn.embedding_lookup(embedding, inputs)
        
        _, enc_state = tf.nn.dynamic_rnn(rnn_cell(), x, enc_seq_len, dtype=tf.float32)
        
        z_mean = tf.layers.dense(enc_state, PARAMS['latent_size'])
        z_logvar = tf.layers.dense(enc_state, PARAMS['latent_size'])
        
    z = reparam_trick(z_mean, z_logvar)
        
    with tf.variable_scope('Decoder'):
        init_state = tf.layers.dense(z, PARAMS['rnn_size'], tf.nn.elu)
        output_proj = tf.layers.Dense(len(PARAMS['word2idx']), _scope='decoder/output_proj')
        dec_cell = rnn_cell()
        
        if is_training:
            dec_seq_len = tf.count_nonzero(labels['dec_out'], 1, dtype=tf.int32)
            
            helper = tf.contrib.seq2seq.TrainingHelper(
                inputs = tf.nn.embedding_lookup(embedding, labels['dec_inp']),
                sequence_length = dec_seq_len)
            decoder = ModifiedBasicDecoder(
                cell = dec_cell,
                helper = helper,
                initial_state = init_state,
                concat_z = z)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = decoder,
                maximum_iterations = tf.reduce_max(dec_seq_len))
            rnn_output = decoder_output.rnn_output
            
            return rnn_output, output_proj(rnn_output), (z_mean, z_logvar)
        else:
            tiled_z = tf.tile(tf.expand_dims(z, 1), [1, PARAMS['beam_width'], 1])
            
            decoder = ModifiedBeamSearchDecoder(
                cell = dec_cell,
                embedding = embedding,
                start_tokens = tf.tile(tf.constant([PARAMS['word2idx']['<start>']], tf.int32),
                                       [batch_sz]),
                end_token = PARAMS['word2idx']['<end>'],
                initial_state = tf.contrib.seq2seq.tile_batch(init_state, PARAMS['beam_width']),
                beam_width = PARAMS['beam_width'],
                output_layer = output_proj,
                concat_z = tiled_z)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = decoder)
            
            return decoder_output.predicted_ids[:, :, 0]

In [6]:
def model_fn(features, labels, mode):
    logits_or_ids = forward(features, labels, mode)        
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions=logits_or_ids)
        
    if mode == tf.estimator.ModeKeys.TRAIN:
        rnn_output, logits, (z_mean, z_logvar) = logits_or_ids
        
        global_step = tf.train.get_global_step()
        
        with tf.variable_scope('Decoder/decoder/output_proj', reuse=True):
            _weights = tf.transpose(tf.get_variable('kernel'))
            _biases = tf.get_variable('bias')
        
        mask = tf.reshape(tf.to_float(tf.sign(labels['dec_out'])), [-1])
        
        nll_loss = tf.reduce_sum(mask * tf.nn.sampled_softmax_loss(
            weights = _weights,
            biases = _biases,
            labels = tf.reshape(labels['dec_out'], [-1, 1]),
            inputs = tf.reshape(rnn_output, [-1, PARAMS['rnn_size']]),
            num_sampled = PARAMS['num_sampled'],
            num_classes = PARAMS['vocab_size'],
        )) / tf.to_float(tf.shape(features)[0])
        
        kl_w = kl_w_fn(global_step)
        
        kl_loss = kl_loss_fn(z_mean, z_logvar)
        
        loss_op = nll_loss + kl_w * kl_loss
        
        train_op = tf.train.AdamOptimizer().apply_gradients(
            clip_grads(loss_op),
            global_step = global_step)
        
        lth = tf.train.LoggingTensorHook(
            {'nll_loss': nll_loss, 'kl_w': kl_w, 'kl_loss': kl_loss}, every_n_iter=100)
        
        return tf.estimator.EstimatorSpec(
            mode=mode, loss=loss_op, train_op=train_op, training_hooks=[lth])

In [7]:
def inf_inp(test_strs):
    x = [[PARAMS['word2idx'].get(w, 2) for w in s.split()] for s in test_strs]
    x = tf.keras.preprocessing.sequence.pad_sequences(
        x, PARAMS['max_len'], truncating='post', padding='post')
    return x


def demo(test_strs, pred_ids):
    for s, pred in zip(test_strs, pred_ids):
        print('\nOriginal:', s)
        print('Reconstr:', ' '.join([PARAMS['idx2word'].get(idx, '<unk>') for idx in pred]))

        
def word_dropout(x):
    is_dropped = np.random.binomial(1, PARAMS['word_dropout_rate'], x.shape)
    fn = np.vectorize(lambda x, k: PARAMS['word2idx']['<unk>'] if (
                      k and (x not in range(4))) else x)
    return fn(x, is_dropped)
        
        
def main():
    build_vocab()
    
    X = np.concatenate(load_data())
    X = np.concatenate((tf.keras.preprocessing.sequence.pad_sequences(
                            X, PARAMS['max_len'], truncating='post', padding='post'),
                        tf.keras.preprocessing.sequence.pad_sequences(
                            X, PARAMS['max_len'], truncating='pre', padding='post')))

    enc_inp = X[:, 1:]
    dec_inp = X[:]
    dec_out = np.concatenate([X[:, 1:], np.full([X.shape[0], 1], PARAMS['word2idx']['<end>'])], 1)

    test_strs = ['i love this film and i think it is one of the best films',
                 'this movie is a waste of time and there is no point to watch it']

    estimator = tf.estimator.Estimator(model_fn)

    for _ in range(PARAMS['n_epochs']):
        estimator.train(tf.estimator.inputs.numpy_input_fn(
            x = enc_inp,
            y = {'dec_inp': word_dropout(dec_inp), 'dec_out': dec_out},
            batch_size = PARAMS['batch_size'],
            shuffle = True))

        pred_ids = list(estimator.predict(tf.estimator.inputs.numpy_input_fn(
            x = inf_inp(test_strs),
            shuffle = False)))

        demo(test_strs, pred_ids)
        print()

In [8]:
main()

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x121dd76d8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See @{tf.nn.softmax_cross_entropy_with_logits_v2}.

INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointS

INFO:tensorflow:global_step/sec: 2.89008
INFO:tensorflow:loss = 70.74055, step = 2265 (34.602 sec)
INFO:tensorflow:nll_loss = 67.16752, kl_w = 0.22676536, kl_loss = 15.756486 (34.602 sec)
INFO:tensorflow:Saving checkpoints for 2346 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt.
INFO:tensorflow:Loss for final step: 67.699715.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt-2346
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

Original: i love this film and i think it is one of the best films
Reconstr: it was this movie i have seen it is a lot of the time <end>

Original: this movie is a waste of time and there is no point to watch it
Reconstr: this movie is the worst movie and ever seen this movie ever seen it <end>

INFO:tensorflow:Calling model_fn

INFO:tensorflow:Saving checkpoints for 4692 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt.
INFO:tensorflow:Loss for final step: 73.8559.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt-4692
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

Original: i love this film and i think it is one of the best films
Reconstr: i watched this movie and i have seen it is one of the time <end>

Original: this movie is a waste of time and there is no point to watch it
Reconstr: this movie was one of the time it was a chance to watch it <end>

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8nj

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt-7038
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

Original: i love this film and i think it is one of the best films
Reconstr: i love this movie when i have seen it is a lot of the <end>

Original: this movie is a waste of time and there is no point to watch it
Reconstr: this is one of the worst movie i've ever seen it was some and <end>

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt-7038
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 7

INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

Original: i love this film and i think it is one of the best films
Reconstr: i saw this movie and i think it is a lot of the best <end> <end>

Original: this movie is a waste of time and there is no point to watch it
Reconstr: this movie is a great movie of it is a good and it was <end> <end>

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt-9384
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 9385 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt.
INFO:tensorflow:loss = 70.768326, step = 9385
INFO:tensorflow:nll_loss = 60.89624, kl_w = 0.99997604, kl_loss = 9.872324
INFO:tensorflow:global

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt-11730
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 11731 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt.
INFO:tensorflow:loss = 64.29591, step = 11731
INFO:tensorflow:nll_loss = 53.901, kl_w = 0.9999995, kl_loss = 10.394917
INFO:tensorflow:global_step/sec: 3.16656
INFO:tensorflow:loss = 66.59424, step = 11831 (31.581 sec)
INFO:tensorflow:nll_loss = 56.07374, kl_w = 0.99999964, kl_loss = 10.520504 (31.581 sec)
INFO:tensorflow:global_step/sec: 2.95834
INFO:tensorflow:loss = 68.798035, step = 11931 (33.803 sec)
INFO:tensorflow:nll_loss = 58.88025, kl_w = 0.99999964, kl_loss = 9.917789 (33.803 sec)
INFO:tensorflow:global_

INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt-14076
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 14077 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp88gb9b9a/model.ckpt.
INFO:tensorflow:loss = 65.15292, step = 14077
INFO:tensorflow:nll_loss = 54.234844, kl_w = 1.0, kl_loss = 10.918081
INFO:tensorflow:global_step/sec: 2.83678
INFO:tensorflow:loss = 63.45938, step = 14177 (35.253 sec)
INFO:tensorflow:nll_loss = 52.544594, kl_w = 1.0, kl_loss = 10.914787 (35.253 sec)
INFO:tensorflow:global_step/sec: 2.84078
INFO:tensorflow:loss = 63.396935, step = 14277 (35.201 sec)
INFO:tensorflow:nll_loss = 52.534294, kl_w = 1.0, kl_loss = 10.86264 (35.201 sec)
INFO:tensorflow:global_step/sec: 3.08253
INFO:tensorflow:loss = 67.05235, step = 14377 (32.441 sec)
INFO:tensorflow:nll_loss = 56.17409, kl_w = 1.0, kl_loss = 10.87826 (32.441 sec)
INFO:tensor

INFO:tensorflow:loss = 66.3771, step = 16423
INFO:tensorflow:nll_loss = 55.477703, kl_w = 1.0, kl_loss = 10.899393
INFO:tensorflow:global_step/sec: 2.80556
INFO:tensorflow:loss = 66.11226, step = 16523 (35.645 sec)
INFO:tensorflow:nll_loss = 55.187656, kl_w = 1.0, kl_loss = 10.924601 (35.645 sec)
INFO:tensorflow:global_step/sec: 2.96816
INFO:tensorflow:loss = 64.67673, step = 16623 (33.691 sec)
INFO:tensorflow:nll_loss = 53.762253, kl_w = 1.0, kl_loss = 10.914475 (33.691 sec)
INFO:tensorflow:global_step/sec: 2.7908
INFO:tensorflow:loss = 60.311344, step = 16723 (35.832 sec)
INFO:tensorflow:nll_loss = 48.914627, kl_w = 1.0, kl_loss = 11.396717 (35.832 sec)
INFO:tensorflow:global_step/sec: 2.89528
INFO:tensorflow:loss = 62.41993, step = 16823 (34.539 sec)
INFO:tensorflow:nll_loss = 51.144005, kl_w = 1.0, kl_loss = 11.275923 (34.539 sec)
INFO:tensorflow:global_step/sec: 2.76824
INFO:tensorflow:loss = 60.796547, step = 16923 (36.124 sec)
INFO:tensorflow:nll_loss = 49.394997, kl_w = 1.0, kl

INFO:tensorflow:nll_loss = 52.75791, kl_w = 1.0, kl_loss = 11.353512 (32.101 sec)
INFO:tensorflow:global_step/sec: 3.06004
INFO:tensorflow:loss = 62.172626, step = 18969 (32.679 sec)
INFO:tensorflow:nll_loss = 50.66104, kl_w = 1.0, kl_loss = 11.511585 (32.679 sec)
INFO:tensorflow:global_step/sec: 2.38843
INFO:tensorflow:loss = 62.3083, step = 19069 (41.870 sec)
INFO:tensorflow:nll_loss = 51.093185, kl_w = 1.0, kl_loss = 11.215116 (41.871 sec)
INFO:tensorflow:global_step/sec: 2.35495
INFO:tensorflow:loss = 61.315468, step = 19169 (42.463 sec)
INFO:tensorflow:nll_loss = 49.836723, kl_w = 1.0, kl_loss = 11.4787445 (42.462 sec)
INFO:tensorflow:global_step/sec: 2.43249
INFO:tensorflow:loss = 65.64402, step = 19269 (41.111 sec)
INFO:tensorflow:nll_loss = 54.46495, kl_w = 1.0, kl_loss = 11.1790695 (41.111 sec)
INFO:tensorflow:global_step/sec: 2.82874
INFO:tensorflow:loss = 64.85683, step = 19369 (35.350 sec)
INFO:tensorflow:nll_loss = 53.872993, kl_w = 1.0, kl_loss = 10.983837 (35.349 sec)
IN

INFO:tensorflow:global_step/sec: 2.73048
INFO:tensorflow:loss = 64.149254, step = 21415 (36.624 sec)
INFO:tensorflow:nll_loss = 52.80629, kl_w = 1.0, kl_loss = 11.342966 (36.626 sec)
INFO:tensorflow:global_step/sec: 2.81947
INFO:tensorflow:loss = 62.412476, step = 21515 (35.467 sec)
INFO:tensorflow:nll_loss = 50.978462, kl_w = 1.0, kl_loss = 11.434011 (35.466 sec)
INFO:tensorflow:global_step/sec: 2.82561
INFO:tensorflow:loss = 60.556, step = 21615 (35.391 sec)
INFO:tensorflow:nll_loss = 48.821335, kl_w = 1.0, kl_loss = 11.734667 (35.391 sec)
INFO:tensorflow:global_step/sec: 2.67037
INFO:tensorflow:loss = 62.503704, step = 21715 (37.448 sec)
INFO:tensorflow:nll_loss = 50.98715, kl_w = 1.0, kl_loss = 11.516556 (37.449 sec)
INFO:tensorflow:global_step/sec: 2.83202
INFO:tensorflow:loss = 60.516922, step = 21815 (35.310 sec)
INFO:tensorflow:nll_loss = 49.139114, kl_w = 1.0, kl_loss = 11.37781 (35.310 sec)
INFO:tensorflow:Saving checkpoints for 21896 into /var/folders/sx/fv0r97j96fz8njp14dt5