Skip to content

Commit

Permalink
Improving local run behavior in estimator.train_and_evaluate.
Browse files Browse the repository at this point in the history
Current behavior is unintuitive (depends on throttle_secs) and leads to frequent checkpoint than desired.
This CL makes evaluation synchronized with checkpointing. It also makes the behavior more closer to distributed setting in following ways:
* in distributed setting we do create input_pipeline only once, in current behavior of local run we do recreate input pipeline in a loop. This cl creates training input pipeline only once.
* in distributed setting evaluator job waits for checkpoints which are dumped by training job. In current behavior of local run evaluator controls the checkpoint schedule. In this cl, we give back the control to trainer.

PiperOrigin-RevId: 201085814
  • Loading branch information
ispirmustafa authored and tensorflower-gardener committed Jun 19, 2018
1 parent f91b5b0 commit 3edb609
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 251 deletions.
160 changes: 77 additions & 83 deletions tensorflow/python/estimator/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,61 @@ def after_run(self, run_context, run_values):
run_context.request_stop()


class _NewCheckpointListenerForEvaluate(
basic_session_run_hooks.CheckpointSaverListener):
"""A saver listener to run evaluate with every checkpoint."""

def __init__(self, evaluator, eval_throttle_secs, continuous_eval_listener):
self._evaluator = evaluator
self._eval_throttle_secs = eval_throttle_secs
self._continuous_eval_listener = continuous_eval_listener
self.eval_result, self.export_results = None, None

def begin(self):
self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_secs=self._eval_throttle_secs)
self._is_first_run = True

def after_save(self, session, global_step_value):
del session # unused; required by signature.
# skip first run model is not trained yet.
if self._is_first_run:
self._is_first_run = False
return

if not self._continuous_eval_listener.before_eval():
logging.info('Exiting training and evaluation loop, as requested by '
'_ContinuousEvalListener.before_eval.')
return True
if self._timer.should_trigger_for_step(global_step_value):
self._evaluate(global_step_value) # updates self.eval_result
if not self._continuous_eval_listener.after_eval(self.eval_result):
logging.info('Exiting evaluation, as requested by '
'_ContinuousEvalListener.after_eval.')
return True
else:
# TODO(ispir): add remaining time in the log.
logging.info('Skip the current checkpoint eval due to throttle secs '
'({} secs).'.format(self._eval_throttle_secs))

def end(self, session, global_step_value):
# Evaluate if the last step has not been evaluated, yet.
if global_step_value != self._timer.last_triggered_step():
if self._continuous_eval_listener.before_eval():
self._evaluate(global_step_value)
self._continuous_eval_listener.after_eval(self.eval_result)

def _evaluate(self, global_step_value):
self._timer.update_last_triggered_step(global_step_value)
self.eval_result, self.export_results = (
self._evaluator.evaluate_and_export())
if self.eval_result.status != _EvalStatus.EVALUATED:
# This is unexpected; should never happen.
# Training should always end with a new checkpoint.
raise RuntimeError('There was no new checkpoint after the training. '
'Eval status: {}'.format(self.eval_result.status))


