Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] implements log-uniform random variable #11232

Merged
merged 42 commits into from Oct 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
513416c
basic implementation of loguniform
Nov 10, 2017
905b797
short demo that tests loguniform dist
Nov 10, 2017
ac1ddca
small tests to benchmark implementation
Dec 7, 2017
ee86407
MAINT: remove demo
stsievert Jun 11, 2018
81e1ff0
ENH: provide loguniform impl
stsievert Jun 11, 2018
bbd27d0
DOC: show example, document class
stsievert Jun 11, 2018
6161fe7
TST: make sure works with ParameterSampler
stsievert Jun 11, 2018
8f89799
DOC: better document
stsievert Jun 13, 2018
9f8675b
DOC: what's new
stsievert Feb 5, 2019
c2ad57d
TST: add test to make sure fairly uniform
stsievert Mar 19, 2019
f47c83e
Clarify when doesn't conform to scipy.stats API
stsievert Aug 30, 2019
93c395a
Update test_random.py
stsievert Aug 30, 2019
79d0bf7
Adds scipy.stats.rv_continuous inheritance
stsievert Aug 30, 2019
1762381
Merge branch 'master' into lograndom
stsievert Aug 30, 2019
9fe651e
DOC: add to whats new
stsievert Aug 30, 2019
1352c75
flake8
stsievert Aug 30, 2019
353c795
Add CDF
stsievert Aug 30, 2019
4d84c21
Clean tests
stsievert Aug 30, 2019
2a03f25
Alias stats.reciprrocal
stsievert Sep 10, 2019
6162180
edit other tests/examples
stsievert Sep 10, 2019
98003a9
Add example use (and remove unnecessary test)
stsievert Sep 22, 2019
a831af2
Merge branch 'master' into lograndom
stsievert Sep 22, 2019
ec6eb92
lint
stsievert Sep 22, 2019
ad966f4
Reorganize into fixes
stsievert Sep 23, 2019
f0d28fb
More moving
stsievert Sep 23, 2019
541668e
More fixing
stsievert Sep 23, 2019
686ad99
fixes
stsievert Sep 23, 2019
d7beefc
Add 'requires_positive_X' back in
stsievert Sep 23, 2019
b429b99
delete
stsievert Sep 23, 2019
7010a23
Add requires_positive_X back in
stsievert Sep 23, 2019
cbec9b8
Reorder?
stsievert Sep 23, 2019
450fa37
Clean imports
stsievert Sep 23, 2019
cef5534
clean imports
stsievert Sep 23, 2019
0ef7f74
Add scipy link in grid search
stsievert Sep 23, 2019
9a7b7b0
Merge branch 'master' into lograndom
stsievert Sep 24, 2019
bd32a62
Respond to review
stsievert Sep 24, 2019
bd98e24
Merge branch 'lograndom' of https://github.com/stsievert/scikit-learn…
stsievert Sep 24, 2019
27429a9
[doc build] trigger ci
glemaitre Sep 24, 2019
ac34b67
[doc build] fix issue with integral in power
glemaitre Sep 24, 2019
cb86fbb
Add line break
stsievert Sep 24, 2019
fabb9aa
Merge branch 'lograndom' of https://github.com/stsievert/scikit-learn…
stsievert Sep 24, 2019
80cc4d8
[build doc] fix code snippet in doc
glemaitre Oct 2, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions doc/modules/grid_search.rst
Expand Up @@ -121,6 +121,7 @@ discrete choices (which will be sampled uniformly) can be specified::
This example uses the ``scipy.stats`` module, which contains many useful
distributions for sampling parameters, such as ``expon``, ``gamma``,
``uniform`` or ``randint``.

In principle, any function can be passed that provides a ``rvs`` (random
variate sample) method to sample a value. A call to the ``rvs`` function should
provide independent random samples from possible parameter values on
Expand All @@ -139,6 +140,22 @@ For continuous parameters, such as ``C`` above, it is important to specify
a continuous distribution to take full advantage of the randomization. This way,
increasing ``n_iter`` will always lead to a finer search.

