forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tune] Add scikit-optimize to Tune (ray-project#3924)
- Loading branch information
1 parent
8df7728
commit 9797028
Showing
6 changed files
with
175 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""This test checks that Skopt is functional. | ||
It also checks that it is usable with a separate scheduler. | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import ray | ||
from ray.tune import run_experiments, register_trainable | ||
from ray.tune.schedulers import AsyncHyperBandScheduler | ||
from ray.tune.suggest import SkOptSearch | ||
|
||
|
||
def easy_objective(config, reporter): | ||
import time | ||
time.sleep(0.2) | ||
for i in range(config["iterations"]): | ||
reporter( | ||
timesteps_total=i, | ||
neg_mean_loss=-(config["height"] - 14)**2 + | ||
abs(config["width"] - 3)) | ||
time.sleep(0.02) | ||
|
||
|
||
if __name__ == '__main__': | ||
import argparse | ||
from skopt import Optimizer | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--smoke-test", action="store_true", help="Finish quickly for testing") | ||
args, _ = parser.parse_known_args() | ||
ray.init(redirect_output=True) | ||
|
||
register_trainable("exp", easy_objective) | ||
|
||
config = { | ||
"skopt_exp": { | ||
"run": "exp", | ||
"num_samples": 10 if args.smoke_test else 50, | ||
"config": { | ||
"iterations": 100, | ||
}, | ||
"stop": { | ||
"timesteps_total": 100 | ||
}, | ||
} | ||
} | ||
optimizer = Optimizer([(0, 20), (-100, 100)]) | ||
algo = SkOptSearch( | ||
optimizer, ["width", "height"], | ||
max_concurrent=4, | ||
reward_attr="neg_mean_loss") | ||
scheduler = AsyncHyperBandScheduler(reward_attr="neg_mean_loss") | ||
run_experiments(config, search_alg=algo, scheduler=scheduler) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
try: | ||
import skopt | ||
except Exception: | ||
skopt = None | ||
|
||
from ray.tune.suggest.suggestion import SuggestionAlgorithm | ||
|
||
|
||
class SkOptSearch(SuggestionAlgorithm): | ||
"""A wrapper around skopt to provide trial suggestions. | ||
Requires skopt to be installed. | ||
Parameters: | ||
optimizer (skopt.optimizer.Optimizer): Optimizer provided | ||
from skopt. | ||
parameter_names (list): List of parameter names. Should match | ||
the dimension of the optimizer output. | ||
max_concurrent (int): Number of maximum concurrent trials. Defaults | ||
to 10. | ||
reward_attr (str): The training result objective value attribute. | ||
This refers to an increasing value. | ||
Example: | ||
>>> from skopt import Optimizer | ||
>>> optimizer = Optimizer([(0,20),(-100,100)]) | ||
>>> config = { | ||
>>> "my_exp": { | ||
>>> "run": "exp", | ||
>>> "num_samples": 10, | ||
>>> "stop": { | ||
>>> "training_iteration": 100 | ||
>>> }, | ||
>>> } | ||
>>> } | ||
>>> algo = SkOptSearch(optimizer, | ||
>>> ["width", "height"], max_concurrent=4, | ||
>>> reward_attr="neg_mean_loss") | ||
""" | ||
|
||
def __init__(self, | ||
optimizer, | ||
parameter_names, | ||
max_concurrent=10, | ||
reward_attr="episode_reward_mean", | ||
**kwargs): | ||
assert skopt is not None, """skopt must be installed! | ||
You can install Skopt with the command: | ||
`pip install scikit-optimize`.""" | ||
assert type(max_concurrent) is int and max_concurrent > 0 | ||
self._max_concurrent = max_concurrent | ||
self._parameters = parameter_names | ||
self._reward_attr = reward_attr | ||
self._skopt_opt = optimizer | ||
self._live_trial_mapping = {} | ||
super(SkOptSearch, self).__init__(**kwargs) | ||
|
||
def _suggest(self, trial_id): | ||
if self._num_live_trials() >= self._max_concurrent: | ||
return None | ||
suggested_config = self._skopt_opt.ask() | ||
self._live_trial_mapping[trial_id] = suggested_config | ||
return dict(zip(self._parameters, suggested_config)) | ||
|
||
def on_trial_result(self, trial_id, result): | ||
pass | ||
|
||
def on_trial_complete(self, | ||
trial_id, | ||
result=None, | ||
error=False, | ||
early_terminated=False): | ||
"""Passes the result to skopt unless early terminated or errored. | ||
The result is internally negated when interacting with Skopt | ||
so that Skopt Optimizers can "maximize" this value, | ||
as it minimizes on default. | ||
""" | ||
skopt_trial_info = self._live_trial_mapping.pop(trial_id) | ||
if result: | ||
self._skopt_opt.tell(skopt_trial_info, -result[self._reward_attr]) | ||
|
||
def _num_live_trials(self): | ||
return len(self._live_trial_mapping) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters