Skip to content

Commit

Permalink
Add test and validation results
Browse files Browse the repository at this point in the history
  • Loading branch information
siddk committed Jul 19, 2017
1 parent 8c409fb commit d00af39
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 10 deletions.
16 changes: 16 additions & 0 deletions model/relation_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, S, S_len, Q, Q_len, A, word2id, a_word2id, embed_size=32, lst

if self.restore:
# Restore from Checkpoint
self.session.run(tf.global_variables_initializer())
self.saver.restore(self.session, self.restore)
else:
# Initialize all Variables
Expand Down Expand Up @@ -181,3 +182,18 @@ def fit(self, epochs):
self.session.run(self.epoch_increment)
self.saver.save(self.session, self.ckpt_dir + "model.ckpt", global_step=self.epoch_step)

def eval(self, evalS, evalS_len, evalQ, evalQ_len, evalA):
"""
Evaluate the model on the given data.
:param evalS: 3D Tensor object containing bAbI Stories [N, story_len, max_s]
:param evalS_len: 2D Tensor object containing story line lengths [N, story_len]
:param evalQ: 2D Tensor object containing queries [N, max_q]
:param evalQ_len: 1D Tensor object containing query lengths [N]
:param evalA: 1D Tensor object containing query answers [N]
:return Accuracy (as float)
"""
accuracy = self.session.run(self.accuracy, feed_dict={self.XS: evalS, self.XS_len: evalS_len, self.XQ: evalQ,
self.XQ_len: evalQ_len, self.YA: evalA})
return accuracy

69 changes: 59 additions & 10 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
FLAGS = tf.app.flags.FLAGS

# Run Parameters
tf.app.flags.DEFINE_string("mode", "train", "Mode to run - choose from [train, valid, test].")
tf.app.flags.DEFINE_string("mode", "valid", "Mode to run - choose from [train, valid, test].")
tf.app.flags.DEFINE_string("ckpt_dir", "ckpt/", "Directory to store checkpoints, log files.")

# Eval Mode Parameters
tf.app.flags.DEFINE_integer("task", 1, "Task to evaluate trained model on.")
tf.app.flags.DEFINE_integer("task", 1, "Task to evaluate trained model on. [0] - Means evaluate on all!")


def main(_):
Expand All @@ -32,20 +32,69 @@ def main(_):

# Train for 5 Epochs
print '[*] Training Model!'
rn.fit(epochs=5)
rn.fit(epochs=50)

elif FLAGS.mode == "valid":
S, S_len, Q, Q_len, A, _, _ = parse("valid",
pik_path=os.path.join(FLAGS.ckpt_dir, 'valid', 'valid_%d.pik' % FLAGS.task),
voc_path=os.path.join(FLAGS.ckpt_dir, 'voc.pik'))
# Restore Model
print '[*] Restoring Model!'
S, S_len, Q, Q_len, A, word2id, a_word2id = parse("train",
pik_path=os.path.join(FLAGS.ckpt_dir, 'train', 'train.pik'),
voc_path=os.path.join(FLAGS.ckpt_dir, 'voc.pik'))
rn = RelationNetwork(S, S_len, Q, Q_len, A, word2id, a_word2id,
restore=tf.train.latest_checkpoint(os.path.join(FLAGS.ckpt_dir, 'ckpts')))

if FLAGS.task == 0:
print '[*] Validating on all Tasks!'
for task in range(1, 21):
print '[*] Loading Task %d!' % task
S, S_len, Q, Q_len, A, _, _ = parse("valid",
pik_path=os.path.join(FLAGS.ckpt_dir, 'valid',
'valid_%d.pik' % task),
voc_path=os.path.join(FLAGS.ckpt_dir, 'voc.pik'), task_id=task)
accuracy = rn.eval(S, S_len, Q, Q_len, A)
print 'Task %d\tAccuracy: %.3f' % (task, accuracy)

else:
task = FLAGS.task
print '[*] Validating on Task %d' % task
S, S_len, Q, Q_len, A, _, _ = parse("valid",
pik_path=os.path.join(FLAGS.ckpt_dir, 'valid', 'valid_%d.pik' % task),
voc_path=os.path.join(FLAGS.ckpt_dir, 'voc.pik'), task_id=task)
accuracy = rn.eval(S, S_len, Q, Q_len, A)
print 'Task %d\tAccuracy: %.3f' % (task, accuracy)

