In [1]:
# import tensorflow & libraries 

import tensorflow as tf
import numpy as np

from tensorflow.contrib.rnn import LSTMCell, MultiRNNCell, DropoutWrapper

In [2]:
## WLM models 

class ptb_wlm(object): 
    def __init__(self, seq_len, LSTM_dim = 250, n_label = 10000, dropout_ratio = 0.5, clip_norm = 1.0, name = 'ptb_wlm'):
        self.seq_len = seq_len
        self.name = name
        self.LSTM_dim = LSTM_dim
        self.n_label = n_label
        self.dropout_ratio = dropout_ratio

        with tf.variable_scope(self.name):                                                                
            ## constuct networks 
            
            # placeholders 
            self.x = tf.placeholder(tf.int32, [None, self.seq_len], name = 'x') # input 
            self.y = tf.placeholder(tf.int32, [None, self.seq_len], name = 'y') # labels 
            self.phase = tf.placeholder(tf.bool, [], name = 'phase') # train or inference
            self.lr = tf.placeholder(tf.float32, [], name = 'lr') # learning rate, for lr scheduling 
            self.state = tf.placeholder(tf.float32, [None, 2, None, self.LSTM_dim], name = 'state')
            
            ##
            # construct model here
            ##

            # lstm_output: output of lstm layers 
            logit = tf.layers.dense(lstm_output, self.n_label)
            self.predict = tf.argmax(logit, -1)
            # softmax loss             
            loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = logit, labels = self.y)
            loss = tf.reduce_mean(loss)
            self.loss = loss
            
            opt = tf.train.GradientDescentOptimizer(self.lr)
            gradients, variables = zip(*opt.compute_gradients(loss))
            gradients, _ = tf.clip_by_global_norm(gradients, clip_norm)
            train_op = opt.apply_gradients(zip(gradients, variables))
            self.train_op = train_op
        
            self.saver = tf.train.Saver()
            
            
model = ptb_wlm(20)


In [3]:
## training code 
import time 

#fix random seed
seed = 123
np.random.seed(seed)
tf.set_random_seed(seed)

#model parameters
lstm_dim = 256
dr_ratio = 0.5

#for early stopping technique
initial_lr = 1.0
decay_factor = 0.5
decay_time = 3
patience = 2
max_epoch = 100

#batch size for train & evaluate
train_batch = 50
test_batch = 20
sequence_length = 30

#paths for data
data_path = './ptb_data/'
result_path = './results/'
model_name = 'ptb_wlm'

result_file = result_path + 'check_point_' + model_name


#internal args for early stopping
cur_lr = initial_lr
cur_patience = 0
cur_decay_time = 0
best_valid_loss = 10000
train_phase = True
eval_phase = False

train_loss_hist = []
valid_loss_hist = []

#load dataset
from dataset import PTBDataset
dataset = PTBDataset(train_batch, sequence_length, data_path = data_path, seed = seed)

