Browse files

ENH add randomized hyperparameter optimization

  • Loading branch information...
1 parent db6f005 commit 0c94b55c63fcd28a47b156cbc0cf331ec6d42a7a @amueller amueller committed Mar 3, 2013
View
4 doc/modules/classes.rst
@@ -455,7 +455,9 @@ From text
:template: class.rst
grid_search.GridSearchCV
- grid_search.IterGrid
+ grid_search.ParameterGrid
+ grid_search.ParameterSampler
+ grid_search.RandomizedSearchCV
.. _hmm_ref:
View
80 doc/modules/grid_search.rst
@@ -1,11 +1,11 @@
.. _grid_search:
+.. currentmodule:: sklearn.grid_search
+
==========================================
Grid Search: setting estimator parameters
==========================================
-.. currentmodule:: sklearn
-
Grid Search is used to optimize the parameters of a model (e.g. ``C``,
``kernel`` and ``gamma`` for Support Vector Classifier, ``alpha`` for
Lasso, etc.) using an internal :ref:`cross_validation` scheme).
@@ -15,10 +15,10 @@ GridSearchCV
============
The main class for implementing hyperparameters grid search in
-scikit-learn is :class:`grid_search.GridSearchCV`. This class is passed
+scikit-learn is :class:`GridSearchCV`. This class is passed
a base model instance (for example ``sklearn.svm.SVC()``) along with a
grid of potential hyper-parameter values specified with the `param_grid`
-attribute. For instace the following `param_grid`::
+attribute. For instance the following `param_grid`::
param_grid = [
{'C': [1, 10, 100, 1000], 'kernel': ['linear']},
@@ -30,7 +30,7 @@ C values in [1, 10, 100, 1000], and the second one with an RBG kernel,
and the cross-product of C values ranging in [1, 10, 100, 1000] and gamma
values in [0.001, 0.0001].
-The :class:`grid_search.GridSearchCV` instance implements the usual
+The :class:`GridSearchCV` instance implements the usual
estimator API: when "fitting" it on a dataset all the possible
combinations of hyperparameter values are evaluated and the best
combinations is retained.
@@ -64,24 +64,76 @@ alternative scoring function can be specified via the ``scoring`` parameter to
:class:`GridSearchCV`.
See :ref:`score_func_objects` for more details.
-Examples
-========
+.. topic:: Examples:
-- See :ref:`example_grid_search_digits.py` for an example of
- Grid Search computation on the digits dataset.
+ - See :ref:`example_grid_search_digits.py` for an example of
+ Grid Search computation on the digits dataset.
-- See :ref:`example_grid_search_text_feature_extraction.py` for an example
- of Grid Search coupling parameters from a text documents feature
- extractor (n-gram count vectorizer and TF-IDF transformer) with a
- classifier (here a linear SVM trained with SGD with either elastic
- net or L2 penalty) using a :class:`pipeline.Pipeline` instance.
+ - See :ref:`example_grid_search_text_feature_extraction.py` for an example
+ of Grid Search coupling parameters from a text documents feature
+ extractor (n-gram count vectorizer and TF-IDF transformer) with a
+ classifier (here a linear SVM trained with SGD with either elastic
+ net or L2 penalty) using a :class:`pipeline.Pipeline` instance.
.. note::
Computations can be run in parallel if your OS supports it, by using
the keyword n_jobs=-1, see function signature for more details.
+Randomized Hyper-Parameter Optimization
+=======================================
+While using a grid of parameter settings is currently the most widely used
+method for hyper-parameter optimization, other search methods have more
+favourable properties.
+:class:`RandomizedSearchCV` implements a randomized search over hyperparameters,
+where each setting is sampled from a distribution over possible parameter values.
+This has two main benefits over searching over a grid:
+
+* A budget can be chosen independent of the number of parameters and possible values.
+
+* Adding parameters that do not influence the performance does not decrease efficiency.
+
+Specifying how parameters should be sampled is done using a dictionary, very
+similar to specifying parameters for :class:`GridSearchCV`. Additionally,
+a computation budget is specified using ``n_iter``, which is the number
+of iterations (parameter samples) to be used.
+For each parameter, either a distribution over possible values or list of
+discrete choices (which will be sampled uniformly) can be specified::
+
+ [{'C': scipy.stats.expon(scale=100), 'gamma': scipy.stats.expon(scale=.1),
+ 'kernel': ['rbf'], 'class_weight':['auto', None]}]
+
+This example uses the ``scipy.stats`` module, which contains many useful
+distributions for sampling hyperparameters, such as ``expon``, ``gamma``,
+``uniform`` or ``randint``.
+In principle, any function can be passed that provides a ``rvs`` (random
+variate sample) method to sample a value. A call to the ``rvs`` function should
+provide independent random samples from possible parameter values on
+consecutive calls.
+
+ .. warning::
+
+ The distributions in ``scipy.stats`` do not allow specifying a random
+ state. Instead, they use the global numpy random state, that can be seeded
+ via ``np.random.seed`` or set using ``np.random.set_state``.
+
+For continuous parameters, such as ``C`` above, it is important to specify
+a continuous distribution to take full advantage of the randomization. This way,
+increasing ``n_iter`` will always lead to a finer search.
+
+.. topic:: Examples:
+
+ * :ref:`example_randomized_search.py` compares the usage and efficiency
+ of randomized search and grid search.
+
+.. topic:: References:
+
+ * Bergstra, J. and Bengio, Y.,
+ Random search for hyper-parameter optimization,
+ The Journal of Machine Learning Research (2012)
+
+
Alternatives to brute force grid search
=======================================
View
2 doc/tutorial/statistical_inference/model_selection.rst
@@ -146,7 +146,7 @@ estimator during the construction and exposes an estimator API::
>>> clf.fit(X_digits[:1000], y_digits[:1000]) # doctest: +ELLIPSIS
GridSearchCV(cv=None,...
>>> clf.best_score_
- 0.988991985997974
+ 0.98899999999999999
>>> clf.best_estimator_.gamma
9.9999999999999995e-07
View
11 doc/whats_new.rst
@@ -35,6 +35,10 @@ Changelog
attribute. Setting ``compute_importances=True`` is no longer required.
By `Gilles Louppe`_.
+ - Added :class:`grid_search.RandomizedSearchCV` and
+ :class:`grid_search.ParameterSampler` for randomized hyperparameter
+ optimization. By `Andreas Müller`_.
+
- :class:`LinearSVC`, :class:`SGDClassifier` and :class:`SGDRegressor`
now have a ``sparsify`` method that converts their ``coef_`` into a
sparse matrix, meaning stored models trained using these estimators
@@ -46,6 +50,13 @@ Changelog
- Fixed bug in :class:`MinMaxScaler` causing incorrect scaling of the
features for non-default ``feature_range`` settings. By `Andreas Müller`_.
+
+API changes summary
+-------------------
+
+ - :class:`grid_search.IterGrid` was renamed to
+ :class:`grid_search.ParameterGrid`.
+
- Fixed bug in :class:`KFold` causing imperfect class balance in some
cases. By `Alexandre Gramfort`_ and Tadej Janež.
View
2 examples/grid_search_digits.py
@@ -59,7 +59,7 @@
print()
print("Grid scores on development set:")
print()
- for params, mean_score, scores in clf.grid_scores_:
+ for params, mean_score, scores in clf.cv_scores_:
print("%0.3f (+/-%0.03f) for %r"
% (mean_score, scores.std() / 2, params))
print()
View
4 examples/svm/plot_rbf_parameters.py
@@ -105,8 +105,8 @@
pl.axis('tight')
# plot the scores of the grid
-# grid_scores_ contains parameter settings and scores
-score_dict = grid.grid_scores_
+# cv_scores_ contains parameter settings and scores
+score_dict = grid.cv_scores_
# We extract just the scores
scores = [x[1] for x in score_dict]
View
2 examples/svm/plot_svm_scale_c.py
@@ -131,7 +131,7 @@
cv=ShuffleSplit(n=n_samples, train_size=train_size,
n_iter=250, random_state=1))
grid.fit(X, y)
- scores = [x[1] for x in grid.grid_scores_]
+ scores = [x[1] for x in grid.cv_scores_]
scales = [(1, 'No scaling'),
((n_samples * train_size), '1/n_samples'),
View
649 sklearn/grid_search.py
@@ -8,26 +8,29 @@
# Gael Varoquaux <gael.varoquaux@normalesup.org>
# License: BSD Style.
-from itertools import product
import time
import warnings
import numbers
+from itertools import product
+from collections import namedtuple
+from abc import ABCMeta, abstractmethod
import numpy as np
from .base import BaseEstimator, is_classifier, clone
from .base import MetaEstimatorMixin
from .cross_validation import check_cv
from .externals.joblib import Parallel, delayed, logger
-from .utils.validation import _num_samples
-from .utils import check_arrays, safe_mask
+from .utils import safe_mask, check_random_state
+from .utils.validation import _num_samples, check_arrays
from .metrics import SCORERS, Scorer
-__all__ = ['GridSearchCV', 'IterGrid', 'fit_grid_point']
+__all__ = ['GridSearchCV', 'ParameterGrid', 'fit_grid_point',
+ 'ParameterSampler', 'RandomizedSearchCV']
-class IterGrid(object):
- """Generators on the combination of the various parameter lists given
+class ParameterGrid(object):
+ """Generators on the combination of the various parameter lists given.
Parameters
----------
@@ -43,16 +46,16 @@ class IterGrid(object):
Examples
--------
- >>> from sklearn.grid_search import IterGrid
+ >>> from sklearn.grid_search import ParameterGrid
>>> param_grid = {'a':[1, 2], 'b':[True, False]}
- >>> list(IterGrid(param_grid)) #doctest: +NORMALIZE_WHITESPACE
+ >>> list(ParameterGrid(param_grid)) #doctest: +NORMALIZE_WHITESPACE
[{'a': 1, 'b': True}, {'a': 1, 'b': False},
{'a': 2, 'b': True}, {'a': 2, 'b': False}]
See also
--------
:class:`GridSearchCV`:
- uses ``IterGrid`` to perform a full parallelized grid search.
+ uses ``ParameterGrid`` to perform a full parallelized grid search.
"""
def __init__(self, param_grid):
@@ -72,11 +75,146 @@ def __iter__(self):
yield params
+class IterGrid(ParameterGrid):
+ """Generators on the combination of the various parameter lists given.
+
+ This class is DEPRECATED. It was renamed to ``ParameterGrid``. The name
+ ``IterGrid`` will be removed in 0.15.
+
+ Parameters
+ ----------
+ param_grid: dict of string to sequence
+ The parameter grid to explore, as a dictionary mapping estimator
+ parameters to sequences of allowed values.
+
+ Returns
+ -------
+ params: dict of string to any
+ **Yields** dictionaries mapping each estimator parameter to one of its
+ allowed values.
+
+ Examples
+ --------
+ >>> from sklearn.grid_search import IterGrid
+ >>> param_grid = {'a':[1, 2], 'b':[True, False]}
+ >>> list(IterGrid(param_grid)) #doctest: +NORMALIZE_WHITESPACE
+ [{'a': 1, 'b': True}, {'a': 1, 'b': False},
+ {'a': 2, 'b': True}, {'a': 2, 'b': False}]
+
+ See also
+ --------
+ :class:`GridSearchCV`:
+ uses ``IterGrid`` to perform a full parallelized grid search.
+ """
+
+ def __init__(self, param_grid):
+ warnings.warn("IterGrid was renamed to ParameterGrid and will be"
+ " removed in 0.15.", DeprecationWarning)
+ super(IterGrid, self).__init__(param_grid)
+
+
+class ParameterSampler(object):
+ """Generator on parameters sampled from given distributions.
+
+ Parameters
+ ----------
+ param_distributions : dict
+ Dictionary where the keys are parameters and values
+ are distributions from which a parameter is to be sampled.
+ Distributions either have to provide a ``rvs`` function
+ to sample from them, or can be given as a list of values,
+ where a uniform distribution is assumed.
+
+ n_iter : integer
+ Number of parameter settings that are produced.
+
+ random_state : int or RandomState
+ Pseudo number generator state used for random sampling.
+
+ Returns
+ -------
+ params: dict of string to any
+ **Yields** dictionaries mapping each estimator parameter to
+ as sampled value.
+
+ Examples
+ --------
+ >>> from sklearn.grid_search import ParameterSampler
+ >>> from scipy.stats.distributions import expon
+ >>> import numpy as np
+ >>> np.random.seed(0)
+ >>> param_grid = {'a':[1, 2], 'b': expon()}
+ >>> list(ParameterSampler(param_grid, n_iter=4))
+ ... #doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
+ [{'a': 1, 'b': 0.89...}, {'a': 1, 'b': 0.92...},
+ {'a': 2, 'b': 1.87...}, {'a': 2, 'b': 1.03...}]
+
+ """
+ def __init__(self, param_distributions, n_iter, random_state=None):
+ self.param_distributions = param_distributions
+ self.n_iter = n_iter
+ self.random_state = random_state
+
+ def __iter__(self):
+ rnd = check_random_state(self.random_state)
+ # Always sort the keys of a dictionary, for reproducibility
+ items = sorted(self.param_distributions.items())
+ for i in range(self.n_iter):
+ params = dict()
+ for k, v in items:
+ if hasattr(v, "rvs"):
+ params[k] = v.rvs()
+ else:
+ params[k] = v[rnd.randint(len(v))]
+ yield params
+
+
def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer,
verbose, loss_func=None, **fit_params):
- """Run fit on one set of parameters
+ """Run fit on one set of parameters.
+
+ Parameters
+ ----------
+ X : array-like, sparse matrix or list
+ Input data.
+
+ y : array-like or None
+ Targets for input data.
+
+ base_clf : estimator object
+ This estimator will be cloned and then fitted.
- Returns the score and the instance of the classifier
+ clf_params : dict
+ Parameters to be set on base_estimator clone for this grid point.
+
+ train : ndarray, dtype int or bool
+ Boolean mask or indices for training set.
+
+ test : ndarray, dtype int or bool
+ Boolean mask or indices for test set.
+
+ scorer : callable or None.
+ If provided must be a scoring object / function with signature
+ ``scorer(estimator, X, y)``.
+
+ verbose : int
+ Verbosity level.
+
+ **fit_params : kwargs
+ Additional parameter passed to the fit function of the estimator.
+
+
+ Returns
+ -------
+ score : float
+ Score of this parameter setting on given training / test split.
+
+ estimator : estimator object
+ Estimator object of type base_clf that was fitted using clf_params
+ and provided train / test split.
+
+ n_samples_test : int
+ Number of test samples in this split.
"""
if verbose > 1:
start_time = time.time()
@@ -137,7 +275,7 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, scorer,
logger.short_format_time(time.time() -
start_time))
print("[GridSearchCV] %s %s" % ((64 - len(end_msg)) * '.', end_msg))
- return this_score, clf_params, _num_samples(X)
+ return this_score, clf_params, _num_samples(X_test)
def _check_param_grid(param_grid):
@@ -170,8 +308,193 @@ def _has_one_grid_point(param_grid):
return True
-class GridSearchCV(BaseEstimator, MetaEstimatorMixin):
- """Grid search on the parameters of a classifier
+class BaseSearchCV(BaseEstimator, MetaEstimatorMixin):
+ """Base class for hyper parameter search with cross-validation.
+ """
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def __init__(self, estimator, scoring=None, loss_func=None,
+ score_func=None, fit_params=None, n_jobs=1, iid=True,
+ refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs'):
+
+ self.scoring = scoring
+ self.estimator = estimator
+ self.loss_func = loss_func
+ self.score_func = score_func
+ self.n_jobs = n_jobs
+ self.fit_params = fit_params if fit_params is not None else {}
+ self.iid = iid
+ self.refit = refit
+ self.cv = cv
+ self.verbose = verbose
+ self.pre_dispatch = pre_dispatch
+ self._check_estimator()
+
+ def score(self, X, y=None):
+ """Returns the mean accuracy on the given test data and labels.
+
+ Parameters
+ ----------
+ X : array-like, shape = [n_samples, n_features]
+ Training set.
+
+ y : array-like, shape = [n_samples], optional
+ Labels for X.
+
+ Returns
+ -------
+ score : float
+
+ """
+ if hasattr(self.best_estimator_, 'score'):
+ return self.best_estimator_.score(X, y)
+ if self.scorer_ is None:
+ raise ValueError("No score function explicitly defined, "
+ "and the estimator doesn't provide one %s"
+ % self.best_estimator_)
+ y_predicted = self.predict(X)
+ return self.scorer(y, y_predicted)
+
+ def _check_estimator(self):
+ """Check that estimator can be fitted and score can be computed."""
+ if (not hasattr(self.estimator, 'fit') or
+ not (hasattr(self.estimator, 'predict')
+ or hasattr(self.estimator, 'score'))):
+ raise TypeError("estimator should a be an estimator implementing"
+ " 'fit' and 'predict' or 'score' methods,"
+ " %s (type %s) was passed" %
+ (self.estimator, type(self.estimator)))
+ if (self.scoring is None and self.loss_func is None and self.score_func
+ is None):
+ if not hasattr(self.estimator, 'score'):
+ raise TypeError(
+ "If no scoring is specified, the estimator passed "
+ "should have a 'score' method. The estimator %s "
+ "does not." % self.estimator)
+
+ def _set_methods(self):
+ """Create predict and predict_proba if present in best estimator."""
+ if hasattr(self.best_estimator_, 'predict'):
+ self.predict = self.best_estimator_.predict
+ if hasattr(self.best_estimator_, 'predict_proba'):
+ self.predict_proba = self.best_estimator_.predict_proba
+
+ def _fit(self, X, y, parameter_iterator, **params):
+ """Actual fitting, performing the search over parameters."""
+ estimator = self.estimator
+ cv = self.cv
+
+ n_samples = _num_samples(X)
+ X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr')
+
+ if self.loss_func is not None:
+ warnings.warn("Passing a loss function is "
+ "deprecated and will be removed in 0.15. "
+ "Either use strings or score objects.")
+ scorer = Scorer(self.loss_func, greater_is_better=False)
+ elif self.score_func is not None:
+ warnings.warn("Passing function as ``score_func`` is "
+ "deprecated and will be removed in 0.15. "
+ "Either use strings or score objects.")
+ scorer = Scorer(self.score_func)
+ elif isinstance(self.scoring, basestring):
+ scorer = SCORERS[self.scoring]
+ else:
+ scorer = self.scoring
+
+ self.scorer_ = scorer
+
+ if y is not None:
+ if len(y) != n_samples:
+ raise ValueError('Target variable (y) has a different number '
+ 'of samples (%i) than data (X: %i samples)'
+ % (len(y), n_samples))
+ y = np.asarray(y)
+ cv = check_cv(cv, X, y, classifier=is_classifier(estimator))
+
+ base_clf = clone(self.estimator)
+
+ pre_dispatch = self.pre_dispatch
+
+ out = Parallel(
+ n_jobs=self.n_jobs, verbose=self.verbose,
+ pre_dispatch=pre_dispatch)(
+ delayed(fit_grid_point)(
+ X, y, base_clf, clf_params, train, test, scorer,
+ self.verbose, **self.fit_params) for clf_params in
+ parameter_iterator for train, test in cv)
+
+ # Out is a list of triplet: score, estimator, n_test_samples
+ n_param_points = len(list(parameter_iterator))
+ n_fits = len(out)
+ n_folds = n_fits // n_param_points
+
+ scores = list()
+ cv_scores = list()
+ for start in range(0, n_fits, n_folds):
+ n_test_samples = 0
+ mean_validation_score = 0
+ these_points = list()
+ for this_score, clf_params, this_n_test_samples in \
+ out[start:start + n_folds]:
+ these_points.append(this_score)
+ if self.iid:
+ this_score *= this_n_test_samples
+ mean_validation_score += this_score
+ n_test_samples += this_n_test_samples
+ if self.iid:
+ mean_validation_score /= float(n_test_samples)
+ scores.append((mean_validation_score, clf_params))
+ cv_scores.append(these_points)
+
+ cv_scores = np.asarray(cv_scores)
+
+ # Note: we do not use max(out) to make ties deterministic even if
+ # comparison on estimator instances is not deterministic
+ if scorer is not None:
+ greater_is_better = scorer.greater_is_better
+ else:
+ greater_is_better = True
+
+ if greater_is_better:
+ best_score = -np.inf
+ else:
+ best_score = np.inf
+
+ for score, params in scores:
+ if ((score > best_score and greater_is_better)
+ or (score < best_score and not greater_is_better)):
+ best_score = score
+ best_params = params
+
+ self.best_params_ = best_params
+ self.best_score_ = best_score
+
+ if self.refit:
+ # fit the best estimator using the entire dataset
+ # clone first to work around broken estimators
+ best_estimator = clone(base_clf).set_params(**best_params)
+ if y is not None:
+ best_estimator.fit(X, y, **self.fit_params)
+ else:
+ best_estimator.fit(X, **self.fit_params)
+ self.best_estimator_ = best_estimator
+ self._set_methods()
+
+ # Store the computed scores
+ CVScoreTuple = namedtuple('CVScoreTuple', ('parameters',
+ 'mean_validation_score',
+ 'cv_validation_scores'))
+ self.cv_scores_ = [
+ CVScoreTuple(clf_params, score, all_scores)
+ for clf_params, (score, _), all_scores
+ in zip(parameter_iterator, scores, cv_scores)]
+ return self
+
+
+class GridSearchCV(BaseSearchCV):
+ """Grid search on the parameters of an estimator.
Important members are fit, predict.
@@ -197,10 +520,10 @@ class GridSearchCV(BaseEstimator, MetaEstimatorMixin):
for details.
fit_params : dict, optional
- parameters to pass to the fit method
+ Parameters to pass to the fit method.
n_jobs: int, optional
- number of jobs to run in parallel (default 1)
+ Number of jobs to run in parallel (default 1).
pre_dispatch: int, or string, optional
Controls the number of jobs that get dispatched during parallel
@@ -230,9 +553,9 @@ class GridSearchCV(BaseEstimator, MetaEstimatorMixin):
sklearn.cross_validation module for the list of possible objects
refit: boolean
- refit the best estimator with the entire dataset.
+ Refit the best estimator with the entire dataset.
If "False", it is impossible to make predictions using
- this GridSearch instance after fitting.
+ this GridSearchCV instance after fitting.
verbose: integer
Controls the verbosity: the higher, the more messages.
@@ -256,8 +579,15 @@ class GridSearchCV(BaseEstimator, MetaEstimatorMixin):
Attributes
----------
- `grid_scores_` : dict of any to float
+ `cv_scores_` : list of named tuples
Contains scores for all parameter combinations in param_grid.
+ Each entry corresponds to one parameter setting.
+ Each named tuple has the attributes:
+
+ * ``parameters``, a dict of parameter settings
+ * ``mean_validation_score``, the mean score over the
+ cross-validation folds
+ * ``cv_validation_scores``, the list of scores for each fold
`best_estimator_` : estimator
Estimator that was choosen by grid search, i.e. estimator
@@ -285,7 +615,7 @@ class GridSearchCV(BaseEstimator, MetaEstimatorMixin):
See Also
---------
- :class:`IterGrid`:
+ :class:`ParameterGrid`:
generates all the combinations of a an hyperparameter grid.
:func:`sklearn.cross_validation.train_test_split`:
@@ -298,46 +628,25 @@ class GridSearchCV(BaseEstimator, MetaEstimatorMixin):
def __init__(self, estimator, param_grid, scoring=None, loss_func=None,
score_func=None, fit_params=None, n_jobs=1, iid=True,
refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs'):
- if (not hasattr(estimator, 'score') and
- (not hasattr(estimator, 'predict')
- or (scoring is None and loss_func is None
- and score_func is None))):
- raise TypeError("The provided estimator %s does not implement a "
- "score function. In this case, it needs to "
- "implement a predict fuction and you have to "
- "provide either a score_func or a loss_func."
- % type(estimator))
-
+ super(GridSearchCV, self).__init__(
+ estimator, scoring, loss_func, score_func, fit_params, n_jobs, iid,
+ refit, cv, verbose, pre_dispatch)
+ self.param_grid = param_grid
_check_param_grid(param_grid)
- self.estimator = estimator
- self.param_grid = param_grid
- self.loss_func = loss_func
- self.score_func = score_func
- self.scoring = scoring
- self.n_jobs = n_jobs
- self.fit_params = fit_params if fit_params is not None else {}
- self.iid = iid
- self.refit = refit
- self.cv = cv
- self.verbose = verbose
- self.pre_dispatch = pre_dispatch
-
- def _set_methods(self):
- if hasattr(self.best_estimator_, 'predict'):
- self.predict = self.best_estimator_.predict
- if hasattr(self.best_estimator_, 'predict_proba'):
- self.predict_proba = self.best_estimator_.predict_proba
+ @property
+ def grid_scores_(self):
+ warnings.warn("grid_scores_ is deprecated and will be removed in 0.15."
+ " Use estimator_scores_ instead.", DeprecationWarning)
+ return self.cv_scores_
def fit(self, X, y=None, **params):
- """Run fit with all sets of parameters
-
- Returns the best classifier
+ """Run fit with all sets of parameters.
Parameters
----------
- X: array, [n_samples, n_features]
+ X: array-like, shape = [n_samples, n_features]
Training vector, where n_samples in the number of samples and
n_features is the number of features.
@@ -348,11 +657,9 @@ def fit(self, X, y=None, **params):
"""
estimator = self.estimator
cv = self.cv
-
- X, y = check_arrays(X, y, sparse_format="csr", allow_lists=True)
cv = check_cv(cv, X, y, classifier=is_classifier(estimator))
- grid = IterGrid(self.param_grid)
+ grid = ParameterGrid(self.param_grid)
base_clf = clone(self.estimator)
# Return early if there is only one grid point.
@@ -367,103 +674,159 @@ def fit(self, X, y=None, **params):
self._set_methods()
return self
- if self.loss_func is not None:
- warnings.warn("Passing a loss function is "
- "deprecated and will be removed in 0.15. "
- "Either use strings or score objects.")
- scorer = Scorer(self.loss_func, greater_is_better=False)
- elif self.score_func is not None:
- warnings.warn("Passing function as ``score_func`` is "
- "deprecated and will be removed in 0.15. "
- "Either use strings or score objects.")
- scorer = Scorer(self.score_func)
- elif isinstance(self.scoring, basestring):
- scorer = SCORERS[self.scoring]
- else:
- scorer = self.scoring
+ return self._fit(X, y, grid, **params)
- self.scorer_ = scorer
- pre_dispatch = self.pre_dispatch
- out = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
- pre_dispatch=pre_dispatch)(
- delayed(fit_grid_point)(X, y, base_clf, clf_params,
- train, test, scorer,
- self.verbose,
- **self.fit_params) for
- clf_params in grid for train, test in cv)
+class RandomizedSearchCV(BaseSearchCV):
+ """Randomized search on hyper parameters.
- # Out is a list of triplet: score, estimator, n_test_samples
- n_grid_points = len(list(grid))
- n_fits = len(out)
- n_folds = n_fits // n_grid_points
+ RandomizedSearchCV implements a "fit" method and a "predict" method like
+ any classifier except that the parameters of the classifier
+ used to predict is optimized by cross-validation.
- scores = list()
- cv_scores = list()
- for grid_start in range(0, n_fits, n_folds):
- n_test_samples = 0
- score = 0
- these_points = list()
- for this_score, clf_params, this_n_test_samples in \
- out[grid_start:grid_start + n_folds]:
- these_points.append(this_score)
- if self.iid:
- this_score *= this_n_test_samples
- score += this_score
- n_test_samples += this_n_test_samples
- if self.iid:
- score /= float(n_test_samples)
- scores.append((score, clf_params))
- cv_scores.append(these_points)
+ In constrast to GridSearchCV, not all parameter values are tried out, but
+ rather a fixed number of parameter settings is sampled from the specified
+ distributions. The number of parameter settings that are tried is
+ given by n_iter.
- cv_scores = np.asarray(cv_scores)
+ Parameters
+ ----------
+ estimator: object type that implements the "fit" and "predict" methods
+ A object of that type is instantiated for each parameter setting.
- # Note: we do not use max(out) to make ties deterministic even if
- # comparison on estimator instances is not deterministic
- if scorer is not None:
- greater_is_better = scorer.greater_is_better
- else:
- greater_is_better = True
+ param_distribution: dict
+ Dictionary with parameters names (string) as keys and distributions
+ or lists of parameters to try. Distributions must provide a ``rvs``
+ method for sampling (such as those from scipy.stats.distributions).
+ If a list is given, it is sampled uniformly.
- if greater_is_better:
- best_score = -np.inf
- else:
- best_score = np.inf
+ n_iter: int, default=10
+ Number of parameter settings that are sampled. n_iter trades
+ off runtime vs qualitiy of the solution.
- for score, params in scores:
- if ((score > best_score and greater_is_better)
- or (score < best_score and not greater_is_better)):
- best_score = score
- best_params = params
+ scoring : string or callable, optional
+ Either one of either a string ("zero_one", "f1", "roc_auc", ... for
+ classification, "mse", "r2",... for regression) or a callable.
+ See 'Scoring objects' in the model evaluation section of the user guide
+ for details.
- self.best_score_ = best_score
- self.best_params_ = best_params
+ fit_params : dict, optional
+ Parameters to pass to the fit method.
- if self.refit:
- # fit the best estimator using the entire dataset
- # clone first to work around broken estimators
- best_estimator = clone(base_clf).set_params(**best_params)
- if y is not None:
- best_estimator.fit(X, y, **self.fit_params)
- else:
- best_estimator.fit(X, **self.fit_params)
- self.best_estimator_ = best_estimator
- self._set_methods()
+ n_jobs: int, optional
+ Number of jobs to run in parallel (default 1).
- # Store the computed scores
- # XXX: the name is too specific, it shouldn't have
- # 'grid' in it. Also, we should be retrieving/storing variance
- self.grid_scores_ = [(clf_params, score, all_scores)
- for clf_params, (score, _), all_scores
- in zip(grid, scores, cv_scores)]
- return self
+ pre_dispatch: int, or string, optional
+ Controls the number of jobs that get dispatched during parallel
+ execution. Reducing this number can be useful to avoid an
+ explosion of memory consumption when more jobs get dispatched
+ than CPUs can process. This parameter can be:
- def score(self, X, y=None):
- if hasattr(self.best_estimator_, 'score'):
- return self.best_estimator_.score(X, y)
- if self.scorer_ is None:
- raise ValueError("No score function explicitly defined, "
- "and the estimator doesn't provide one %s"
- % self.best_estimator_)
- y_predicted = self.predict(X)
- return self.scorer(y, y_predicted)
+ - None, in which case all the jobs are immediatly
+ created and spawned. Use this for lightweight and
+ fast-running jobs, to avoid delays due to on-demand
+ spawning of the jobs
+
+ - An int, giving the exact number of total jobs that are
+ spawned
+
+ - A string, giving an expression as a function of n_jobs,
+ as in '2*n_jobs'
+
+ iid: boolean, optional
+ If True, the data is assumed to be identically distributed across
+ the folds, and the loss minimized is the total loss per sample,
+ and not the mean loss across the folds.
+
+ cv : integer or crossvalidation generator, optional
+ If an integer is passed, it is the number of fold (default 3).
+ Specific crossvalidation objects can be passed, see
+ sklearn.cross_validation module for the list of possible objects
+
+ refit: boolean
+ Refit the best estimator with the entire dataset.
+ If "False", it is impossible to make predictions using
+ this RandomizedSearchCV instance after fitting.
+
+ verbose: integer
+ Controls the verbosity: the higher, the more messages.
+
+
+ Attributes
+ ----------
+ `cv_scores_` : list of named tuples
+ Contains scores for all parameter combinations in param_grid.
+ Each entry corresponds to one parameter setting.
+ Each named tuple has the attributes:
+
+ * ``parameters``, a dict of parameter settings
+ * ``mean_validation_score``, the mean score over the
+ cross-validation folds
+ * ``cv_validation_scores``, the list of scores for each fold
+
+ `best_estimator_` : estimator
+ Estimator that was choosen by search, i.e. estimator
+ which gave highest score (or smallest loss if specified)
+ on the left out data.
+
+ `best_score_` : float
+ Score of best_estimator on the left out data.
+
+ `best_params_` : dict
+ Parameter setting that gave the best results on the hold out data.
+
+ Notes
+ -----
+ The parameters selected are those that maximize the score of the left out
+ data, unless an explicit score_func is passed in which case it is used
+ instead. If a loss function loss_func is passed, it overrides the score
+ functions and is minimized.
+
+ If `n_jobs` was set to a value higher than one, the data is copied for each
+ parameter setting(and not `n_jobs` times). This is done for efficiency
+ reasons if individual jobs take very little time, but may raise errors if
+ the dataset is large and not enough memory is available. A workaround in
+ this case is to set `pre_dispatch`. Then, the memory is copied only
+ `pre_dispatch` many times. A reasonable value for `pre_dispatch` is 2 *
+ `n_jobs`.
+
+ See Also
+ --------
+ :class:`GridSearchCV`:
+ Does exhaustive search over a grid of parameters.
+
+ :class:`ParameterSampler`:
+ A generator over parameter settins, constructed from
+ param_distributions.
+
+ """
+
+ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
+ loss_func=None, score_func=None, fit_params=None, n_jobs=1,
+ iid=True, refit=True, cv=None, verbose=0,
+ pre_dispatch='2*n_jobs'):
+
+ self.param_distributions = param_distributions
+ self.n_iter = n_iter
+ super(RandomizedSearchCV, self).__init__(
+ estimator, scoring, loss_func, score_func, fit_params, n_jobs, iid,
+ refit, cv, verbose, pre_dispatch)
+
+ def fit(self, X, y=None, **params):
+ """Run fit on the estimator with randomly drawn parameters.
+
+ Parameters
+ ----------
+
+ X: array-like, shape = [n_samples, n_features]
+ Training vector, where n_samples in the number of samples and
+ n_features is the number of features.
+
+ y: array-like, shape = [n_samples], optional
+ Target vector relative to X for classification;
+ None for unsupervised learning.
+
+ """
+ sampled_params = ParameterSampler(self.param_distributions,
+ self.n_iter)
+ return self._fit(X, y, sampled_params, **params)
View
33 sklearn/tests/test_grid_search.py
@@ -16,9 +16,12 @@
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_almost_equal
+from scipy.stats import distributions
+
from sklearn.base import BaseEstimator
-from sklearn.grid_search import GridSearchCV
from sklearn.datasets.samples_generator import make_classification, make_blobs
+from sklearn.grid_search import (GridSearchCV, RandomizedSearchCV,
+ ParameterSampler)
from sklearn.svm import LinearSVC, SVC
from sklearn.cluster import KMeans, MeanShift
from sklearn.metrics import f1_score
@@ -86,7 +89,8 @@ def test_grid_search():
assert_equal(grid_search.best_estimator_.foo_param, 2)
for i, foo_i in enumerate([1, 2, 3]):
- assert_true(grid_search.grid_scores_[i][0] == {'foo_param': foo_i})
+ assert_true(grid_search.cv_scores_[i][0]
+ == {'foo_param': foo_i})
# Smoke test the score:
grid_search.score(X, y)
@@ -341,6 +345,29 @@ def test_bad_estimator():
scoring='ari')
+def test_param_sampler():
+ # test basic properties of param sampler
+ param_distributions = {"kernel": ["rbf", "linear"],
+ "C": distributions.uniform(0, 1)}
+ sampler = ParameterSampler(param_distributions=param_distributions,
+ n_iter=10, random_state=0)
+ samples = [x for x in sampler]
+ assert_equal(len(samples), 10)
+ for sample in samples:
+ assert_true(sample["kernel"] in ["rbf", "linear"])
+ assert_true(0 <= sample["C"] <= 1)
+
+
+def test_randomized_search():
+ # very basic smoke test
+ X, y = make_classification(n_samples=200, n_features=100, random_state=0)
+
+ params = dict(C=distributions.expon())
+ search = RandomizedSearchCV(LinearSVC(), param_distributions=params)
+ search.fit(X, y)
+ assert_equal(len(search.cv_scores_), 10)
+
+
def test_grid_search_score_consistency():
# test that correct scores are used
from sklearn.metrics import auc_score
@@ -351,7 +378,7 @@ def test_grid_search_score_consistency():
grid_search = GridSearchCV(clf, {'C': Cs}, scoring=score)
grid_search.fit(X, y)
cv = StratifiedKFold(n_folds=3, y=y)
- for C, scores in zip(Cs, grid_search.grid_scores_):
+ for C, scores in zip(Cs, grid_search.cv_scores_):
clf.set_params(C=C)
scores = scores[2] # get the separate runs from grid scores
i = 0
View
2 sklearn/utils/testing.py
@@ -168,7 +168,7 @@ def quote(self, string, safe='/'):
"OutputCodeClassifier", "OneVsRestClassifier", "RFE",
"RFECV", "BaseEnsemble"]
# estimators that there is no way to default-construct sensibly
-other = ["Pipeline", "FeatureUnion", "GridSearchCV"]
+other = ["Pipeline", "FeatureUnion", "GridSearchCV", "RandomizedSearchCV"]
def all_estimators(include_meta_estimators=False, include_other=False,

0 comments on commit 0c94b55

Please sign in to comment.