In [1]:
import pickle
import numpy as np
import tensorflow as tf
from preppy import RankPreppy
from user import UserModel
from rank import RankModel
from seq2seq import Seq2SeqModel
from tensorflow.contrib.seq2seq import *
from tensorflow.python.layers.core import Dense

np.set_printoptions(threshold=np.nan)

In [2]:
tf.reset_default_graph()

In [3]:
params = {
    'sentence_embedding': 0,  # sentence vector size, read from seq2seq params
    'user_embeding': 0,   # user  embedding size, read from user params
    
    'p_size': 64,  # hidden layer size, according to the paper
    'r_size': 64,   # hidden layer size, according to the paper
    'f_size': 64,   # hidden layer size, according to the paper
    's_size': 64,  # hidden layer size, according to the paper
    
    'epochs': 1,
    'batch_size': 32,
    
    'grad_clip': 5.0,
    'learning_rate': 0.001,
    
    'save_path' : './Model/Rank/model.ckpt',
    'user_embedding_file': ''
}


Make a dataset by reading the train 

In [4]:
def expand(x):
    x['label'] = tf.expand_dims(tf.convert_to_tensor(x['label']),0)
    x['user'] = tf.expand_dims(tf.convert_to_tensor(x['user']),0)
    return x

def deflate(x):
    x['label'] = tf.squeeze(x['label'])
    x['user'] = tf.squeeze(x['user'])
    return x

def tokenizer(sentence):
    return sentence.split()

def save_params(params, path='./Model/Rank/params.pkl'):
    with open(path, 'wb') as out_file:
        pickle.dump(params, out_file)

def load_params(path='./Model/Rank/params.pkl'):
    with open(path, 'rb') as in_file:
        return pickle.load(in_file)

In [5]:
preppy = pickle.load(open('./data/rank/preppy.pkl','rb'))
dataset_train = tf.data.TFRecordDataset(['./data/rank/train.tfrecord']).map(preppy.parse)

In [23]:
val = pickle.load(open('./data/rank/val.pkl','rb'))
for i in range(len(val)):
    for key in val[i]:
        if key != "user":
            val[i][key] = [preppy.sentence_to_id_list(val[i][key])]
        else:
            val[i][key] = [val[i][key]]
print(val[0])

{'user': [11], 'query': [[52, 5, 101, 257, 16, 1234, 33, 720, 153, 2594, 16, 6, 225, 974, 7, 159, 309, 5, 991, 3, 4294, 5, 2423, 16, 4798, 288, 52, 489, 5, 341, 5338]], 'response_0': [[3811, 16, 1547, 1595, 337, 90, 1750, 3, 70, 81, 497, 866, 52, 174, 2551, 14]], 'response_1': [[5, 119, 3298, 38, 14, 462, 3, 1259, 261, 1482, 565, 462, 973, 819, 11, 727, 1072, 1259]], 'response_2': [[462, 973, 974, 51, 2975, 769, 3]], 'response_3': [[3267, 409, 974, 2688, 727, 1378, 4141, 81, 543, 1530, 12, 16, 2423, 158, 397, 16, 838, 119, 123, 102, 2133, 210, 3, 494, 3, 4141, 27]], 'response_4': [[462, 973, 974, 11, 727, 253, 253, 30, 727, 73, 83, 462, 975, 410]], 'response_5': [[4735, 25, 49, 621, 225, 4735, 3449, 270, 963, 6, 18, 1937, 25, 225, 174, 304, 621, 25]], 'response_6': [[225, 341, 304, 2621, 2500, 25, 740, 25, 11, 225, 2901, 288, 25, 2901, 16, 2991, 106, 277, 225, 7, 884, 25, 2621, 92, 225, 809, 304, 556, 81, 2301, 2815, 16, 3239, 16, 1520, 304, 2971, 25]], 'response_7': [[3, 52, 6, 1649, 

In [7]:
dataset_train.output_shapes

{'query': TensorShape([Dimension(None)]),
 'response': TensorShape([Dimension(None)]),
 'user': TensorShape([]),
 'label': TensorShape([])}

In [8]:
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))

In [9]:
batched_train = dataset_train.map(expand).padded_batch(32,padded_shapes={
    "query":tf.TensorShape([None]),
    "response":tf.TensorShape([None]),
    "label":1,
    "user":1
}, drop_remainder=True).map(deflate)

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, batched_train.output_types, batched_train.output_shapes)

