Skip to content

Commit

Permalink
Actually use eval hooks in Experiment across all usages.
Browse files Browse the repository at this point in the history
Fix bug with not copying hooks in est.evaluate (leads to updated eval_hooks in Experiment and crash second time).
Change: 144777878
  • Loading branch information
Illia Polosukhin authored and tensorflower-gardener committed Jan 18, 2017
1 parent 3d5de06 commit 781ccc8
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 10 deletions.
Expand Up @@ -814,7 +814,8 @@ def _evaluate_model(self,

update_op, eval_dict = self._extract_metric_update_ops(eval_dict)

hooks = hooks or []
# We need to copy the hook array as we modify it, thus [:].
hooks = hooks[:] if hooks else []
if feed_fn:
hooks.append(basic_session_run_hooks.FeedFnHook(feed_fn))
if steps:
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
Expand Up @@ -620,6 +620,16 @@ def testIrisInputFn(self):
predictions = list(est.predict(x=iris.data))
self.assertEqual(len(predictions), iris.target.shape[0])

def testHooksNotChanged(self):
est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
# We pass empty array and expect it to remain empty after calling
# fit and evaluate. Requires inside to copy this array if any hooks were
# added.
my_array = []
est.fit(input_fn=iris_input_fn, steps=100, monitors=my_array)
_ = est.evaluate(input_fn=iris_input_fn, steps=1, hooks=my_array)
self.assertEqual(my_array, [])

def testIrisInputFnLabelsDict(self):
iris = base.load_iris()
est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
Expand Down
8 changes: 5 additions & 3 deletions tensorflow/contrib/learn/python/learn/experiment.py
Expand Up @@ -378,7 +378,8 @@ def _continuous_eval(self,
steps=self._eval_steps,
metrics=self._eval_metrics,
name=name,
checkpoint_path=latest_path)
checkpoint_path=latest_path,
hooks=self._eval_hooks)
# Ensure eval result is not None for next round of evaluation.
if not eval_result:
eval_result = {}
Expand Down Expand Up @@ -454,14 +455,15 @@ def train_and_evaluate(self):
self._train_monitors += [monitors.ValidationMonitor(
input_fn=self._eval_input_fn, eval_steps=self._eval_steps,
metrics=self._eval_metrics, every_n_steps=self._min_eval_frequency,
name=eval_dir_suffix,
name=eval_dir_suffix, hooks=self._eval_hooks
)]
self.train(delay_secs=0)

eval_result = self._estimator.evaluate(input_fn=self._eval_input_fn,
steps=self._eval_steps,
metrics=self._eval_metrics,
name=eval_dir_suffix)
name=eval_dir_suffix,
hooks=self._eval_hooks)
export_results = self._maybe_export(eval_result)
return eval_result, export_results

Expand Down
36 changes: 32 additions & 4 deletions tensorflow/contrib/learn/python/learn/experiment_test.py
Expand Up @@ -42,6 +42,7 @@
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat
from tensorflow.python.util.all_util import reveal_undocumented

Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(self, config=None, max_evals=5):
self._max_evals = max_evals
self.export_count = 0
self.monitors = []
self.eval_hooks = []
self._config = config or run_config.RunConfig()
self._model_dir = tempfile.mkdtemp()

Expand All @@ -87,6 +89,8 @@ def config(self):

def evaluate(self, **kwargs):
tf_logging.info('evaluate called with args: %s' % kwargs)
if 'hooks' in kwargs:
self.eval_hooks = kwargs['hooks']
self.eval_count += 1
if self.eval_count > self._max_evals:
tf_logging.info('Ran %d evals. Done.' % self.eval_count)
Expand Down Expand Up @@ -117,6 +121,10 @@ def export_savedmodel(self, export_dir_base, export_input_fn, **kwargs):
compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp'))


class _NoopHook(session_run_hook.SessionRunHook):
pass


class ExperimentTest(test.TestCase):

def _cluster_spec(self):
Expand Down Expand Up @@ -253,52 +261,63 @@ def test_train_raises_if_job_name_is_missing(self):
def test_evaluate(self):
est = TestEstimator()
est.fake_checkpoint()
noop_hook = _NoopHook()
ex = experiment.Experiment(
est,
train_input_fn='train_input',
eval_input_fn='eval_input',
eval_metrics='eval_metrics',
eval_hooks=[noop_hook],
eval_steps='steps',
eval_delay_secs=0)
ex.evaluate()
self.assertEquals(1, est.eval_count)
self.assertEquals(0, est.fit_count)
self.assertEquals(1, est.eval_count)
self.assertEquals([noop_hook], est.eval_hooks)

def test_evaluate_delay(self):
est = TestEstimator()
est.fake_checkpoint()
noop_hook = _NoopHook()
ex = experiment.Experiment(
est, train_input_fn='train_input', eval_input_fn='eval_input')
est, train_input_fn='train_input', eval_input_fn='eval_input',
eval_hooks=[noop_hook])

for delay in [0, 1, 3]:
with test.mock.patch('time.sleep', SheepCounter()) as sheep:
ex.evaluate(delay_secs=delay)
self.assertAlmostEqual(delay, sheep.total_time, delta=0.1)
self.assertEquals([noop_hook], est.eval_hooks)