with tf.Graph().as_default():
    # graph construct
    model = ptb_wlm(
            name = model_name,
            seq_len = sequence_length,
            LSTM_dim = lstm_dim,
            dropout_ratio = dr_ratio,
            )
    #set gpu usage to allow_growth
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    #begin training
    epoch = 0
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        status = 'keep_train'
        #curriculum learning control
        while epoch < max_epoch:
            start_time = time.time()

            #early stopping control
            if status == 'end_train':
                time.sleep(1)
                model.saver.restore(sess, result_file)
                model.saver.save(sess, result_path + model_name)
                break
            elif status == 'change_lr':
                time.sleep(1)
                model.saver.restore(sess, result_file)
                cur_lr *= decay_factor
                cur_patience = 0
                cur_decay_time +=1
                
                print('lr changed to : ', cur_lr)
                print('current decay : ', cur_decay_time, " / ", decay_time)
            elif status == 'save_param':
                cur_patience = 0
                model.saver.save(sess, result_file)
            elif status == 'keep_train':
                cur_patience +=1
            else:
                raise NotImplementedError


            print('--------', epoch, '/', max_epoch, '--------')
            #train
            dataset.set_batch_size(train_batch)
            dataset.set_mode('train')
            epoch_loss = []
            epoch_cer = []
            with tf.name_scope('train'):
                init_state = np.zeros([2, 2, train_batch, lstm_dim], dtype = 'float32')
                
                while dataset.iter_flag():
                    batch_x, batch_y = dataset.get_data()
                    cur_loss, last_state, _ = sess.run(
                        [model.loss, model.last_state, model.train_op],
                        feed_dict = {model.x: batch_x,
                                     model.y: batch_y,
                                     model.phase: train_phase,
                                     model.state: init_state, 
                                     model.lr: cur_lr}
                    )
                    init_state = last_state                    
                    epoch_loss.append(cur_loss)

                epoch_loss = np.mean(np.asarray(epoch_loss, dtype='float32'))
                train_loss_hist.append(epoch_loss)

            # evaluation
            epoch_loss = []
            dataset.set_batch_size(test_batch)
            dataset.set_mode('valid')

            with tf.name_scope('valid'):
                init_state = np.zeros([2, 2, test_batch, lstm_dim], dtype = 'float32')
                while dataset.iter_flag():
                    
                    batch_x, batch_y = dataset.get_data()
                    cur_loss, last_state = sess.run(
                        [model.loss ,model.last_state],
                        feed_dict={model.x: batch_x,
                                   model.y: batch_y,
                                   model.state: init_state,
                                   model.phase: eval_phase}
                    )
                    init_state = last_state
                    epoch_loss.append(cur_loss)
                    continue

                epoch_loss = np.mean(np.asarray(epoch_loss, dtype='float32'))
                valid_loss_hist.append(epoch_loss)

            #early stopping
            if epoch_loss >= best_valid_loss:
                if cur_patience == patience:
                    if cur_decay_time == decay_time:
                        status = 'end_train'
                    else:
                        status = 'change_lr'
                else:
                    status = 'keep_train'
            else:
                status = 'save_param'
                best_valid_loss = epoch_loss

            end_time = time.time()


            print('train loss - ', train_loss_hist[-1], ' | ppl - ', np.exp(train_loss_hist[-1]))
            print('valid loss - ', valid_loss_hist[-1], ' | ppl - ', np.exp(valid_loss_hist[-1]))
            print('status : ', status, ', training time : ', end_time - start_time)
            epoch+=1


        #final test
        # evaluation
        epoch_loss = []
        dataset.set_batch_size(test_batch)
        dataset.set_mode('test')
        start_time = time.time()
        with tf.name_scope('test'):
            init_state = np.zeros([2, 2, test_batch, lstm_dim], dtype = 'float32')
            while dataset.iter_flag():
                batch_x, batch_y = dataset.get_data()
                cur_loss, last_state = sess.run(
                    [model.loss, model.last_state],
                    feed_dict={model.x: batch_x,
                               model.y: batch_y,
                               model.state: init_state,
                               model.phase: eval_phase}
                )
                init_state = last_state
                epoch_loss.append(cur_loss)
            epoch_loss = np.mean(np.asarray(epoch_loss, dtype='float32'))

        end_time = time.time()
        print('\n\n--------final result for test data--------')
        print('loss - ', epoch_loss, ' | ppl - ', np.exp(epoch_loss))
        print('test set inference time : ', end_time - start_time)
        



-------- 0 / 100 --------
train loss -  6.78307  | ppl -  882.775
valid loss -  6.5555  | ppl -  703.102
status :  save_param , training time :  29.919086694717407
-------- 1 / 100 --------
train loss -  6.58648  | ppl -  725.226
valid loss -  6.53572  | ppl -  689.332
status :  save_param , training time :  28.839855909347534
-------- 2 / 100 --------
train loss -  6.53509  | ppl -  688.898
valid loss -  6.43803  | ppl -  625.173
status :  save_param , training time :  27.04491949081421
-------- 3 / 100 --------
train loss -  6.41239  | ppl -  609.348
valid loss -  6.29264  | ppl -  540.578
status :  save_param , training time :  27.29948377609253
-------- 4 / 100 --------
train loss -  6.28777  | ppl -  537.952
valid loss -  6.17401  | ppl -  480.109
status :  save_param , training time :  28.188401460647583
-------- 5 / 100 --------
train loss -  6.1947  | ppl -  490.142
valid loss -  6.09798  | ppl -  444.957
status :  save_param , training time :  28.537468671798706
-------- 6 / 1

