Skip to content

Commit

Permalink
[tune] Fix SearchAlg finishing early (#3081)
Browse files Browse the repository at this point in the history
* Fix trial search alg finishing early

* Fix lint

* fix lint

* nit fix
  • Loading branch information
richardliaw authored and ericl committed Oct 22, 2018
1 parent 221d166 commit eff7cb4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
2 changes: 1 addition & 1 deletion doc/source/tune-usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ For TensorFlow model training, this would look something like this `(full tensor
.. code-block:: python
class MyClass(Trainable):
def _setup(self):
def _setup(self, config):
self.saver = tf.train.Saver()
self.sess = ...
self.iteration = 0
Expand Down
28 changes: 27 additions & 1 deletion python/ray/tune/test/trial_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from ray.tune.trial import Trial, Resources
from ray.tune.trial_runner import TrialRunner
from ray.tune.suggest import grid_search, BasicVariantGenerator
from ray.tune.suggest.suggestion import _MockSuggestionAlgorithm
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
SuggestionAlgorithm)
from ray.tune.suggest.variant_generator import RecursiveDependencyError


Expand Down Expand Up @@ -1385,6 +1386,31 @@ def testSearchAlgStalled(self):
self.assertTrue(searcher.is_finished())
self.assertTrue(runner.is_finished())

def testSearchAlgFinishes(self):
"""SearchAlg changing state in `next_trials` does not crash."""

class FinishFastAlg(SuggestionAlgorithm):
def next_trials(self):
self._finished = True
return []

ray.init(num_cpus=4, num_gpus=2)
experiment_spec = {
"run": "__fake",
"num_samples": 3,
"stop": {
"training_iteration": 1
}
}
searcher = FinishFastAlg()
experiments = [Experiment.from_json("test", experiment_spec)]
searcher.add_configurations(experiments)

runner = TrialRunner(search_alg=searcher)
runner.step() # This should not fail
self.assertTrue(searcher.is_finished())
self.assertTrue(runner.is_finished())


if __name__ == "__main__":
unittest.main(verbosity=2)
4 changes: 4 additions & 0 deletions python/ray/tune/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def step(self):
self.trial_executor.start_trial(next_trial)
elif self.trial_executor.get_running_trials():
self._process_events()
elif self.is_finished():
# We check `is_finished` again here because the experiment
# may have finished while getting the next trial.
pass
else:
for trial in self._trials:
if trial.status == Trial.PENDING:
Expand Down

0 comments on commit eff7cb4

Please sign in to comment.