next_item = iterator.get_next()

In [10]:
batched_train.output_shapes

{'query': TensorShape([Dimension(32), Dimension(None)]),
 'response': TensorShape([Dimension(32), Dimension(None)]),
 'user': TensorShape([Dimension(32)]),
 'label': TensorShape([Dimension(32)])}

In [11]:
iterator_train = batched_train.make_initializable_iterator()

handle_train = sess.run(iterator_train.string_handle())

In [12]:
seqParams = load_params('./Model/Seq2seq/params.pkl')
seqParams["vocab_size"] = len(preppy.vocab)

seqModel = Seq2SeqModel(seqParams)

saver = tf.train.Saver()
saver.restore(sess, seqParams["save_path"])

params["sentence_embedding"] = seqParams["hidden_size"]

INFO:tensorflow:Restoring parameters from ./Model/Seq2seq/model.ckpt


In [13]:
userParams = load_params('./Model/User/params.pkl')
params["user_embedding_file"] = userParams["embedding_path"]
params["user_embedding"] = userParams["embedding_size"]

In [14]:
with open("./Model/User/user_embedding.pkl","rb") as user_emb:
    user_embedding = pickle.load(user_emb)

In [15]:
print(np.shape(user_embedding))

(101, 50)


In [16]:
M = RankModel(params)
sess.run(tf.global_variables_initializer())

In [17]:
saver.restore(sess, params["save_path"])

INFO:tensorflow:Restoring parameters from ./Model/Rank/model.ckpt


In [24]:
loss_train = []
loss_val = []
for epoch in range(params["epochs"]):
    print("Epoch: %d"%(epoch))
    sess.run(iterator_train.initializer)
    print("Training")
    while True:
        try:
            item_dict  = sess.run(next_item,feed_dict={handle: handle_train})
            query= sess.run(seqModel.encoder_state,feed_dict={seqModel.sentence:item_dict["query"]})
            response= sess.run(seqModel.encoder_state,feed_dict={seqModel.sentence:item_dict["response"]})
            query = query[1]
            response = response[1]
            _, loss = sess.run([M.train_op, M.loss],feed_dict={
                M.lr: params["learning_rate"], 
                M.query: query,
                M.response: response,
                M.label: item_dict["label"],
                M.user: item_dict["user"]
            })
            loss_train.append(loss)
            print(loss)
        except tf.errors.OutOfRangeError:
            break
        except tf.errors.DataLossError:
            break
    val_scores = []
    print("Validation")
    predictions = []
    for i in range(len(val)):
        query = val[i]["query"]
        user = val[i]["user"]
        scores = []
        query= sess.run(seqModel.encoder_state,feed_dict={seqModel.sentence:query})
        query = query[1]
        for j in range(10):
            response = val[i]["response_"+str(j)]
            response= sess.run(seqModel.encoder_state,feed_dict={seqModel.sentence:response})
            response = response[1]
            val_score = sess.run(M.predict,feed_dict={
                M.query: query,
                M.response: response,
                M.user: user
            })
            scores.append(val_score)
        predict = np.argmax(np.array(scores))
        predictions.append(predict)
        print(predict)

print("Training and Validation Finish")

# Save Model
saver = tf.train.Saver()
saver.save(sess, params["save_path"])
save_params(params)

print('Model Trained and Saved')