A continuous log-uniform random variable is available through
:class:`~sklearn.utils.fixes.loguniform`. This is a continuous version of
log-spaced parameters. For example to specify ``C`` above, ``loguniform(1,
100)`` can be used instead of ``[1, 10, 100]`` or ``np.logspace(0, 2,
num=1000)``. This is an alias to SciPy's `stats.reciprocal
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.reciprocal.html>`_.

Mirroring the example above in grid search, we can specify a continuous random
variable that is log-uniformly distributed between ``1e0`` and ``1e3``::

from sklearn.utils.fixes import loguniform
{'C': loguniform(1e0, 1e3),
'gamma': loguniform(1e-4, 1e-3),
'kernel': ['rbf'],
'class_weight':['balanced', None]}

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_model_selection_plot_randomized_search.py` compares the usage and efficiency
Expand Down
7 changes: 7 additions & 0 deletions doc/whats_new/v0.22.rst
Expand Up @@ -554,6 +554,13 @@ Changelog
:func:`~utils.estimator_checks.parametrize_with_checks`, to parametrize
estimator checks for a list of estimators. :pr:`14381` by `Thomas Fan`_.

- A new random variable, :class:`utils.fixes.loguniform` implements a
log-uniform random variable (e.g., for use in RandomizedSearchCV).
For example, the outcomes ``1``, ``10`` and ``100`` are all equally likely
for ``loguniform(1, 100)``. See :issue:`11232` by
:user:`Scott Sievert <stsievert>` and :user:`Nathaniel Saul <sauln>`,
and `SciPy PR 10815 <https://github.com/scipy/scipy/pull/10815>`.

- |API| The following utils have been deprecated and are now private:
- ``choose_check_classifiers_labels``
- ``enforce_estimator_tags_y``
Expand Down
37 changes: 17 additions & 20 deletions examples/model_selection/plot_randomized_search.py
Expand Up @@ -12,8 +12,8 @@
parameters. The result in parameter settings is quite similar, while the run
time for randomized search is drastically lower.

The performance is slightly worse for the randomized search, though this
is most likely a noise effect and would not carry over to a held-out test set.
The performance is may slightly worse for the randomized search, and is likely
due to a noise effect and would not carry over to a held-out test set.

Note that in practice, one would not search over this many different parameters
simultaneously using grid search, but pick only the ones deemed most important.
Expand All @@ -23,18 +23,19 @@
import numpy as np

from time import time
stsievert marked this conversation as resolved.
Show resolved Hide resolved
from scipy.stats import randint as sp_randint
import scipy.stats as stats
from sklearn.utils.fixes import loguniform

from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import SGDClassifier
stsievert marked this conversation as resolved.
Show resolved Hide resolved

# get some data
X, y = load_digits(return_X_y=True)

# build a classifier
clf = RandomForestClassifier(n_estimators=20)
clf = SGDClassifier(loss='hinge', penalty='elasticnet',
fit_intercept=True)


# Utility function to report best scores
Expand All @@ -43,19 +44,17 @@ def report(results, n_top=3):
candidates = np.flatnonzero(results['rank_test_score'] == i)
for candidate in candidates:
print("Model with rank: {0}".format(i))
print("Mean validation score: {0:.3f} (std: {1:.3f})".format(
results['mean_test_score'][candidate],
results['std_test_score'][candidate]))
print("Mean validation score: {0:.3f} (std: {1:.3f})"
.format(results['mean_test_score'][candidate],
results['std_test_score'][candidate]))
print("Parameters: {0}".format(results['params'][candidate]))
print("")


# specify parameters and distributions to sample from
param_dist = {"max_depth": [3, None],
"max_features": sp_randint(1, 11),
"min_samples_split": sp_randint(2, 11),
"bootstrap": [True, False],
"criterion": ["gini", "entropy"]}
param_dist = {'average': [True, False],
'l1_ratio': stats.uniform(0, 1),
'alpha': loguniform(1e-4, 1e0)}

# run randomized search
n_iter_search = 20
Expand All @@ -69,11 +68,9 @@ def report(results, n_top=3):
report(random_search.cv_results_)