elif FLAGS.mode == "test":
S, S_len, Q, Q_len, A, _, _ = parse("test",
pik_path=os.path.join(FLAGS.ckpt_dir, 'test', 'test_%d.pik' % FLAGS.task),
voc_path=os.path.join(FLAGS.ckpt_dir, 'voc.pik'))
# Restore Model
print '[*] Restoring Model!'
S, S_len, Q, Q_len, A, word2id, a_word2id = parse("train",
pik_path=os.path.join(FLAGS.ckpt_dir, 'train', 'train.pik'),
voc_path=os.path.join(FLAGS.ckpt_dir, 'voc.pik'))
rn = RelationNetwork(S, S_len, Q, Q_len, A, word2id, a_word2id,
restore=tf.train.latest_checkpoint(os.path.join(FLAGS.ckpt_dir, 'ckpts')))

if FLAGS.task == 0:
print '[*] Testing on all Tasks!'
for task in range(1, 21):
print '[*] Loading Task %d!' % task
S, S_len, Q, Q_len, A, _, _ = parse("test",
pik_path=os.path.join(FLAGS.ckpt_dir, 'test',
'test_%d.pik' % task),
voc_path=os.path.join(FLAGS.ckpt_dir, 'voc.pik'), task_id=task)
accuracy = rn.eval(S, S_len, Q, Q_len, A)
print 'Task %d\tAccuracy: %.3f' % (task, accuracy)

else:
task = FLAGS.task
print '[*] Testing on Task %d' % task
S, S_len, Q, Q_len, A, _, _ = parse("test",
pik_path=os.path.join(FLAGS.ckpt_dir, 'test', 'test_%d.pik' % task),
voc_path=os.path.join(FLAGS.ckpt_dir, 'voc.pik'), task_id=task)
accuracy = rn.eval(S, S_len, Q, Q_len, A)
print 'Task %d\tAccuracy: %.3f' % (task, accuracy)
else:
print "Unsupported Mode, use one of [train, valid, test]"
raise UserWarning


if __name__ == "__main__":
tf.app.run()
tf.app.run()
42 changes: 42 additions & 0 deletions test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
[*] Restoring Model!
[*] Testing on all Tasks!
[*] Loading Task 1!
Task 1 Accuracy: 0.990
[*] Loading Task 2!
Task 2 Accuracy: 0.338
[*] Loading Task 3!
Task 3 Accuracy: 0.336
[*] Loading Task 4!
Task 4 Accuracy: 0.993
[*] Loading Task 5!
Task 5 Accuracy: 0.939
[*] Loading Task 6!
Task 6 Accuracy: 0.984
[*] Loading Task 7!
Task 7 Accuracy: 0.935
[*] Loading Task 8!
Task 8 Accuracy: 0.989
[*] Loading Task 9!
Task 9 Accuracy: 0.991
[*] Loading Task 10!
Task 10 Accuracy: 0.953
[*] Loading Task 11!
Task 11 Accuracy: 0.964
[*] Loading Task 12!
Task 12 Accuracy: 0.984
[*] Loading Task 13!
Task 13 Accuracy: 0.987
[*] Loading Task 14!
Task 14 Accuracy: 0.833
[*] Loading Task 15!
Task 15 Accuracy: 0.994
[*] Loading Task 16!
Task 16 Accuracy: 0.444
[*] Loading Task 17!
Task 17 Accuracy: 0.771
[*] Loading Task 18!
Task 18 Accuracy: 0.930
[*] Loading Task 19!
Task 19 Accuracy: 0.099
[*] Loading Task 20!
Task 20 Accuracy: 0.999
42 changes: 42 additions & 0 deletions valid.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
[*] Restoring Model!
[*] Validating on all Tasks!
[*] Loading Task 1!
Task 1 Accuracy: 0.995
[*] Loading Task 2!
Task 2 Accuracy: 0.333
[*] Loading Task 3!
Task 3 Accuracy: 0.307
[*] Loading Task 4!
Task 4 Accuracy: 0.995
[*] Loading Task 5!
Task 5 Accuracy: 0.948
[*] Loading Task 6!
Task 6 Accuracy: 0.984
[*] Loading Task 7!
Task 7 Accuracy: 0.955
[*] Loading Task 8!
Task 8 Accuracy: 0.975
[*] Loading Task 9!
Task 9 Accuracy: 0.994
[*] Loading Task 10!
Task 10 Accuracy: 0.952
[*] Loading Task 11!
Task 11 Accuracy: 0.977
[*] Loading Task 12!
Task 12 Accuracy: 0.978
[*] Loading Task 13!
Task 13 Accuracy: 0.977
[*] Loading Task 14!
Task 14 Accuracy: 0.833
[*] Loading Task 15!
Task 15 Accuracy: 1.000
[*] Loading Task 16!
Task 16 Accuracy: 0.439
[*] Loading Task 17!
Task 17 Accuracy: 0.754
[*] Loading Task 18!
Task 18 Accuracy: 0.951
[*] Loading Task 19!
Task 19 Accuracy: 0.119
[*] Loading Task 20!
Task 20 Accuracy: 1.000

0 comments on commit d00af39

Please sign in to comment.