Skip to content

Commit

Permalink
correct grid world
Browse files Browse the repository at this point in the history
  • Loading branch information
shehzaadzd committed May 10, 2018
1 parent c2f09dd commit 6faf567
Show file tree
Hide file tree
Showing 35 changed files with 294,518 additions and 83 deletions.
9 changes: 3 additions & 6 deletions code/model/trainer.py
Expand Up @@ -270,9 +270,7 @@ def train(self, sess):
self.path_logger_file_ = self.path_logger_file + "/" + str(self.batch_counter) + "/paths"


self.test_environment = self.dev_test_environment
self.test(sess, beam=True, print_paths=False)
self.test_environment = self.test_test_environment

self.test(sess, beam=True, print_paths=False)

logger.info('Memory usage: %s (kb)' % resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
Expand Down Expand Up @@ -473,7 +471,7 @@ def test(self, sess, beam=False, print_paths=False, save_model = True, auc = Fal
all_final_reward_20 /= total_examples
auc /= total_examples
if save_model:
if auc >= self.max_hits_at_10:
if all_final_reward_10 >= self.max_hits_at_10:
self.max_hits_at_10 = all_final_reward_10
self.save_path = self.model_saver.save(sess, self.model_dir + "model" + '.ckpt')

Expand Down Expand Up @@ -584,8 +582,7 @@ def top_k(self, scores, k):

trainer.test(sess, beam=True, print_paths=True, save_model=False)

# trainer.test_environment = trainer.dev_test_environment
# trainer.test(sess, beam=True, print_paths=True, save_model=False)

print options['nell_evaluation']
if options['nell_evaluation'] == 1:
nell_eval(path_logger_file + "/" + "test_beam/" + "pathsanswers", trainer.data_input_dir+'/sort_test.pairs' )
Expand Down
6 changes: 3 additions & 3 deletions code/options.py
Expand Up @@ -12,7 +12,7 @@ def read_options():
parser.add_argument("--input_file", default="train.txt", type=str)
parser.add_argument("--create_vocab", default=0, type=int)
parser.add_argument("--vocab_dir", default="", type=str)
parser.add_argument("--max_num_actions", default=250, type=int)
parser.add_argument("--max_num_actions", default=200, type=int)
parser.add_argument("--path_length", default=3, type=int)
parser.add_argument("--hidden_size", default=50, type=int)
parser.add_argument("--embedding_size", default=50, type=int)
Expand All @@ -28,15 +28,15 @@ def read_options():
parser.add_argument("--log_file_name", default="reward.txt", type=str)
parser.add_argument("--output_file", default="", type=str)
parser.add_argument("--num_rollouts", default=20, type=int)
parser.add_argument("--test_rollouts", default=100, type=int)
parser.add_argument("--test_rollouts", default=10, type=int)
parser.add_argument("--LSTM_layers", default=1, type=int)
parser.add_argument("--model_dir", default='', type=str)
parser.add_argument("--base_output_dir", default='', type=str)
parser.add_argument("--total_iterations", default=2000, type=int)

parser.add_argument("--Lambda", default=0.0, type=float)
parser.add_argument("--pool", default="max", type=str)
parser.add_argument("--eval_every", default=100, type=int)
parser.add_argument("--eval_every", default=500, type=int)
parser.add_argument("--use_entity_embeddings", default=0, type=int)
parser.add_argument("--train_entity_embeddings", default=0, type=int)
parser.add_argument("--train_relation_embeddings", default=1, type=int)
Expand Down
2 changes: 1 addition & 1 deletion configs/fb15k-237.sh
Expand Up @@ -2,7 +2,7 @@

data_input_dir="datasets/data_preprocessed/FB15K-237/"
vocab_dir="datasets/data_preprocessed/FB15K-237/vocab"
total_iterations=3000
total_iterations=2000
path_length=3
hidden_size=50
embedding_size=50
Expand Down
Binary file modified datasets/.DS_Store
Binary file not shown.
Binary file modified datasets/data_preprocessed/.DS_Store
Binary file not shown.
Binary file modified datasets/data_preprocessed/grid-world/.DS_Store
Binary file not shown.
Binary file modified datasets/data_preprocessed/grid-world/problem_16_10_8/.DS_Store
Binary file not shown.

0 comments on commit 6faf567

Please sign in to comment.