def test_continuous_eval(self):
est = TestEstimator()
est.fake_checkpoint()
noop_hook = _NoopHook()
ex = experiment.Experiment(
est,
train_input_fn='train_input',
eval_input_fn='eval_input',
eval_metrics='eval_metrics',
eval_hooks=[noop_hook],
eval_delay_secs=0,
continuous_eval_throttle_secs=0)
self.assertRaises(
StopIteration, ex.continuous_eval, evaluate_checkpoint_only_once=False)
self.assertEquals(6, est.eval_count)
self.assertEquals(0, est.fit_count)
self.assertEquals(6, est.eval_count)
self.assertEquals([noop_hook], est.eval_hooks)

def test_continuous_eval_throttle_delay(self):
for delay in [0, 1, 2]:
est = TestEstimator()
est.fake_checkpoint()
noop_hook = _NoopHook()
ex = experiment.Experiment(
est,
train_input_fn='train_input',
eval_input_fn='eval_input',
eval_metrics='eval_metrics',
eval_hooks=[noop_hook],
continuous_eval_throttle_secs=delay,
eval_delay_secs=0)
with test.mock.patch('time.sleep', SheepCounter()) as sheep:
Expand All @@ -311,6 +330,7 @@ def test_continuous_eval_throttle_delay(self):
def test_continuous_eval_predicate_fn(self):
est = TestEstimator()
est.fake_checkpoint()
noop_hook = _NoopHook()

def _predicate_fn(unused_eval_result):
return est.eval_count < 3
Expand All @@ -320,38 +340,45 @@ def _predicate_fn(unused_eval_result):
train_input_fn='train_input',
eval_input_fn='eval_input',
eval_metrics='eval_metrics',
eval_hooks=[noop_hook],
eval_delay_secs=0,
continuous_eval_throttle_secs=0,
continuous_eval_predicate_fn=_predicate_fn)
ex.continuous_eval(evaluate_checkpoint_only_once=False)
self.assertEquals(3, est.eval_count)
self.assertEquals(0, est.fit_count)
self.assertEquals(3, est.eval_count)
self.assertEquals([noop_hook], est.eval_hooks)

def test_run_local(self):
est = TestEstimator()
noop_hook = _NoopHook()
ex = experiment.Experiment(
est,
train_input_fn='train_input',
eval_input_fn='eval_input',
eval_metrics='eval_metrics',
eval_hooks=[noop_hook],
train_steps=100,
eval_steps=100,
local_eval_frequency=10)
ex.local_run()
self.assertEquals(1, est.fit_count)
self.assertEquals(1, est.eval_count)
self.assertEquals(1, len(est.monitors))
self.assertEquals([noop_hook], est.eval_hooks)
self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor))

def test_train_and_evaluate(self):
est = TestEstimator()
noop_hook = _NoopHook()
export_strategy = saved_model_export_utils.make_export_strategy(
est, 'export_input', exports_to_keep=None)
ex = experiment.Experiment(
est,
train_input_fn='train_input',
eval_input_fn='eval_input',
eval_metrics='eval_metrics',
eval_hooks=[noop_hook],
train_steps=100,
eval_steps=100,
export_strategies=export_strategy)
Expand All @@ -360,6 +387,7 @@ def test_train_and_evaluate(self):
self.assertEquals(1, est.eval_count)
self.assertEquals(1, est.export_count)
self.assertEquals(1, len(est.monitors))
self.assertEquals([noop_hook], est.eval_hooks)
self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor))

@test.mock.patch.object(server_lib, 'Server')
Expand Down
9 changes: 7 additions & 2 deletions tensorflow/contrib/learn/python/learn/monitors.py
Expand Up @@ -618,7 +618,8 @@ class ValidationMonitor(EveryN):

def __init__(self, x=None, y=None, input_fn=None, batch_size=None,
eval_steps=None,
every_n_steps=100, metrics=None, early_stopping_rounds=None,
every_n_steps=100, metrics=None, hooks=None,
early_stopping_rounds=None,
early_stopping_metric="loss",
early_stopping_metric_minimize=True, name=None):
"""Initializes a ValidationMonitor.
Expand All @@ -632,6 +633,8 @@ def __init__(self, x=None, y=None, input_fn=None, batch_size=None,
every_n_steps: Check for new checkpoints to evaluate every N steps. If a
new checkpoint is found, it is evaluated. See `EveryN`.
metrics: See `BaseEstimator.evaluate`.
hooks: A list of `SessionRunHook` hooks to pass to the
`Estimator`'s `evaluate` function.
early_stopping_rounds: `int`. If the metric indicated by
`early_stopping_metric` does not change according to
`early_stopping_metric_minimize` for this many steps, then training
Expand Down Expand Up @@ -660,6 +663,7 @@ def __init__(self, x=None, y=None, input_fn=None, batch_size=None,
self.batch_size = batch_size
self.eval_steps = eval_steps
self.metrics = metrics
self.hooks = hooks
self.early_stopping_rounds = early_stopping_rounds
self.early_stopping_metric = early_stopping_metric
self.early_stopping_metric_minimize = early_stopping_metric_minimize
Expand Down Expand Up @@ -709,7 +713,8 @@ def every_n_step_end(self, step, outputs):
# Run evaluation and log it.
validation_outputs = self._estimator.evaluate(
x=self.x, y=self.y, input_fn=self.input_fn, batch_size=self.batch_size,
steps=self.eval_steps, metrics=self.metrics, name=self.name)
steps=self.eval_steps, metrics=self.metrics, hooks=self.hooks,
name=self.name)
stats = []
for name in validation_outputs:
stats.append("%s = %s" % (name, str(validation_outputs[name])))
Expand Down

0 comments on commit 781ccc8

Please sign in to comment.