Skip to content

Commit

Permalink
[tune] Reduce sampling API clutter (#4739)
Browse files Browse the repository at this point in the history
Adds some sugar for tune sampling API (for commonplace sampling idioms).
  • Loading branch information
richardliaw committed May 7, 2019
1 parent 71b2dec commit 7f50c96
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 65 deletions.
6 changes: 4 additions & 2 deletions python/ray/tune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from ray.tune.experiment import Experiment
from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.suggest import grid_search, function, sample_from
from ray.tune.suggest import grid_search
from ray.tune.sample import (function, sample_from, uniform, choice, randint,
randn)

__all__ = [
"Trainable", "TuneError", "grid_search", "register_env",
"register_trainable", "run", "run_experiments", "Experiment", "function",
"sample_from"
"sample_from", "uniform", "choice", "randint", "randn"
]
7 changes: 2 additions & 5 deletions python/ray/tune/examples/mnist_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def test():
datasets.MNIST("~/data", train=True, download=True)
args = parser.parse_args()

import numpy as np
import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
Expand Down Expand Up @@ -183,9 +182,7 @@ def test():
},
"num_samples": 1 if args.smoke_test else 10,
"config": {
"lr": tune.sample_from(
lambda spec: np.random.uniform(0.001, 0.1)),
"momentum": tune.sample_from(
lambda spec: np.random.uniform(0.1, 0.9)),
"lr": tune.uniform(0.001, 0.1),
"momentum": tune.uniform(0.1, 0.9),
}
})
7 changes: 2 additions & 5 deletions python/ray/tune/examples/mnist_pytorch_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def _restore(self, checkpoint_path):
datasets.MNIST("~/data", train=True, download=True)
args = parser.parse_args()

import numpy as np
import ray
from ray import tune
from ray.tune.schedulers import HyperBandScheduler
Expand All @@ -193,9 +192,7 @@ def _restore(self, checkpoint_path):
"checkpoint_at_end": True,
"config": {
"args": args,
"lr": tune.sample_from(
lambda spec: np.random.uniform(0.001, 0.1)),
"momentum": tune.sample_from(
lambda spec: np.random.uniform(0.1, 0.9)),
"lr": tune.uniform(0.001, 0.1),
"momentum": tune.uniform(0.1, 0.9),
}
})
2 changes: 1 addition & 1 deletion python/ray/tune/log_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
from ray.tune.error import TuneError
from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.suggest.variant_generator import function as tune_function
from ray.tune.sample import function as tune_function

logger = logging.getLogger(__name__)
_log_sync_warned = False
Expand Down
71 changes: 71 additions & 0 deletions python/ray/tune/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import numpy as np

logger = logging.getLogger(__name__)


class sample_from(object):
"""Specify that tune should sample configuration values from this function.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: An callable function to draw a sample from.
"""

def __init__(self, func):
self.func = func

def __str__(self):
return "tune.sample_from({})".format(str(self.func))

def __repr__(self):
return "tune.sample_from({})".format(repr(self.func))


class function(object):
"""Wraps `func` to make sure it is not expanded during resolution.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: A function literal.
"""

def __init__(self, func):
self.func = func

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

def __str__(self):
return "tune.function({})".format(str(self.func))

def __repr__(self):
return "tune.function({})".format(repr(self.func))


def uniform(*args, **kwargs):
"""A wrapper around np.random.uniform."""
return sample_from(lambda _: np.random.uniform(*args, **kwargs))


def choice(*args, **kwargs):
"""A wrapper around np.random.choice."""
return sample_from(lambda _: np.random.choice(*args, **kwargs))


def randint(*args, **kwargs):
"""A wrapper around np.random.randint."""
return sample_from(lambda _: np.random.randint(*args, **kwargs))


def randn(*args, **kwargs):
"""A wrapper around np.random.randn."""
return sample_from(lambda _: np.random.randn(*args, **kwargs))
11 changes: 3 additions & 8 deletions python/ray/tune/suggest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest.variant_generator import grid_search, function, \
sample_from
from ray.tune.suggest.variant_generator import grid_search

__all__ = [
"SearchAlgorithm",
"BasicVariantGenerator",
"SuggestionAlgorithm",
"grid_search",
"function",
"sample_from",
"SearchAlgorithm", "BasicVariantGenerator", "SuggestionAlgorithm",
"grid_search"
]


Expand Down
44 changes: 1 addition & 43 deletions python/ray/tune/suggest/variant_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import types

from ray.tune import TuneError
from ray.tune.sample import sample_from

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,49 +55,6 @@ def grid_search(values):
return {"grid_search": values}


class sample_from(object):
"""Specify that tune should sample configuration values from this function.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: An callable function to draw a sample from.
"""

def __init__(self, func):
self.func = func

def __str__(self):
return "tune.sample_from({})".format(str(self.func))

def __repr__(self):
return "tune.sample_from({})".format(repr(self.func))


class function(object):
"""Wraps `func` to make sure it is not expanded during resolution.
The use of function arguments in tune configs must be disambiguated by
either wrapped the function in tune.sample_from() or tune.function().
Arguments:
func: A function literal.
"""

def __init__(self, func):
self.func = func

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

def __str__(self):
return "tune.function({})".format(str(self.func))

def __repr__(self):
return "tune.function({})".format(repr(self.func))


_STANDARD_IMPORTS = {
"random": random,
"np": numpy,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/trial_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ray.tune.ray_trial_executor import RayTrialExecutor
from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE
from ray.tune.trial import Trial, Checkpoint
from ray.tune.suggest import function
from ray.tune.sample import function
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
from ray.tune.util import warn_if_slow
from ray.utils import binary_to_hex, hex_to_binary
Expand Down

0 comments on commit 7f50c96

Please sign in to comment.