# use a full grid over all parameters
param_grid = {"max_depth": [3, None],
"max_features": [1, 3, 10],
"min_samples_split": [2, 3, 10],
"bootstrap": [True, False],
"criterion": ["gini", "entropy"]}
param_grid = {'average': [True, False],
'l1_ratio': np.linspace(0, 1, num=10),
'alpha': np.power(10, np.arange(-4, 1, dtype=float))}

# run grid search
grid_search = GridSearchCV(clf, param_grid=param_grid)
Expand Down
50 changes: 50 additions & 0 deletions sklearn/utils/fixes.py
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import scipy.sparse as sp
import scipy
import scipy.stats
from scipy.sparse.linalg import lsqr as sparse_lsqr # noqa


Expand Down Expand Up @@ -256,3 +257,52 @@ def _joblib_parallel_args(**kwargs):
if require == 'sharedmem':
args['backend'] = 'threading'
return args


class loguniform(scipy.stats.reciprocal):
"""A class supporting log-uniform random variables.

Parameters
----------
low : float
The minimum value
high : float
The maximum value

Methods
-------
rvs(self, size=None, random_state=None)
Generate log-uniform random variables
stsievert marked this conversation as resolved.
Show resolved Hide resolved

The most useful method for Scikit-learn usage is highlighted here.
For a full list, see
`scipy.stats.reciprocal
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.reciprocal.html>`_.
This list includes all functions of ``scipy.stats`` continuous
distributions such as ``pdf``.

Notes
-----
This class generates values between ``low`` and ``high`` or

low <= loguniform(low, high).rvs() <= high

The logarithmic probability density function (PDF) is uniform. When
``x`` is a uniformly distributed random variable between 0 and 1, ``10**x``
are random variales that are equally likely to be returned.

This class is an alias to ``scipy.stats.reciprocal``, which uses the
reciprocal distribution:
https://en.wikipedia.org/wiki/Reciprocal_distribution

Examples
--------

>>> from sklearn.utils.fixes import loguniform
>>> rv = loguniform(1e-3, 1e1)
>>> rvs = rv.rvs(random_state=42, size=1000)
>>> rvs.min() # doctest: +SKIP
0.0010435856341129003
>>> rvs.max() # doctest: +SKIP
9.97403052786026
"""
27 changes: 27 additions & 0 deletions sklearn/utils/tests/test_fixes.py
Expand Up @@ -3,16 +3,19 @@
# Lars Buitinck
# License: BSD 3 clause

import math
import pickle

import numpy as np
import pytest
import scipy.stats

from sklearn.utils.testing import assert_array_equal

from sklearn.utils.fixes import MaskedArray
from sklearn.utils.fixes import _joblib_parallel_args
from sklearn.utils.fixes import _object_dtype_isnan
from sklearn.utils.fixes import loguniform


def test_masked_array_obj_dtype_pickleable():
Expand Down Expand Up @@ -68,3 +71,27 @@ def test_object_dtype_isnan(dtype, val):
mask = _object_dtype_isnan(X)

assert_array_equal(mask, expected_mask)


@pytest.mark.parametrize("low,high,base",
[(-1, 0, 10), (0, 2, np.exp(1)), (-1, 1, 2)])
def test_loguniform(low, high, base):
rv = loguniform(base ** low, base ** high)
assert isinstance(rv, scipy.stats._distn_infrastructure.rv_frozen)
rvs = rv.rvs(size=2000, random_state=0)

# Test the basics; right bounds, right size
assert (base ** low <= rvs).all() and (rvs <= base ** high).all()
assert len(rvs) == 2000

# Test that it's actually (fairly) uniform
log_rvs = np.array([math.log(x, base) for x in rvs])
counts, _ = np.histogram(log_rvs)
assert counts.mean() == 200
assert np.abs(counts - counts.mean()).max() <= 40

# Test that random_state works
assert (
loguniform(base ** low, base ** high).rvs(random_state=0)
== loguniform(base ** low, base ** high).rvs(random_state=0)
)