Epoch: 0
Training
0.47476962
0.3276026
0.41601965
0.48160124
0.24021786
0.658562
0.43334627
0.398009
0.5676975
0.37094367
0.59196657
0.31432384
0.23204938
0.4393373
0.49548697
0.34006107
0.525295
0.445009
0.4810596
0.2571851
0.32639915
0.4679556
0.27033445
0.52282244
0.5665869
0.33226573
0.38346237
0.34524184
0.21056077
0.3890518
0.40338176
0.3719858
0.43448615
0.4934253
0.32570815
0.31917977
0.42883348
0.18388817
0.2914665
0.32172972
0.42681658
0.2906142
0.44946837
0.2807832
0.39934883
0.1729584
0.20328087
0.2965841
0.39577883
0.36500454
0.16734914
0.40122384
0.34947228
0.33299136
0.19055451
0.49160913
0.3640092
0.100351974
0.3093309
0.3120603
0.15249136
0.24589728
0.44603905
0.3807531
0.38060066
0.36607006
0.40813074
0.23560803
0.42502987
0.30694595
0.17251647
0.23285148
0.3350166
0.3263651
0.23926188
0.46863753
0.37282532
0.2286796
0.3077584
0.38507038
0.39333594
0.530054
0.29573473
0.2626981
0.39301366
0.5629115
0.38289672
0.43868953
0.370881
0.23311399
0.47641358
0.43305534
0.2984

In [28]:
print(predictions)

[5, 3, 4, 9, 5, 1, 7, 4, 2, 4, 3, 3, 9, 4, 6, 4, 6, 9, 0, 6, 0, 1, 5, 8, 7, 6, 5, 3, 3, 3, 0, 2, 0, 4, 5, 3, 8, 6, 7, 9, 6, 5, 3, 5, 4, 0, 8, 3, 2, 8, 3, 9, 0, 3, 7, 5, 6, 4, 3, 3, 5, 8, 0, 2, 6, 4, 6, 8, 0, 9, 7, 2, 7, 2, 6, 3, 8, 1, 9, 9, 3, 5, 0, 2, 6, 2, 9, 3, 5, 1, 8, 3, 4, 1, 7, 1, 0, 3, 7, 3, 6, 3, 0, 9, 9, 8, 9, 6, 8, 4, 0, 0, 3, 1, 1, 7, 6, 9, 0, 1, 8, 3, 8, 3, 0, 4, 8, 0, 6, 5, 4, 3, 3, 1, 8, 7, 7, 3, 3, 4, 4, 3, 8, 3, 5, 6, 6, 8, 3, 2, 9, 7, 4, 8, 3, 0, 0, 5, 0, 9, 7, 9, 6, 2, 9, 3, 2, 6, 9, 9, 9, 3, 2, 9, 7, 1, 9, 7, 6, 0, 1, 6, 6, 8, 7, 4, 5, 5, 5, 0, 1, 4, 7, 1, 9, 8, 4, 8, 8, 4, 3, 0, 2, 7, 2, 9, 1, 1, 3, 2, 0, 4, 2, 3, 3, 2, 2, 4, 9, 6, 6, 8, 0, 5, 6, 1, 3, 2, 9, 3, 4, 3, 8, 4, 9, 5, 5, 0, 9, 7, 5, 6, 6, 7, 3, 0, 0, 4, 8, 3, 1, 8, 8, 6, 1, 4, 2, 2, 4, 6, 7, 8, 0, 1, 8, 7, 8, 0, 8, 2, 3, 2, 5, 7, 6, 5, 0, 0, 3, 4, 2, 1, 0, 6, 0, 3, 3, 4, 9, 4, 4, 3, 0, 2, 7, 6, 0, 7, 6, 2, 7, 4, 8, 6, 4, 4, 4, 1, 1, 5, 4, 5, 3, 0, 4, 9, 7, 1, 7, 0, 4, 0, 7, 9, 6, 8, 4, 4, 3, 0, 8, 3, 1, 