Skip to content

Commit

Permalink
fix saving variables
Browse files Browse the repository at this point in the history
  • Loading branch information
jindrahelcl committed Jun 23, 2016
1 parent 8c35a34 commit add9bdc
Showing 1 changed file with 59 additions and 37 deletions.
96 changes: 59 additions & 37 deletions learning_utils.py
Expand Up @@ -103,7 +103,7 @@ def training_loop(sess, saver,
runner,
test_datasets=[],
save_n_best_vars=1,
link_best_vars=None,
link_best_vars="/tmp/variables.data.best",
vars_prefix="/tmp/variables.data",
initial_variables=None,
logging_period=20,
Expand Down Expand Up @@ -170,23 +170,26 @@ def training_loop(sess, saver,
elif save_n_best_vars > 1:
variables_files = ['{}.{}'.format(vars_prefix, i)
for i in range(save_n_best_vars)]
var_file_i = 0

if minimize_metric:
saved_scores = [np.inf for _ in range(save_n_best_vars)]
best_score = np.inf
else:
saved_scores = [-np.inf for _ in range(save_n_best_vars)]
best_score = -np.inf

saver.save(sess, variables_files[0])

if link_best_vars is not None:
if os.path.islink(link_best_vars):
# if overwriting output dir
os.unlink(link_best_vars)
os.symlink(variables_files[0], link_best_vars)
if os.path.islink(link_best_vars):
# if overwriting output dir
os.unlink(link_best_vars)
os.symlink(variables_files[0], link_best_vars)

if log_directory:
log("Initializing TensorBoard summary writer.")
tb_writer = tf.train.SummaryWriter(log_directory, sess.graph)
log("TesorBoard writer initialized.")

best_score = 0.0
if minimize_metric:
best_score = float('inf')
best_score_epoch = 0
best_score_batch_no = 0

Expand Down Expand Up @@ -219,49 +222,67 @@ def training_loop(sess, saver,
else:
trainer.run(sess, batch_feed_dict, summary=False)


if step % validation_period == validation_period - 1:
decoded_val_sentences, decoded_raw_val_sentences, val_evaluation = \
run_on_dataset(sess, runner, all_coders, decoder, val_dataset,
evaluation_functions, postprocess, write_out=False)
decoded_val_sentences, decoded_raw_val_sentences, \
val_evaluation = run_on_dataset(
sess, runner, all_coders, decoder, val_dataset,
evaluation_functions, postprocess, write_out=False)

this_score = val_evaluation[evaluation_functions[-1]]
best_var_file = None

if (minimize_metric and this_score < best_score) or \
(not minimize_metric and this_score > best_score):
def is_better(score1, score2, minimize):
if minimize:
return score1 < score2
else:
return score1 > score2

def argworst(scores, minimize):
if minimize:
return np.argmax(scores)
else:
return np.argmin(scores)

if is_better(this_score, best_score, minimize_metric):
best_score = this_score
best_score_epoch = i + 1
best_score_batch_no = batch_n

## increment var file index here or restoring must decrement
var_file_i = (var_file_i + 1) % save_n_best_vars
best_var_file = variables_files[var_file_i]
saver.save(sess, best_var_file)
worst_index = argworst(saved_scores, minimize_metric)
worst_score = saved_scores[worst_index]

if is_better(this_score, worst_score, minimize_metric):
# we need to save this score instead the worst score
worst_var_file = variables_files[worst_index]
saver.save(sess, worst_var_file)
saved_scores[worst_index] = this_score
log("Variable file saved in {}".format(worst_var_file))

## TODO make link best vars never be none
if save_n_best_vars > 1 and link_best_vars is not None:
## make the symlink point to the best vars
# update symlink
if best_score == this_score:
os.unlink(link_best_vars)
os.symlink(variables_files[var_file_i], link_best_vars)
os.symlink(worst_var_file, link_best_vars)

log("Best scores saved so far: {}".format(saved_scores))

log("Validation (epoch {}, batch number {}):"
.format(i + 1, batch_n), color='blue')

log("Validation (epoch {}, batch number {}):"\
.format(i + 1, batch_n), color='blue')
process_evaluation(evaluation_functions, tb_writer, val_evaluation,
seen_instances, summary_str, None, train=False)
process_evaluation(evaluation_functions, tb_writer,
val_evaluation, seen_instances,
summary_str, None, train=False)

if this_score == best_score:
best_score_str = colored("{:.2f}".format(best_score), attrs=['bold'])
best_score_str = colored("{:.2f}".format(best_score),
attrs=['bold'])
else:
best_score_str = "{:.2f}".format(best_score)

log("best {} on validation: {} (in epoch {}, after batch number {})"\
.format(evaluation_labels[-1], best_score_str,
best_score_epoch, best_score_batch_no), color='blue')
log("best {} on validation: {} (in epoch {}, "
"after batch number {})"
.format(evaluation_labels[-1], best_score_str,
best_score_epoch, best_score_batch_no),
color='blue')

if best_var_file is not None:
log("variable file saved in {}".format(best_var_file))

log_print("")
log_print("Examples:")
Expand Down Expand Up @@ -292,13 +313,14 @@ def training_loop(sess, saver,
except KeyboardInterrupt:
log("Training interrupted by user.")

saver.restore(sess, variables_files[var_file_i])
saver.restore(sess, link_best_vars)
log("Training finished. Maximum {} on validation data: {:.2f}, epoch {}"
.format(evaluation_labels[-1], best_score, best_score_epoch))

for dataset in test_datasets:
_, _, evaluation = run_on_dataset(sess, runner, all_coders, decoder,
dataset, evaluation_functions, postprocess, write_out=True)
dataset, evaluation_functions,
postprocess, write_out=True)
if evaluation:
print_dataset_evaluation(dataset.name, evaluation)

Expand Down

0 comments on commit add9bdc

Please sign in to comment.