In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import tensorflow.contrib.slim as slim  
import matplotlib.pyplot as plt
import numpy as np
import time
import math
from gumbel import *
from ops import *
import reader
from tensorflow.python.client import device_lib

In [2]:
class Config(object):
    init_scale = 0.1
    max_epoch = 3
    max_max_epoch = 50
    batch_size = 20
    display_step = 100
    lr = 1.0
    lr_decay = 0.5
    keep_prob = 0.5
    max_grad_norm = 5
    vocab_size = 10000
    tau0=5.0 # initial temperature
    ANNEAL_RATE=0.1
    MIN_TEMP=0.1
    
    # Network Parameters
    input_size = 256
    hidden_size = 512
    num_steps = 20 # timesteps
    num_layers = 1
    K= 2

In [3]:
class PTBInput(object):
    """The input data."""
    def __init__(self, config, data, name=None):
        self.batch_size = config.batch_size
        self.num_steps = config.num_steps
        self.epoch_size = ((len(data) // self.batch_size) - 1) // self.num_steps
        self.input_data, self.targets = reader.ptb_producer(data, self.batch_size, self.num_steps, name=name)

def PrintPTBSentence(data, dict):
    n, timestep = data.shape
    s = []
    for i in range(n):
        temp = []
        for j in range(timestep):
            temp.append(dict[data[i][j]])
        s.append(temp)
    s = np.array(s)
    return s

def Dropoutcell(config, is_training):
    cell = tf.contrib.rnn.BasicLSTMCell(config.hidden_size, forget_bias=0.0,state_is_tuple=True)
    if is_training:
        cell = tf.contrib.rnn.DropoutWrapper(cell,
                                             output_keep_prob=config.keep_prob,
                                             variational_recurrent=False,
                                             dtype=tf.float32)
    return cell

def run_epoch(session, model, eval_op=None, verbose=False):
    """Runs the model on the given data."""
    start_time = time.time()
    costs = 0.0
    iters = 0
    state = session.run(model.initial_state)

    fetches = {
        "NLL": model.NLL,
        "final_state": model.final_state,
        "KL": model.KL
    }
    if eval_op is not None:
        fetches["eval_op"] = eval_op

    for step in range(model.data.epoch_size):
        feed_dict = {}
        feed_dict[model.initial_state[0]] = state.c
        feed_dict[model.initial_state[1]] = state.h
        
        vals = session.run(fetches, feed_dict)
        cost = vals["NLL"]
        state = vals["final_state"]
        KL = vals["KL"]

        costs += cost
        iters += model.data.num_steps

        if verbose and step % (model.data.epoch_size // 10) == 10:
            print("%.3f perplexity: %.3f speed: %.0f wps NLL: %.3f KL: %f" %
                  (step * 1.0 / model.data.epoch_size, np.exp(costs / iters),\
                   iters * model.data.batch_size /(time.time() - start_time),\
                   cost,KL))

    return np.exp(costs / iters)

In [4]:
class PTBModel(object):
    def __init__(self, is_training, config, data):
        training = tf.placeholder(tf.bool)
        self.data = data
        self.batch_size = data.batch_size
        self.num_steps = data.num_steps
        self.tau = tf.Variable(0.0, trainable=False)
        self.lr  = tf.Variable(0.0, trainable=False)
        
        with tf.device("/cpu:0"):
            self.embedding = tf.get_variable("embedding", [config.vocab_size, config.input_size], dtype=tf.float32)
            self.inputs = tf.nn.embedding_lookup(self.embedding, self.data.input_data)
        if is_training and config.keep_prob < 1:
            self.inputs = tf.nn.dropout(self.inputs, config.keep_prob)
        with tf.variable_scope("model", reuse=tf.AUTO_REUSE ):
            
            self.stcells = [Dropoutcell(config, is_training) for _ in range(config.K)]
            self.initial_state = self.stcells[0].zero_state(self.batch_size, dtype=tf.float32)
            
            self.state = self.initial_state
            self.outputs = []
            self.ALL_z = []
            self.ALL_qz = []
        with tf.variable_scope("RNN", reuse=tf.AUTO_REUSE):
            for time_step in range(self.num_steps):
                with tf.variable_scope('logit_enc'):
                    logit_z = linear(self.inputs[:, time_step, :], config.K, name='L1_x_zin') + linear(self.state[1], config.K, name='L1_h_zin')
                q_z = tf.nn.softmax(logit_z)
                if is_training is True:
                    z = gumbel_softmax(logit_z,self.tau,hard=False)
                else:
                    z = softmax_sample(logit_z)
                self.ALL_z.append(z)
                self.ALL_qz.append(q_z)
                
                h = []
                c = []
                for i in range(config.K):
                    temp_name = 'L1_LSTM_'+str(i)
                    with tf.variable_scope(temp_name):
                        temp_h, temp_hc = self.stcells[i](self.inputs[:, time_step, :], self.state)
                    h.append(temp_h)
                    temp_c, _ = temp_hc
                    c.append(temp_c)
                h = tf.reshape(h, [config.K,self.batch_size,config.hidden_size])
                c = tf.reshape(c, [config.K, self.batch_size, config.hidden_size])
                new_h = tf.einsum('knd,nk->nd',h,z)
                new_c = tf.einsum('knd,nk->nd',c,z)
                self.cell_output = new_h
                self.state = tf.nn.rnn_cell.LSTMStateTuple(new_c, new_h)
                self.outputs.append(self.cell_output)
        self.arrange_outputs = tf.reshape(tf.concat(self.outputs, 1), [-1, config.hidden_size])
        self.logits = tf.reshape(linear(self.arrange_outputs, config.vocab_size), [self.batch_size, self.num_steps, config.vocab_size])
        
        cross_entropy = tf.contrib.seq2seq.sequence_loss(
            self.logits,
            self.data.targets,
            tf.ones([self.batch_size, self.num_steps], dtype=tf.float32),
            average_across_timesteps=False,
            average_across_batch=True)
        self.NLL = tf.reduce_sum(cross_entropy)
        ALL_q_z_tmp = tf.transpose(self.ALL_qz,[1,0,2])
        KL_tmp = ALL_q_z_tmp*(tf.log(ALL_q_z_tmp+1e-20)-tf.log(1.0/config.K))
        self.KL = tf.reduce_mean(tf.reduce_sum(KL_tmp, [1, 2]))
        self.cost = self.NLL
        self.final_state = self.state
        
        if not is_training:
            return
    
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),config.max_grad_norm)
        optimizer = tf.train.GradientDescentOptimizer(self.lr)
        self.train_op = optimizer.apply_gradients(
            zip(grads, tvars),
            global_step=tf.contrib.framework.get_or_create_global_step())
    
        self.new_lr = tf.placeholder(tf.float32, shape=[], name="new_learning_rate")
        self.lr_update = tf.assign(self.lr, self.new_lr)
        self.new_tau = tf.placeholder(tf.float32, shape=[], name="new_tau")
        self.tau_update = tf.assign(self.tau, self.new_tau)
        
    def assign_lr(self, session, lr_value):
        session.run(self.lr_update, feed_dict={self.new_lr: lr_value})
    def assign_tau(self, session, tau_value):
        session.run(self.tau_update, feed_dict={self.new_tau: tau_value})

In [5]:
import os
raw_data = reader.ptb_raw_data('./data/')
train_data, valid_data, test_data, w2id, id2w = raw_data
config = Config()
eval_config = Config()
eval_config.batch_size = 1
eval_config.num_steps = 1

URL = "./checkpoint/MRNN_D="+str(config.hidden_size)+"_K="+str(config.K)

In [6]:
initializer = tf.orthogonal_initializer()

with tf.name_scope("Train"):
    train_input = PTBInput(config=config, data=train_data, name="TrainInput")
    with tf.variable_scope("Model", reuse=None, initializer=initializer):
        m = PTBModel(is_training=True, config=config, data=train_input)

with tf.name_scope("Valid"):
    valid_input = PTBInput(config=config, data=valid_data, name="ValidInput")
    with tf.variable_scope("Model", reuse=True, initializer=initializer):
        mvalid = PTBModel(is_training=False, config=config, data=valid_input)

with tf.name_scope("Test"):
    test_input = PTBInput(config=eval_config, data=test_data, name="TestInput")
    with tf.variable_scope("Model", reuse=True, initializer=initializer):
        mtest = PTBModel(is_training=False, config=config, data=test_input)

Instructions for updating:
Please switch to tf.train.get_or_create_global_step


In [7]:
# Launch the graph
sv = tf.train.Supervisor()
with sv.managed_session() as sess:
    
    ppl_train = []
    ppl_valid = [np.inf]
    m.assign_lr(sess, config.lr)
    start_time = time.time()
    for i in range(config.max_max_epoch):
        
        np_temp=np.maximum(config.tau0*np.exp(-config.ANNEAL_RATE*i),config.MIN_TEMP)
        m.assign_tau(sess, np_temp)
        temp_lr, temp_tau = sess.run([m.lr, m.tau])
        if temp_lr < 0.0001:
            break
        
        print("Epoch: %d Learning rate: %f, Temperature: %f" % (i + 1, temp_lr, temp_tau))
        train_perplexity = run_epoch(sess, m, eval_op=m.train_op,verbose=True)
        ppl_train.append(train_perplexity)
        print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
        
        valid_perplexity = run_epoch(sess, mvalid)
        print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
        ppl_valid.append(valid_perplexity)

        if (i+1) > config.max_epoch:
            if ppl_valid[-2] - ppl_valid[-1] < 0:
                m.assign_lr(sess, temp_lr * config.lr_decay)
    
    end_time = time.time()
    test_perplexity = run_epoch(sess, mtest)
    print("Test Perplexity: %.3f" % test_perplexity)
    print("Training time: %f, Testing time: %f" % (end_time-start_time,time.time()-end_time))
    
    sv.saver.save(sess, URL+"/model.ckpt")

INFO:tensorflow:Starting standard services.
INFO:tensorflow:Starting queue runners.
Epoch: 1 Learning rate: 1.000000, Temperature: 5.000000
0.004 perplexity: 3544.068 speed: 3818 wps NLL: 153.652 KL: 0.205823
0.104 perplexity: 782.620 speed: 9628 wps NLL: 122.325 KL: 1.208275
0.204 perplexity: 600.917 speed: 10000 wps NLL: 120.487 KL: 3.941295
0.304 perplexity: 500.007 speed: 10096 wps NLL: 110.561 KL: 5.276466
0.404 perplexity: 440.432 speed: 10165 wps NLL: 118.639 KL: 5.021752
0.504 perplexity: 400.846 speed: 10197 wps NLL: 107.750 KL: 6.228884
0.604 perplexity: 365.603 speed: 10231 wps NLL: 110.628 KL: 5.431989
0.703 perplexity: 341.293 speed: 10254 wps NLL: 113.856 KL: 5.373662
0.803 perplexity: 322.152 speed: 10262 wps NLL: 103.884 KL: 6.792552
0.903 perplexity: 304.350 speed: 10257 wps NLL: 99.758 KL: 5.713828
Epoch: 1 Train Perplexity: 290.773
Epoch: 1 Valid Perplexity: 185.617
Epoch: 2 Learning rate: 1.000000, Temperature: 4.524187
0.004 perplexity: 243.965 speed: 10151 wps NLL

0.004 perplexity: 101.330 speed: 10081 wps NLL: 94.593 KL: 5.199228
0.104 perplexity: 80.981 speed: 10324 wps NLL: 87.173 KL: 5.100006
0.204 perplexity: 87.729 speed: 10312 wps NLL: 87.301 KL: 5.854121
0.304 perplexity: 86.243 speed: 10304 wps NLL: 81.829 KL: 5.333508
0.404 perplexity: 86.339 speed: 10308 wps NLL: 96.736 KL: 4.919614
0.504 perplexity: 86.700 speed: 10307 wps NLL: 85.548 KL: 5.836352
0.604 perplexity: 85.728 speed: 10304 wps NLL: 94.520 KL: 5.381948
0.703 perplexity: 85.993 speed: 10303 wps NLL: 95.803 KL: 5.764557
0.803 perplexity: 86.272 speed: 10302 wps NLL: 85.977 KL: 6.146986
0.903 perplexity: 85.205 speed: 10303 wps NLL: 84.139 KL: 5.692077
Epoch: 11 Train Perplexity: 85.019
Epoch: 11 Valid Perplexity: 103.926
Epoch: 12 Learning rate: 1.000000, Temperature: 1.664355
0.004 perplexity: 103.719 speed: 10315 wps NLL: 93.657 KL: 5.815705
0.104 perplexity: 80.535 speed: 10297 wps NLL: 88.439 KL: 6.041402
0.204 perplexity: 87.721 speed: 10307 wps NLL: 88.253 KL: 6.414481

0.404 perplexity: 74.665 speed: 10276 wps NLL: 91.703 KL: 10.469229
0.504 perplexity: 74.635 speed: 10282 wps NLL: 81.668 KL: 9.778066
0.604 perplexity: 73.311 speed: 10274 wps NLL: 88.836 KL: 9.819635
0.703 perplexity: 73.140 speed: 10270 wps NLL: 90.203 KL: 9.637306
0.803 perplexity: 72.966 speed: 10272 wps NLL: 81.959 KL: 9.677843
0.903 perplexity: 71.815 speed: 10266 wps NLL: 81.264 KL: 8.979802
Epoch: 21 Train Perplexity: 71.336
Epoch: 21 Valid Perplexity: 92.289
Epoch: 22 Learning rate: 0.250000, Temperature: 0.612282
0.004 perplexity: 86.934 speed: 10417 wps NLL: 91.717 KL: 8.331179
0.104 perplexity: 67.846 speed: 10178 wps NLL: 84.167 KL: 8.486487
0.204 perplexity: 72.874 speed: 10217 wps NLL: 83.833 KL: 8.810019
0.304 perplexity: 71.330 speed: 10254 wps NLL: 78.268 KL: 9.222363
0.404 perplexity: 71.175 speed: 10260 wps NLL: 91.681 KL: 8.432203
0.504 perplexity: 71.322 speed: 10271 wps NLL: 81.479 KL: 8.448047
0.604 perplexity: 70.316 speed: 10274 wps NLL: 88.415 KL: 8.433241
0

0.703 perplexity: 65.276 speed: 10277 wps NLL: 89.764 KL: 12.365152
0.803 perplexity: 65.320 speed: 10271 wps NLL: 80.560 KL: 12.796083
0.903 perplexity: 64.455 speed: 10273 wps NLL: 79.580 KL: 13.079564
Epoch: 31 Train Perplexity: 64.183
Epoch: 31 Valid Perplexity: 88.943
Epoch: 32 Learning rate: 0.125000, Temperature: 0.225246
0.004 perplexity: 79.982 speed: 10637 wps NLL: 87.121 KL: 12.772757
0.104 perplexity: 61.948 speed: 10314 wps NLL: 83.381 KL: 12.508075
0.204 perplexity: 66.791 speed: 10292 wps NLL: 82.648 KL: 12.530244
0.304 perplexity: 65.469 speed: 10282 wps NLL: 75.999 KL: 13.040982
0.404 perplexity: 65.461 speed: 10284 wps NLL: 89.322 KL: 12.767982
0.504 perplexity: 65.658 speed: 10277 wps NLL: 82.074 KL: 12.715817
0.604 perplexity: 64.723 speed: 10286 wps NLL: 87.472 KL: 12.759073
0.703 perplexity: 64.852 speed: 10288 wps NLL: 89.377 KL: 12.453888
0.803 perplexity: 64.898 speed: 10291 wps NLL: 80.091 KL: 12.768628
0.903 perplexity: 64.067 speed: 10295 wps NLL: 77.038 KL:

0.903 perplexity: 60.954 speed: 10287 wps NLL: 75.073 KL: 12.299204
Epoch: 41 Train Perplexity: 60.696
Epoch: 41 Valid Perplexity: 87.566
Epoch: 42 Learning rate: 0.031250, Temperature: 0.100000
0.004 perplexity: 77.859 speed: 10423 wps NLL: 86.435 KL: 12.729106
0.104 perplexity: 58.902 speed: 10218 wps NLL: 80.724 KL: 12.920148
0.204 perplexity: 63.454 speed: 10252 wps NLL: 81.880 KL: 12.733215
0.304 perplexity: 62.241 speed: 10267 wps NLL: 76.605 KL: 12.975183
0.404 perplexity: 62.170 speed: 10289 wps NLL: 89.620 KL: 13.001396
0.504 perplexity: 62.399 speed: 10296 wps NLL: 81.074 KL: 12.822783
0.604 perplexity: 61.505 speed: 10299 wps NLL: 87.139 KL: 12.566619
0.703 perplexity: 61.600 speed: 10296 wps NLL: 90.186 KL: 12.990352
0.803 perplexity: 61.581 speed: 10298 wps NLL: 79.240 KL: 13.048021
0.903 perplexity: 60.753 speed: 10296 wps NLL: 75.639 KL: 12.844942
Epoch: 42 Train Perplexity: 60.455
Epoch: 42 Valid Perplexity: 87.534
Epoch: 43 Learning rate: 0.031250, Temperature: 0.10000

In [8]:
total_para = 0
for var in tf.trainable_variables():
    shape = var.get_shape()
    tmp = 1
    for dim in shape:
        tmp *= dim.value
    total_para += tmp
print('Total parameters: ', total_para)
print(tf.trainable_variables())

Total parameters:  10841364
[<tf.Variable 'Model/embedding:0' shape=(10000, 256) dtype=float32_ref>, <tf.Variable 'Model/RNN/logit_enc/L1_x_zin/weights:0' shape=(256, 2) dtype=float32_ref>, <tf.Variable 'Model/RNN/logit_enc/L1_x_zin/biases:0' shape=(2,) dtype=float32_ref>, <tf.Variable 'Model/RNN/logit_enc/L1_h_zin/weights:0' shape=(512, 2) dtype=float32_ref>, <tf.Variable 'Model/RNN/logit_enc/L1_h_zin/biases:0' shape=(2,) dtype=float32_ref>, <tf.Variable 'Model/RNN/L1_LSTM_0/basic_lstm_cell/kernel:0' shape=(768, 2048) dtype=float32_ref>, <tf.Variable 'Model/RNN/L1_LSTM_0/basic_lstm_cell/bias:0' shape=(2048,) dtype=float32_ref>, <tf.Variable 'Model/RNN/L1_LSTM_1/basic_lstm_cell/kernel:0' shape=(768, 2048) dtype=float32_ref>, <tf.Variable 'Model/RNN/L1_LSTM_1/basic_lstm_cell/bias:0' shape=(2048,) dtype=float32_ref>, <tf.Variable 'Model/linear/weights:0' shape=(512, 10000) dtype=float32_ref>, <tf.Variable 'Model/linear/biases:0' shape=(10000,) dtype=float32_ref>]


In [None]:
# Calculate accuracy for 128 mnist test images
plt.figure()
plt.plot(range(1,len(ppl_train)+1),ppl_train,'b')
plt.show()
plt.figure()
plt.plot(range(1,len(ppl_valid)+1),ppl_valid,'r')
plt.show()

In [None]:
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=2)

