Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)

![alt text](vae.png)

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
PARAMS = {
    'max_len': 15,
    'embed_dims': 128,
    'rnn_size': 128,
    'latent_size': 16,
    'beam_width': 5,
    'num_sampled': 1000,
    'clip_norm': 5.0,
    'anneal_max': 1.0,
    'anneal_bias': 6000,
    'batch_size': 128,
    'n_epochs': 10,
}

In [3]:
def build_vocab(index_from=4):
    PARAMS['word2idx'] = tf.keras.datasets.imdb.get_word_index()
    PARAMS['word2idx'] = {k: (v + index_from) for k, v in PARAMS['word2idx'].items()}
    PARAMS['word2idx']['<pad>'] = 0
    PARAMS['word2idx']['<start>'] = 1
    PARAMS['word2idx']['<unk>'] = 2
    PARAMS['word2idx']['<end>'] = 3
    
    PARAMS['idx2word'] = {i: w for w, i in PARAMS['word2idx'].items()}
    
    PARAMS['vocab_size'] = len(PARAMS['word2idx'])

    
def load_data(index_from=4):
    (X_train, _), (X_test, _) = tf.contrib.keras.datasets.imdb.load_data(
        num_words=None, index_from=index_from)
    return (X_train, X_test)

In [4]:
word2idx = build_vocab()
X = np.concatenate(load_data())

X = np.concatenate((
    tf.keras.preprocessing.sequence.pad_sequences(
        X, PARAMS['max_len'], truncating='post', padding='post'),
    tf.keras.preprocessing.sequence.pad_sequences(
        X, PARAMS['max_len'], truncating='pre', padding='post')))

enc_inp = X[:, 1:]
dec_inp = X
dec_out = np.concatenate([X[:, 1:], np.full([X.shape[0], 1], PARAMS['word2idx']['<end>'])], 1)

In [5]:
def reparam_trick(z_mean, z_logvar):
    gaussian = tf.truncated_normal(tf.shape(z_logvar))
    z = z_mean + tf.exp(0.5 * z_logvar) * gaussian
    return z


def kl_w_fn(global_step):
    return PARAMS['anneal_max'] * tf.sigmoid((10 / PARAMS['anneal_bias']) * (
        tf.to_float(global_step) - tf.constant(PARAMS['anneal_bias'] / 2)))


def kl_loss_fn(mean, gamma):
    return 0.5 * tf.reduce_sum(
        tf.exp(gamma) + tf.square(mean) - 1 - gamma) / tf.to_float(tf.shape(mean)[0])

In [6]:
def clip_grads(loss):
    variables = tf.trainable_variables()
    grads = tf.gradients(loss, variables)
    clipped_grads, _ = tf.clip_by_global_norm(grads, PARAMS['clip_norm'])
    return zip(clipped_grads, variables)


def rnn_cell():
    return tf.nn.rnn_cell.GRUCell(PARAMS['rnn_size'],
                                  kernel_initializer=tf.orthogonal_initializer())


def forward(inputs, labels, mode):
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    enc_seq_len = tf.count_nonzero(inputs, 1, dtype=tf.int32)
    batch_sz = tf.shape(inputs)[0]
    
    with tf.variable_scope('Encoder'):
        embedding = tf.get_variable('lookup_table', [len(PARAMS['word2idx']), PARAMS['embed_dims']])
        x = tf.nn.embedding_lookup(embedding, inputs)
        
        _, enc_state = tf.nn.dynamic_rnn(rnn_cell(), x, enc_seq_len, dtype=tf.float32)
        
        z_mean = tf.layers.dense(enc_state, PARAMS['latent_size'])
        z_logvar = tf.layers.dense(enc_state, PARAMS['latent_size'])
        
    z = reparam_trick(z_mean, z_logvar)
        
    with tf.variable_scope('Decoder'):
        init_state = tf.layers.dense(z, PARAMS['rnn_size'], tf.nn.elu)
        output_proj = tf.layers.Dense(len(PARAMS['word2idx']), _scope='decoder/output_proj')
        dec_cell = rnn_cell()
        
        if is_training:
            dec_seq_len = tf.count_nonzero(labels['dec_out'], 1, dtype=tf.int32)
            
            helper = tf.contrib.seq2seq.TrainingHelper(
                inputs = tf.nn.embedding_lookup(embedding, labels['dec_inp']),
                sequence_length = dec_seq_len)
            decoder = tf.contrib.seq2seq.BasicDecoder(
                cell = dec_cell,
                helper = helper,
                initial_state = init_state)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = decoder,
                maximum_iterations = tf.reduce_max(dec_seq_len))
            rnn_output = decoder_output.rnn_output
            
            return rnn_output, output_proj(rnn_output), (z_mean, z_logvar)
        else:
            decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                cell = dec_cell,
                embedding = embedding,
                start_tokens = tf.tile(tf.constant([PARAMS['word2idx']['<start>']], tf.int32),
                                       [batch_sz]),
                end_token = PARAMS['word2idx']['<end>'],
                initial_state = tf.contrib.seq2seq.tile_batch(init_state, PARAMS['beam_width']),
                beam_width = PARAMS['beam_width'],
                output_layer = output_proj)
            decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = decoder)
            
            return decoder_output.predicted_ids[:, :, 0]

