In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import tensorflow.contrib.slim as slim  
import matplotlib.pyplot as plt
import numpy as np
import time
import math
from gumbel import *
from ops import *
import reader
from tensorflow.python.client import device_lib

In [None]:
class Config(object):
    init_scale = 0.1
    max_epoch = 3
    max_max_epoch = 50
    batch_size = 20
    display_step = 100
    lr = 30.0
    lr_decay = 0.5
    keep_prob = 0.5
    max_grad_norm = 0.25
    vocab_size = 10000
    tau0=5.0 # initial temperature
    ANNEAL_RATE=0.1
    MIN_TEMP=0.1
#     alpha=1
#     beta = 0.25
    
    # Network Parameters
    input_size = 300
    hidden_size = 900
    num_steps = 35 # timesteps
    num_layers = 1
    K= 2

In [None]:
class PTBInput(object):
    """The input data."""
    def __init__(self, config, data, name=None):
        self.batch_size = config.batch_size
        self.num_steps = config.num_steps
        self.epoch_size = ((len(data) // self.batch_size) - 1) // self.num_steps
        self.input_data, self.targets = reader.ptb_producer(data, self.batch_size, self.num_steps, name=name)

def PrintPTBSentence(data, dict):
    n, timestep = data.shape
    s = []
    for i in range(n):
        temp = []
        for j in range(timestep):
            temp.append(dict[data[i][j]])
        s.append(temp)
    s = np.array(s)
    return s

def Dropoutcell(config, is_training):
    cell = tf.contrib.rnn.BasicLSTMCell(config.hidden_size, forget_bias=0.0,state_is_tuple=True)
    if is_training:
        cell = tf.contrib.rnn.DropoutWrapper(cell,
                                             output_keep_prob=config.keep_prob,
                                             variational_recurrent=False,
                                             dtype=tf.float32)
    return cell

def run_epoch(session, model, eval_op=None, verbose=False):
    """Runs the model on the given data."""
    start_time = time.time()
    costs = 0.0
    iters = 0
    state = session.run(model.initial_state)

    fetches = {
        "NLL": model.NLL,
        "final_state": model.final_state,
        "KL": model.KL
    }
    if eval_op is not None:
        fetches["eval_op"] = eval_op

    for step in range(model.data.epoch_size):
        feed_dict = {}
        feed_dict[model.initial_state[0]] = state.c
        feed_dict[model.initial_state[1]] = state.h
        
        vals = session.run(fetches, feed_dict)
        cost = vals["NLL"]
        state = vals["final_state"]
        KL = vals["KL"]

        costs += cost
        iters += model.data.num_steps

        if verbose and step % (model.data.epoch_size // 10) == 10:
            print("%.3f perplexity: %.3f speed: %.0f wps NLL: %.3f KL: %f" %
                  (step * 1.0 / model.data.epoch_size, np.exp(costs / iters),\
                   iters * model.data.batch_size /(time.time() - start_time),\
                   cost,KL))

    return np.exp(costs / iters)

In [None]:
class PTBModel(object):
    def __init__(self, is_training, config, data):
        training = tf.placeholder(tf.bool)
        self.data = data
        self.batch_size = data.batch_size
        self.num_steps = data.num_steps
        self.tau = tf.Variable(0.0, trainable=False)
        self.lr  = tf.Variable(0.0, trainable=False)
        
        with tf.device("/cpu:0"):
            self.embedding = tf.get_variable("embedding", [config.vocab_size, config.input_size], dtype=tf.float32)
            self.inputs = tf.nn.embedding_lookup(self.embedding, self.data.input_data)
        if is_training and config.keep_prob < 1:
            self.inputs = tf.nn.dropout(self.inputs, config.keep_prob)
        with tf.variable_scope("model", reuse=tf.AUTO_REUSE ):
            
            self.stcells = [Dropoutcell(config, is_training) for _ in range(config.K)]
            self.initial_state = self.stcells[0].zero_state(self.batch_size, dtype=tf.float32)
            
            self.state = self.initial_state
            self.outputs = []
            self.ALL_z = []
            self.ALL_qz = []
        with tf.variable_scope("RNN", reuse=tf.AUTO_REUSE):
            for time_step in range(self.num_steps):
                with tf.variable_scope('logit_enc'):
                    logit_z = linear(self.inputs[:, time_step, :], config.K, name='L1_x_zin') + linear(self.state[1], config.K, name='L1_h_zin')
                q_z = tf.nn.softmax(logit_z)
                if is_training is True:
                    z = gumbel_softmax(logit_z,self.tau,hard=False)
                else:
                    z = softmax_sample(logit_z)
                self.ALL_z.append(z)
                self.ALL_qz.append(q_z)
                
                h = []
                c = []
                for i in range(config.K):
                    temp_name = 'L1_LSTM_'+str(i)
                    with tf.variable_scope(temp_name):
                        temp_h, temp_hc = self.stcells[i](self.inputs[:, time_step, :], self.state)
                    h.append(temp_h)
                    temp_c, _ = temp_hc
                    c.append(temp_c)
                h = tf.reshape(h, [config.K,self.batch_size,config.hidden_size])
                c = tf.reshape(c, [config.K, self.batch_size, config.hidden_size])
                new_h = tf.einsum('knd,nk->nd',h,z)
                new_c = tf.einsum('knd,nk->nd',c,z)
                self.cell_output = new_h
                self.state = tf.nn.rnn_cell.LSTMStateTuple(new_c, new_h)
                self.outputs.append(self.cell_output)
        self.arrange_outputs = tf.reshape(tf.concat(self.outputs, 1), [-1, config.hidden_size])
        self.logits = tf.reshape(linear(self.arrange_outputs, config.vocab_size), [self.batch_size, self.num_steps, config.vocab_size])
        
        cross_entropy = tf.contrib.seq2seq.sequence_loss(
            self.logits,
            self.data.targets,
            tf.ones([self.batch_size, self.num_steps], dtype=tf.float32),
            average_across_timesteps=False,
            average_across_batch=True)
        self.NLL = tf.reduce_sum(cross_entropy)
        ALL_q_z_tmp = tf.transpose(self.ALL_qz,[1,0,2])
        KL_tmp = ALL_q_z_tmp*(tf.log(ALL_q_z_tmp+1e-20)-tf.log(1.0/config.K))
        self.KL = tf.reduce_mean(tf.reduce_sum(KL_tmp, [1, 2]))
        self.cost = self.NLL
        self.final_state = self.state
        
        if not is_training:
            return
    
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),config.max_grad_norm)
        optimizer = tf.train.GradientDescentOptimizer(self.lr)
        self.train_op = optimizer.apply_gradients(
            zip(grads, tvars),
            global_step=tf.contrib.framework.get_or_create_global_step())
    
        self.new_lr = tf.placeholder(tf.float32, shape=[], name="new_learning_rate")
        self.lr_update = tf.assign(self.lr, self.new_lr)
        self.new_tau = tf.placeholder(tf.float32, shape=[], name="new_tau")
        self.tau_update = tf.assign(self.tau, self.new_tau)
        
    def assign_lr(self, session, lr_value):
        session.run(self.lr_update, feed_dict={self.new_lr: lr_value})
    def assign_tau(self, session, tau_value):
        session.run(self.tau_update, feed_dict={self.new_tau: tau_value})

In [None]:
import os
raw_data = reader.ptb_raw_data('./data/')
train_data, valid_data, test_data, w2id, id2w = raw_data
config = Config()
eval_config = Config()
eval_config.batch_size = 1
eval_config.num_steps = 1

URL = "./checkpoint/MRNN_D="+str(config.hidden_size)+"_K="+str(config.K)
# os.makedirs(URL)

In [None]:
# initializer = tf.contrib.layers.xavier_initializer()
# initializer = tf.random_uniform_initializer(-config.init_scale,config.init_scale)
initializer = tf.orthogonal_initializer()

with tf.name_scope("Train"):
    train_input = PTBInput(config=config, data=train_data, name="TrainInput")
    with tf.variable_scope("Model", reuse=None, initializer=initializer):
        m = PTBModel(is_training=True, config=config, data=train_input)

with tf.name_scope("Valid"):
    valid_input = PTBInput(config=config, data=valid_data, name="ValidInput")
    with tf.variable_scope("Model", reuse=True, initializer=initializer):
        mvalid = PTBModel(is_training=False, config=config, data=valid_input)

with tf.name_scope("Test"):
    test_input = PTBInput(config=eval_config, data=test_data, name="TestInput")
    with tf.variable_scope("Model", reuse=True, initializer=initializer):
        mtest = PTBModel(is_training=False, config=config, data=test_input)

In [None]:
# # Launch the graph
# sv = tf.train.Supervisor()
# with sv.managed_session() as sess:
#     state = sess.run(m.initial_state)
#     feed_dict = {}
#     for i, (c, h) in enumerate(m.initial_state):
#         feed_dict[c] = state[i].c
#         feed_dict[h] = state[i].h

#     print(m.initial_state)

In [None]:
# Launch the graph
sv = tf.train.Supervisor()
with sv.managed_session() as sess:
    
    ppl_train = []
    ppl_valid = [np.inf]
    m.assign_lr(sess, config.lr)
    start_time = time.time()
    for i in range(config.max_max_epoch):
        
        np_temp=np.maximum(config.tau0*np.exp(-config.ANNEAL_RATE*i),config.MIN_TEMP)
        m.assign_tau(sess, np_temp)
        temp_lr, temp_tau = sess.run([m.lr, m.tau])
        if temp_lr < 0.0001:
            break
        
        print("Epoch: %d Learning rate: %f, Temperature: %f" % (i + 1, temp_lr, temp_tau))
        train_perplexity = run_epoch(sess, m, eval_op=m.train_op,verbose=True)
        ppl_train.append(train_perplexity)
        print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
        
        valid_perplexity = run_epoch(sess, mvalid)
        print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
        ppl_valid.append(valid_perplexity)

        if (i+1) > config.max_epoch:
            if ppl_valid[-2] - ppl_valid[-1] < 0:
                m.assign_lr(sess, temp_lr * config.lr_decay)
                
#         if (i+1) > config.max_epoch:
#             m.assign_lr(sess, temp_lr * config.lr_decay)
    
    end_time = time.time()
    test_perplexity = run_epoch(sess, mtest)
    print("Test Perplexity: %.3f" % test_perplexity)
    print("Training time: %f, Testing time: %f" % (end_time-start_time,time.time()-end_time))
    
#     sv.saver.save(sess, URL+"/model.ckpt")

In [None]:
total_para = 0
for var in tf.trainable_variables():
    shape = var.get_shape()
    tmp = 1
    for dim in shape:
        tmp *= dim.value
    total_para += tmp
print('Total parameters: ', total_para)
print(tf.trainable_variables())

In [None]:
# Calculate accuracy for 128 mnist test images
plt.figure()
plt.plot(range(1,len(ppl_train)+1),ppl_train,'b')
plt.show()
plt.figure()
plt.plot(range(1,len(ppl_valid)+1),ppl_valid,'r')
plt.show()

In [None]:
np.savetxt("temp1.csv", ppl_train, delimiter=",")
np.savetxt("temp2.csv", ppl_valid, delimiter=",")

In [None]:
sv = tf.train.Supervisor()
with sv.managed_session() as sess:
    sv.saver.restore(sess, URL+"/model.ckpt")
    test_perplexity = run_epoch(sess, mtest)
    print("Test Perplexity: %.3f" % test_perplexity)

In [None]:
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    return np.exp(x) / np.sum(np.exp(x), axis=2)

model = mtest
sv = tf.train.Supervisor()
z = []
qz = []
h = []
x = []
start_time = time.time()
with sv.managed_session() as sess:
    sv.saver.restore(sess, URL+"/model.ckpt")

    state = sess.run(model.initial_state)
    fetches = {
        "z": model.ALL_z,
        "qz": model.ALL_qz,
        "h": model.state,
        "x": model.data.input_data,
        "y": model.data.targets,
        "logits":model.logits,
        "final_state": model.final_state,
    }
    print(model.data.epoch_size)
    for i in range(model.data.epoch_size):
        feed_dict = {}
        feed_dict[model.initial_state[0]] = state.c
        feed_dict[model.initial_state[1]] = state.h
        vals = sess.run(fetches, feed_dict)
        state = vals["final_state"]
        
        
        h.append(vals['h'])
        qz.append(vals['qz'][0][0])
        z.append(vals['z'][0][0])
        x.append(vals['x'][0][0])
        
        
#         print(id2w[vals['x'][0][0]])
#         print(id2w[vals['y'][0][0]])
#         print(id2w[np.argmax(vals['logits'])])
#         print('-----------------')
    end_time = time.time()
    print("time: %f" % (end_time-start_time))

In [None]:
h = np.reshape(h,[model.data.epoch_size,2,-1])
z = np.array(z)
qz = np.array(qz)
x = np.array(x)

In [None]:
from sklearn.manifold import TSNE
start_time = time.time()
N = 1000
h_tsne = TSNE(n_components=2).fit_transform(h[:N,1,:])
print("time: %f" % (time.time()-start_time))
np.save('./tsne_MRNN_K=4',h_tsne)

In [None]:
fig = plt.figure()
color = ['r','g','b','y']
for i in range(N):
    plt.scatter(h_tsne[i,0], h_tsne[i,1], c=color[np.argmax(z,1)[i]], marker='.')
#     plt.scatter(h_tsne[i,0], h_tsne[i,1], h_tsne[i,2], c=color[np.argmax(z,1)[i]], marker='.')
plt.show()

In [None]:
from sklearn.manifold import TSNE
start_time = time.time()
h_tsne = TSNE(n_components=2).fit_transform(np.reshape(h,[model.data.epoch_size,-1])[:N,:])
print("time: %f" % (time.time()-start_time))
np.save('./tsne_MRNN_K=4',h_tsne)

In [None]:
fig = plt.figure()
color = ['r','g','b','y']
for i in range(N):
    plt.scatter(h_tsne[i,0], h_tsne[i,1], c=color[np.argmax(z,1)[i]], marker='.')
plt.show()

In [None]:
fig = plt.figure(figsize=[15,15])
plt.imshow(z[:100,:].T)
plt.show()
fig = plt.figure(figsize=[15,15])
plt.imshow(qz[:100,:].T)
plt.show()