In [1]:
import tensorflow as tf
import numpy as np
import time 
import os
import pickle


input_target_list_pkl = 'data/input_taget_list.pkl'


class RecLSTM:
    def __init__(self, num_items, num_seqs=64, num_steps=50,
              lstm_size=128, num_layers=2, learning_rate=0.001,
              grad_clip=5, train_keep_prob=0.5):
        self.num_items = num_items
        self.num_seqs = num_seqs
        self.num_steps = num_steps
        self.lstm_size = lstm_size
        self.num_layers = num_layers
        self.learning_rate = learning_rate
        self.grad_clip = grad_clip
        self.train_keep_prob = train_keep_prob
        
        tf.reset_default_graph()
        self.build_inputs()
        self.build_lstm()
        self.build_loss()
        self.build_optimizer()
        self.saver = tf.train.Saver()
        
        
    def build_inputs(self):
        with tf.name_scope('inputs'):
            self.inputs = tf.placeholder(tf.int32, shape=(
                self.num_seqs, self.num_steps), name='inputs')
            self.targets = tf.placeholder(tf.int32, shape=(
                self.num_seqs, self.num_steps), name='targets')
            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
            # TODO num_items?
            self.lstm_inputs = tf.one_hot(self.inputs, self.num_items)
            
        
    def build_lstm(self):
        def get_a_cell(lstm_size, keep_prob):
            lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)
            drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)
            return drop
        
        with tf.name_scope('lstm'):
            cell = tf.nn.rnn_cell.MultiRNNCell(
                [get_a_cell(self.lstm_size, self.keep_prob) for _ in range(self.num_layers)]
            )
            self.initial_state = cell.zero_state(self.num_seqs, tf.float32)
            
            self.lstm_output, self.final_state = tf.nn.dynamic_rnn(cell, self.lstm_inputs,
                                                                  initial_state=self.initial_state)
            seq_output = tf.concat(self.lstm_output, 1)
            x = tf.reshape(seq_output, [-1, self.lstm_size])
            
            with tf.variable_scope('softmax'):
                softmax_w = tf.Variable(tf.truncated_normal([self.lstm_size, self.num_items], stddev=0.1))
                softmax_b = tf.Variable(tf.zeros(self.num_items))
                
            self.logits = tf.matmul(x, softmax_w) + softmax_b
            self.proba_prediction = tf.nn.softmax(self.logits, name='predictions')
            
            
    def build_loss(self):
        with tf.name_scope('loss'):
            y_one_hot = tf.one_hot(self.targets, self.num_items)
            y_reshaped = tf.reshape(y_one_hot, self.logits.get_shape())
            loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits,labels=y_reshaped)
            self.loss = tf.reduce_mean(loss)
        
        
    def build_optimizer(self):
        # use clipping gradients
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), self.grad_clip)
        train_op = tf.train.AdamOptimizer(self.learning_rate)
        self.optimizer = train_op.apply_gradients(zip(grads, tvars))

        
    def train(self, batch_generator, max_steps, save_path, save_every_n, log_every_n):
        self.session = tf.Session()
        with self.session as sess:
            sess.run(tf.global_variables_initializer())
            step = 0
            new_state = sess.run(self.initial_state)
            for x, y in batch_generator:
                step += 1
                start = time.time()
                feed = {self.inputs: x,
                        self.targets: y,
                        self.keep_prob: self.train_keep_prob,
                        self.initial_state: new_state}
                batch_loss, new_state, _ = sess.run([self.loss,
                                                     self.final_state,
                                                     self.optimizer],
                                                    feed_dict=feed)
                end = time.time()
                if step % log_every_n == 0:
                    print('step: {}/{}... '.format(step, max_steps),
                          'loss: {:.4f}... '.format(batch_loss),
                          '{:.4f} sec/batch'.format((end - start)*10))
                if step % save_every_n == 0:
                   self.saver.save(sess, os.path.join(save_path, 'lstm_model'), global_step=step)
                if step >= max_steps:
                    break
            self.saver.save(sess, os.path.join(save_path, 'lstm_model'), global_step=step)

            
    def test(self, test_generator, item_size, max_steps=100):
        with open(input_target_list_pkl, 'rb') as rf:
            input_target_list = pickle.load(rf)
        input_target_map = dict()
        for item in input_target_list:
            item_convert = [str(i) for i in item]
            key = '|'.join(item_convert[:-1])
            value = item[-1]
            input_target_map[key] = value
        step = 0
        sess = self.session
        new_state = sess.run(self.initial_state)
        hit = 0
        li = list()
        for i in test_generator:
            step += 1
            pred_index = 0
            key = ''
            for item_id in i[0]:
#                 print(item_id)
                if item_id != 0:
                    pred_index += 1
                    key = key + str(item_id) + '|'
                else:
                    break
            key = key[:-1]
            pred_index = pred_index - 1
            feed = {self.inputs: i,
                    self.keep_prob: 1,
                    self.initial_state: new_state}
            preds, new_state = sess.run([self.proba_prediction, self.final_state],
                                        feed_dict=feed)
            print('----------------input----------------')
            print(i[0])
            print('----------------output---------------')
            print(np.argmax(preds[pred_index]))
            pred_output = preds[pred_index][np.argmax(preds[pred_index])]
            print('pred_output: {}'.format(pred_output))
            print('----------------truth----------------')
            print(input_target_map[key])
            pred_truth = preds[pred_index][int(input_target_map[key])]
            print(pred_truth)
            bigger_num = 0
            for pred in preds[pred_index]:
                if pred > pred_truth:
                    bigger_num += 1
            print('{} bigger than pred_truth'.format(bigger_num))
            li.append(bigger_num)
            if np.argmax(preds[pred_index]) == input_target_map[key]:
                hit += 1
            if step >= max_steps:
                break
        print(hit)
        li.sort()
        print(li)

    
    def load(self, checkpoint):
        self.session = tf.Session()
        self.saver.restore(self.session, checkpoint)
        print('Restored from: {}'.format(checkpoint))