In [1]:
import pickle
import numpy as np
import tensorflow as tf
from preppy import UserPreppy
from user import UserModel
from seq2seq import Seq2SeqModel
from tensorflow.contrib.seq2seq import *
from tensorflow.python.layers.core import Dense

np.set_printoptions(threshold=np.nan)

In [2]:
tf.reset_default_graph()

In [3]:
params = {
    'embedding_size': 50, # user embedding
    'num_users': 101,
    'hidden_size': 64, # dense layer
    
    'epochs': 10,
    'batch_size': 32,
    
    'grad_clip': 5.0,
    'learning_rate': 0.001,
    
    'save_path' : './Model/User/model.ckpt'
}


Make a dataset by reading the train 

In [4]:
def expand(x):
    x['label'] = tf.expand_dims(tf.convert_to_tensor(x['label']),0)
    x['user'] = tf.expand_dims(tf.convert_to_tensor(x['user']),0)
    return x

def deflate(x):
    x['label'] = tf.squeeze(x['label'])
    x['user'] = tf.squeeze(x['user'])
    return x

def tokenizer(sentence):
    return sentence.split()

def save_params(params, path='./Model/User/params.pkl'):
    with open(path, 'wb') as out_file:
        pickle.dump(params, out_file)

def load_params(path='./Model/User/params.pkl'):
    with open(path, 'rb') as in_file:
        return pickle.load(in_file)

In [5]:
preppy = pickle.load(open('./data/user/preppy.pkl','rb'))
dataset_train = tf.data.TFRecordDataset(['./data/user/train.tfrecord']).map(preppy.parse)
dataset_val = tf.data.TFRecordDataset(['./data/user/val.tfrecord']).map(preppy.parse)

In [6]:
dataset_train.output_shapes

{'sentence': TensorShape([Dimension(None)]),
 'user': TensorShape([]),
 'label': TensorShape([])}

In [7]:
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))

In [8]:
batched_train = dataset_train.map(expand).padded_batch(32,padded_shapes={
    "sentence":tf.TensorShape([None]),
    "label":1,
    "user":1
}, drop_remainder=True).map(deflate)

batched_val = dataset_val.map(expand).padded_batch(32,padded_shapes={
    "sentence":tf.TensorShape([None]),
    "label":1,
    "user":1
}, drop_remainder=True).map(deflate)

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, batched_train.output_types, batched_train.output_shapes)

next_item = iterator.get_next()

In [9]:
batched_train.output_shapes

{'sentence': TensorShape([Dimension(32), Dimension(None)]),
 'user': TensorShape([Dimension(32)]),
 'label': TensorShape([Dimension(32)])}

In [10]:
iterator_train = batched_train.make_initializable_iterator()
iterator_val = batched_val.make_initializable_iterator()

handle_train = sess.run(iterator_train.string_handle())
handle_val = sess.run(iterator_val.string_handle())

In [11]:
seqParams = load_params('./Model/Seq2seq/params.pkl')
seqParams["vocab_size"] = len(preppy.vocab)

Seq = Seq2SeqModel(next_item, seqParams)

saver = tf.train.Saver()
saver.restore(sess, seqParams["save_path"])

params["sentence_size"] = seqParams["hidden_size"]

INFO:tensorflow:Restoring parameters from ./Model/Seq2seq/model.ckpt


In [12]:
M = UserModel(next_item, params)
sess.run(tf.global_variables_initializer())

In [13]:
sentence = np.random.rand(params["batch_size"], seqParams["hidden_size"])
sentence
print(np.shape(sentence))

(32, 64)


In [31]:
loss_train = []
loss_val = []
for epoch in range(params["epochs"]):
    print("Epoch: %d"%(epoch))
    sess.run(iterator_train.initializer)
    print("Training")
    while True:
        try:
            sentence = sess.run([Seq.encoder_state],feed_dict={handle: handle_train})
            sentence = sentence[0][1]
            _, loss = sess.run([M.train_op, M.loss],feed_dict={handle: handle_train, M.lr: params["learning_rate"], M.sentence: sentence})
            loss_train.append(loss)
            print(loss)
        except tf.errors.OutOfRangeError:
            break
        except tf.errors.DataLossError:
            break
    
    print("Validation")
    sess.run(iterator_val.initializer)
    while True:
        try:
            loss = sess.run([M.loss], feed_dict={handle: handle_val, M.sentence: sentence})
            loss_val.append(loss)
        except tf.errors.OutOfRangeError:
            break
        except tf.errors.DataLossError:
            break

print("Training and Validation Finish")

# Save Model
saver = tf.train.Saver()
saver.save(sess, params["save_path"])
save_params(params["save_path"])

print('Model Trained and Saved')

Epoch: 0
Training
0.675029
0.6670722
0.6523012
0.72999626
0.67252916
0.63465667
0.6983559
0.62789303
0.70183873
0.7200966
0.67391837
0.6800141
0.6531574
0.7153834
0.73229754
0.7568157
0.7057307
0.75013375
0.71488965
0.7100783
0.6504344
0.6899806
0.69403976
0.69795763
0.71849704
0.6684686
0.69949245
0.70484734
0.69305366
0.6815133
0.71252835
0.68626654
0.69254434
0.65513647
0.70442915
0.6716073
0.67620236
0.7194283
0.69569296
0.68630713
0.6994825
0.69218695
0.7085073
0.67646694
0.6676594
0.66750515
0.6959096
0.70726824
0.6777333
0.66397953
0.70961905
0.6864778
0.7026472
0.6707838
0.6573089
0.70112824
0.70692956
0.7142452
0.6942569
0.69936097
0.64874434
0.70411026
0.72077584
0.652691
0.7388009
0.6924424
0.7054165
0.7161866
0.67414474
0.6438148
0.67217845
0.6782994
0.6944267
0.6708162
0.6969852
0.7248253
0.7285496
0.7161125
0.6927978
0.6917727
0.70366836
0.6835671
0.6898908
0.7140877
0.7158962
0.66463137
0.7183664
0.6874913
0.7046777
0.690469
0.7161585
0.69594705
0.6834272
0.69039017
0.69

0.67900646
0.6391311
0.6943078
0.6958252
0.72457695
0.70060146
0.72195166
0.7080943
0.72410583
0.6591858
0.67370474
0.6807076
0.6843397
0.69066775
0.6771771
0.6617391
0.67397916
0.6852362
0.6796367
0.6879277
0.6891203
0.6874554
0.6423245
0.67358875
0.6587086
0.6781608
0.7302122
0.6803287
0.7076108
0.69131994
0.70499444
0.6748015
0.66290385
0.67104226
0.64070475
0.7217897
0.71922517
0.682933
0.63503206
0.7050233
0.69302446
0.7135754
0.67649794
0.6623136
0.68206656
0.6990291
0.7120718
0.66265273
0.66154486
0.6420203
0.6902468
0.72163075
0.6680323
0.7310366
0.7023934
0.71936786
0.7068995
0.65701514
0.6311375
0.6620907
0.67925394
0.6974059
0.6652703
0.6911301
0.71285224
0.70865273
0.6815282
0.697068
0.67858315
0.7053782
0.67324555
0.7136486
0.7023705
0.6990537
0.67261755
0.7044811
0.68162394
0.70274764
0.7031907
0.7095821
0.68757445
0.7021304
0.68317604
0.69697726
0.6782762
0.6954262
0.68432134
0.66174316
0.6835599
0.69032013
0.68377787
0.670705
0.6722779
0.6797426
0.7190971
0.6724335
0.68