model = mtest
sv = tf.train.Supervisor()
z = []
qz = []
h = []
x = []
start_time = time.time()
with sv.managed_session() as sess:
    sv.saver.restore(sess, URL+"/model.ckpt")

    state = sess.run(model.initial_state)
    fetches = {
        "z": model.ALL_z,
        "qz": model.ALL_qz,
        "h": model.state,
        "x": model.data.input_data,
        "y": model.data.targets,
        "logits":model.logits,
        "final_state": model.final_state,
    }
    print(model.data.epoch_size)
    for i in range(model.data.epoch_size):
        feed_dict = {}
        feed_dict[model.initial_state[0]] = state.c
        feed_dict[model.initial_state[1]] = state.h
        vals = sess.run(fetches, feed_dict)
        state = vals["final_state"]
        
        
        h.append(vals['h'])
        qz.append(vals['qz'][0][0])
        z.append(vals['z'][0][0])
        x.append(vals['x'][0][0])
        
    end_time = time.time()
    print("time: %f" % (end_time-start_time))

In [None]:
h = np.reshape(h,[model.data.epoch_size,2,-1])
z = np.array(z)
qz = np.array(qz)
x = np.array(x)

In [None]:
from sklearn.manifold import TSNE
start_time = time.time()
N = 1000
h_tsne = TSNE(n_components=2).fit_transform(h[:N,1,:])
print("time: %f" % (time.time()-start_time))
np.save('./tsne_MRNN_K=4',h_tsne)

In [None]:
fig = plt.figure()
color = ['r','g','b','y']
for i in range(N):
    plt.scatter(h_tsne[i,0], h_tsne[i,1], c=color[np.argmax(z,1)[i]], marker='.')
plt.show()

In [None]:
fig = plt.figure(figsize=[15,15])
plt.imshow(z[:100,:].T)
plt.show()
fig = plt.figure(figsize=[15,15])
plt.imshow(qz[:100,:].T)
plt.show()