class _TrainingExecutor(object):
"""The executor to run `Estimator` training and evaluation.
Expand Down Expand Up @@ -576,28 +631,6 @@ def run_worker(self):

def run_master(self):
"""Runs task master."""

class NewCheckpointListener(
basic_session_run_hooks.CheckpointSaverListener):

def __init__(self, evaluator, eval_throttle_secs):
self._evaluator = evaluator
self._eval_throttle_secs = eval_throttle_secs

def begin(self):
self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_secs=self._eval_throttle_secs)

def after_save(self, session, global_step_value):
del session # unused; required by signature.

if self._timer.should_trigger_for_step(global_step_value):
self._timer.update_last_triggered_step(global_step_value)
self._evaluator.evaluate_and_export()
else:
logging.info('Skip the current checkpoint eval due to throttle secs '
'({} secs).'.format(self._eval_throttle_secs))

_assert_eval_spec(self._eval_spec)

# Final export signal: For any eval result with global_step >= train
Expand All @@ -617,16 +650,12 @@ def after_save(self, session, global_step_value):
# When the underlying `Estimator` object saves a new checkpoint, we would
# like this callback to be called so that evaluation and export can trigger.
saving_listeners = [
NewCheckpointListener(evaluator, self._eval_spec.throttle_secs)
_NewCheckpointListenerForEvaluate(evaluator,
self._eval_spec.throttle_secs,
_ContinuousEvalListener())
]
self._start_distributed_training(saving_listeners=saving_listeners)

if not evaluator.is_final_export_triggered:
logging.info('Training has already ended. But the last eval is skipped '
'due to eval throttle_secs. Now evaluating the final '
'checkpoint.')
evaluator.evaluate_and_export()

def run_evaluator(self):
"""Runs task evaluator."""
# TODO(xiejw): To allow execution framework to add continuous eval listener.
Expand All @@ -640,68 +669,33 @@ def run_ps(self):

def run_local(self):
"""Runs training and evaluation locally (non-distributed)."""

def _should_stop_local_train(global_step):
if self._train_spec.max_steps is None:
return False
if global_step >= self._train_spec.max_steps:
return True
return False

_assert_eval_spec(self._eval_spec)

if self._eval_spec.throttle_secs <= 0:
raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
'It is used do determine how long each training '
'iteration should go when train and evaluate '
'locally.'.format(self._eval_spec.throttle_secs))

stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
train_hooks = (
list(self._train_spec.hooks) + [stop_hook] + list(self._train_hooks))
train_hooks = list(self._train_spec.hooks) + list(self._train_hooks)
logging.info('Start train and evaluate loop. The evaluate will happen '
'after {} secs (eval_spec.throttle_secs) or training is '
'finished.'.format(self._eval_spec.throttle_secs))
'after every checkpoint. Checkpoint frequency is determined '
'based on RunConfig arguments: save_checkpoints_steps {} or '
'save_checkpoints_secs {}.'.format(
self._estimator.config.save_checkpoints_steps,
self._estimator.config.save_checkpoints_secs))

evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
self._train_spec.max_steps)

eval_result = _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT)
export_results = []

while True:
self._estimator.train(
input_fn=self._train_spec.input_fn,
max_steps=self._train_spec.max_steps,
hooks=train_hooks)

if not self._continuous_eval_listener.before_eval():
logging.info('Exiting training and evaluation loop, as requested by '
'_ContinuousEvalListener.before_eval.')
break

# Final export signal: For any eval result with global_step >= train
# max_steps, the evaluator will send the final export signal. The
# _should_stop_local_train will then end the while True as the stopping
# condition is satisfied (both checks use the same global_step value,
# i.e., no race condition)
eval_result, export_results = evaluator.evaluate_and_export()

if eval_result.status != _EvalStatus.EVALUATED:
# This is unexpected; should never happen.
# Training should always end with a new checkpoint.
raise RuntimeError('There was no new checkpoint after the training. '
'Eval status: {}'.format(eval_result.status))

if not self._continuous_eval_listener.after_eval(eval_result):
logging.info('Exiting evaluation, as requested by '
'_ContinuousEvalListener.after_eval.')
break
listener_for_eval = _NewCheckpointListenerForEvaluate(
evaluator, self._eval_spec.throttle_secs,
self._continuous_eval_listener)
saving_listeners = [listener_for_eval]

self._estimator.train(
input_fn=self._train_spec.input_fn,
max_steps=self._train_spec.max_steps,
hooks=train_hooks,
saving_listeners=saving_listeners)

if _should_stop_local_train(
eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]):
break
return eval_result.metrics, export_results
eval_result = listener_for_eval.eval_result or _EvalResult(
status=_EvalStatus.MISSING_CHECKPOINT)
return eval_result.metrics, listener_for_eval.export_results

def _start_std_server(self, config):
"""Creates, starts, and returns a server_lib.Server."""
Expand Down

0 comments on commit 3edb609

Please sign in to comment.