train loss -  5.03175  | ppl -  153.201
valid loss -  5.06884  | ppl -  158.99
status :  save_param , training time :  26.93541979789734
-------- 51 / 100 --------
train loss -  5.02312  | ppl -  151.885
valid loss -  5.05146  | ppl -  156.251
status :  save_param , training time :  26.004714488983154
-------- 52 / 100 --------
train loss -  5.01431  | ppl -  150.553
valid loss -  5.04999  | ppl -  156.02
status :  save_param , training time :  27.0794415473938
-------- 53 / 100 --------
train loss -  5.00615  | ppl -  149.328
valid loss -  5.04273  | ppl -  154.893
status :  save_param , training time :  27.22673511505127
-------- 54 / 100 --------
train loss -  4.99932  | ppl -  148.312
valid loss -  5.03795  | ppl -  154.154
status :  save_param , training time :  26.808645725250244
-------- 55 / 100 --------
train loss -  4.99056  | ppl -  147.019
valid loss -  5.03528  | ppl -  153.742
status :  save_param , training time :  28.347527742385864
-------- 56 / 100 --------
train loss



--------final result for test data--------
loss -  4.83679  | ppl -  126.064
test set inference time :  1.9982125759124756


In [4]:
## generate 

#initial 5 words
init_words = "the company said a word"
word_to_id = dataset.word_to_id
id_to_word = dataset.id_to_word
init_ids = []
for i in range(5):
    init_ids.append(word_to_id[init_words.split(' ')[i]])
gen_words = init_words
gen_ids = init_ids

init_state = np.zeros([2, 2, 1, lstm_dim], dtype = 'float32')

with tf.Graph().as_default():
    model = ptb_wlm(
            name = model_name,
            seq_len = 1,
            LSTM_dim = lstm_dim,
            dropout_ratio = dr_ratio,
            )
    #set gpu usage to allow_growth
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        model.saver.restore(sess, result_file)
        for i in range(5):
            cur_x = np.reshape(np.asarray([init_ids[i]], dtype = 'int32'), [1,1])
            cur_predict, last_state = sess.run([model.predict, model.last_state],
                    feed_dict = {
                        model.x : cur_x,
                        model.state: init_state,
                        model.phase: eval_phase
                        }
                    )
            init_state = last_state
            
            print(id_to_word[cur_predict[0][0]])
        gen_ids.append(cur_predict[0][0])
        gen_words = gen_words + ' ' + id_to_word[gen_ids[-1]]
        
        counter = 5
        while gen_words.split(' ')[-1] != '<eos>' and counter < 50:
            cur_x = np.reshape(np.asarray([gen_ids[-1]], dtype = 'int32'), [1,1])
            cur_predict, last_state = sess.run([model.predict, model.last_state],
                    feed_dict = {
                        model.x : cur_x,
                        model.state: init_state,
                        model.phase: eval_phase
                        }
                    )
            init_state = last_state
            gen_ids.append(cur_predict[0][0])
            gen_words = gen_words + ' ' + id_to_word[gen_ids[-1]]
            counter+=1

        print('-----generated ids-----')
        print(gen_ids)
        print('-----generated sentences-----')
        print(gen_words)

INFO:tensorflow:Restoring parameters from ./results/check_point_ptb_wlm
<unk>
's
it
<unk>
of
-----generated ids-----
[0, 37, 15, 6, 2079, 4, 0, 1, 4, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
-----generated sentences-----
the company said a word of the <unk> of the <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>