In [7]:
def model_fn(features, labels, mode):
    logits_or_ids = forward(features, labels, mode)        
    
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions=logits_or_ids)
        
    if mode == tf.estimator.ModeKeys.TRAIN:
        rnn_output, logits, (z_mean, z_logvar) = logits_or_ids
        
        global_step = tf.train.get_global_step()
        
        with tf.variable_scope('Decoder/decoder/output_proj', reuse=True):
            _weights = tf.transpose(tf.get_variable('kernel'))
            _biases = tf.get_variable('bias')
        
        mask = tf.reshape(tf.to_float(tf.sign(labels['dec_out'])), [-1])
        
        nll_loss = tf.reduce_sum(mask * tf.nn.sampled_softmax_loss(
            weights = _weights,
            biases = _biases,
            labels = tf.reshape(labels['dec_out'], [-1, 1]),
            inputs = tf.reshape(rnn_output, [-1, PARAMS['rnn_size']]),
            num_sampled = PARAMS['num_sampled'],
            num_classes = PARAMS['vocab_size'],
        )) / tf.to_float(tf.shape(features)[0])
        
        kl_w = kl_w_fn(global_step)
        
        kl_loss = kl_loss_fn(z_mean, z_logvar)
        
        loss_op = nll_loss + kl_w * kl_loss
        
        train_op = tf.train.AdamOptimizer().apply_gradients(
            clip_grads(loss_op),
            global_step = global_step)
        
        lth = tf.train.LoggingTensorHook(
            {'nll_loss': nll_loss, 'kl_w': kl_w, 'kl_loss': kl_loss}, every_n_iter=100)
        
        return tf.estimator.EstimatorSpec(
            mode=mode, loss=loss_op, train_op=train_op, training_hooks=[lth])

In [8]:
def inf_inp(test_strs):
    x = [[PARAMS['word2idx'].get(w, 2) for w in s.split()] for s in test_strs]
    x = tf.keras.preprocessing.sequence.pad_sequences(
        x, PARAMS['max_len'], truncating='post', padding='post')
    return x

def demo(test_strs, pred_ids):
    for s, pred in zip(test_strs, pred_ids):
        print('\nOriginal:', s)
        print('Reconstr:', ' '.join([PARAMS['idx2word'].get(idx, '<unk>') for idx in pred]))


test_strs = ['i love this film and i think it is one of the best films',
             'this movie is a waste of time and there is no point to watch it']

estimator = tf.estimator.Estimator(model_fn)

for _ in range(PARAMS['n_epochs']):
    estimator.train(tf.estimator.inputs.numpy_input_fn(
        x = enc_inp,
        y = {'dec_inp': dec_inp, 'dec_out': dec_out},
        batch_size = PARAMS['batch_size'],
        shuffle = True))

    pred_ids = list(estimator.predict(tf.estimator.inputs.numpy_input_fn(
        x = inf_inp(test_strs),
        shuffle = False)))

    demo(test_strs, pred_ids)
    print()

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpgmez5mla', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x119065518>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See @{tf.nn.softmax_cross_entropy_with_logits_v2}.

INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointS

INFO:tensorflow:global_step/sec: 3.08222
INFO:tensorflow:loss = 62.404377, step = 2265 (32.444 sec)
INFO:tensorflow:nll_loss = 59.62514, kl_w = 0.22676536, kl_loss = 12.256002 (32.444 sec)
INFO:tensorflow:Saving checkpoints for 2346 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpgmez5mla/model.ckpt.
INFO:tensorflow:Loss for final step: 62.869072.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpgmez5mla/model.ckpt-2346
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

Original: i love this film and i think it is one of the best films
Reconstr: i give this one of this movie and this is a great movie and <end>

Original: this movie is a waste of time and there is no point to watch it
Reconstr: it is a great film and not to watch your time for the movie <end>

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:nll_loss = 55.239914, kl_w = 0.93603605, kl_loss = 4.6789 (32.706 sec)
INFO:tensorflow:Saving checkpoints for 4692 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpgmez5mla/model.ckpt.
INFO:tensorflow:Loss for final step: 57.501835.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpgmez5mla/model.ckpt-4692
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

Original: i love this film and i think it is one of the best films
Reconstr: and i love this movie is one of the worst movies i've ever seen <end>

Original: this movie is a waste of time and there is no point to watch it
Reconstr: this movie is a must see if you want to be a fan of <end>

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph 

INFO:tensorflow:Loss for final step: 54.943295.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpgmez5mla/model.ckpt-7038
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.

Original: i love this film and i think it is one of the best films
Reconstr: i think that this is one of the best movies i have ever seen <end>

Original: this movie is a waste of time and there is no point to watch it
Reconstr: this is one of the worst movies i've ever seen and i have seen <end>

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpgmez5mla/model.ckpt-7038
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running