Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Successive halving for faster parameter search #13900

Merged
merged 124 commits into from Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from 121 commits
Commits
Show all changes
124 commits
Select commit Hold shift + click to select a range
0d1401f
More flexible grid search interface
NicolasHug Feb 12, 2019
3cb78e1
Merge branch 'master' into grid_search_pass_X
NicolasHug Feb 13, 2019
80963f3
added info dict parameter
NicolasHug Feb 13, 2019
326fe39
Put back removed test
NicolasHug Feb 13, 2019
bae0d95
renamed info into more_results
NicolasHug Feb 14, 2019
abbb606
Merge branch 'master' into grid_search_pass_X
NicolasHug Mar 24, 2019
7d4cb56
Passed grroups as well since we need n_to use get_n_splits(X, y, groups)
NicolasHug Mar 26, 2019
58a931f
Merge branch 'grid_search_pass_X' into successive_halving
NicolasHug May 17, 2019
cdb6b50
port
NicolasHug May 17, 2019
ab29554
pep8
NicolasHug May 17, 2019
c725fee
dabl -> sklearn
NicolasHug May 17, 2019
c9c87c3
add _required_parameters
NicolasHug May 17, 2019
9e7fc3c
skipping check in rst file if pandas not installed
NicolasHug May 17, 2019
81cee9b
Update sklearn/model_selection/_search_successive_halving.py
NicolasHug May 22, 2019
a8e4ca8
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug May 25, 2019
79fae17
Merge branch 'successive_halving' of github.com:NicolasHug/scikit-lea…
NicolasHug May 25, 2019
4d79f7c
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Jul 29, 2019
cbed1e3
renamed into GridHalvingSearchCV and RandomHalvingSearchCV
NicolasHug Jul 29, 2019
0c1fd07
Addressed thomas' comments
NicolasHug Jul 29, 2019
5d70859
repr
NicolasHug Jul 30, 2019
55d82d0
removed passing group as a parameter to evaluate_candidates
NicolasHug Jul 30, 2019
00c99d5
Joels comments
NicolasHug Jul 30, 2019
40d36db
pep8
NicolasHug Jul 30, 2019
7662476
reorganized user user guide
NicolasHug Jul 30, 2019
00df22d
renaming
NicolasHug Jul 31, 2019
1b1554e
update user guide
NicolasHug Jul 31, 2019
3ff395e
remove groups support + pass fit_params
NicolasHug Jul 31, 2019
1cf9bf7
parameter renaming
NicolasHug Jul 31, 2019
a91b119
pep8
NicolasHug Jul 31, 2019
935525b
r_i -> resource_iter
NicolasHug Jul 31, 2019
9ad17c6
fixed r_i issues
NicolasHug Aug 1, 2019
618d637
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Aug 23, 2019
64bcc93
examples + removed use of word budget
NicolasHug Aug 23, 2019
a438890
Added inpute checking tests
NicolasHug Aug 23, 2019
98161b3
added cv_resutlts_ user guide
NicolasHug Aug 23, 2019
19243d6
minor title change
NicolasHug Aug 23, 2019
5203a30
fixed doc layout
NicolasHug Sep 20, 2019
9b88b76
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Sep 20, 2019
ed5de25
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Oct 4, 2019
1c67463
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Nov 11, 2019
9f049ec
Addressed some comments
NicolasHug Nov 11, 2019
b02c53e
properly pass down fit_params
NicolasHug Nov 11, 2019
fd4a41d
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Nov 13, 2019
4d720ad
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Nov 16, 2019
db736a5
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Dec 3, 2019
0e1e38c
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Dec 13, 2019
243e02a
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Feb 11, 2020
866c08e
change default value of force_exhaust_resources and update doc
NicolasHug Feb 11, 2020
d7c4fd8
should fix doc
NicolasHug Feb 11, 2020
3d6d952
Used check_fit_params
NicolasHug Feb 12, 2020
51e4dbd
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Mar 12, 2020
cabef66
Update section about min_resources and number of candidates
NicolasHug Mar 12, 2020
9d9a5d6
Clarified ratio section
NicolasHug Mar 12, 2020
0eace47
Use ~ to refer to classes
NicolasHug Mar 12, 2020
8eb7fe7
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Jul 8, 2020
39bf2e2
fixed doc checks
NicolasHug Jul 8, 2020
9a303cc
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Jul 9, 2020
1a0808e
Apply suggestions from code review
NicolasHug Jul 9, 2020
d4d7d10
Addressed easy comments from Joel
NicolasHug Jul 9, 2020
dd69a0e
Merge branch 'successive_halving' of github.com:NicolasHug/scikit-lea…
NicolasHug Jul 9, 2020
2cffdc3
missed some
NicolasHug Jul 9, 2020
1403dfa
updated docstring of run_search
NicolasHug Jul 20, 2020
446666c
Used f strings instead of format
NicolasHug Jul 20, 2020
ed4f86d
remove candidate duplication checks
NicolasHug Jul 20, 2020
e09229a
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Jul 20, 2020
c86be6d
fix example
NicolasHug Jul 20, 2020
bb178a0
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Jul 24, 2020
907ed9a
Addressed easy comments
NicolasHug Jul 24, 2020
dcb7f46
rotate ticks labels
NicolasHug Jul 24, 2020
22d1986
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Jul 29, 2020
ac23683
Added discussion in the intro as suggested by Joel
NicolasHug Jul 29, 2020
33b60d7
Split examples into sections
NicolasHug Jul 29, 2020
762c889
minor changes
NicolasHug Jul 29, 2020
f218a9c
remove force_exhaust_budget and introduce min_resources=exhaust
NicolasHug Jul 29, 2020
c19f989
some minor validation
NicolasHug Jul 29, 2020
a49acc3
Added a n_resources_ attribute
NicolasHug Jul 29, 2020
08dd96e
update examples
NicolasHug Jul 30, 2020
57d9466
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Aug 2, 2020
c3ee547
Addressed comments
NicolasHug Aug 2, 2020
97e6040
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Aug 7, 2020
b193999
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Aug 10, 2020
31d8195
passing CV instead of X,y
NicolasHug Aug 10, 2020
cdebb6e
minor revert for handling fit_params
NicolasHug Aug 10, 2020
99072bf
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Aug 17, 2020
0507093
updated docs
NicolasHug Aug 17, 2020
749d941
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Aug 21, 2020
be87756
fix len
NicolasHug Aug 21, 2020
beda557
whatsnew
NicolasHug Aug 21, 2020
982a2ae
Merge branch 'fix_len_param_sampler' into successive_halving
NicolasHug Aug 21, 2020
d807d26
Add test for sampling when all_list
NicolasHug Aug 21, 2020
0350176
minor change to top-k
NicolasHug Aug 21, 2020
f83a436
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Aug 21, 2020
0bc44a1
Force CV splits to be consistent across calls
NicolasHug Aug 21, 2020
88840a5
reorder parameters
NicolasHug Aug 21, 2020
084ca7c
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Aug 21, 2020
c9ec1c4
reduced diff
NicolasHug Aug 21, 2020
b702abc
added tests for top_k
NicolasHug Aug 21, 2020
4c7a1b1
put back doc for groups
NicolasHug Aug 21, 2020
79cac35
not sure what went wrong
NicolasHug Aug 21, 2020
7c55a29
put import at its place
NicolasHug Aug 23, 2020
72ae482
some comment
NicolasHug Aug 23, 2020
1b71491
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Aug 30, 2020
a68bac4
Addressed comments
NicolasHug Aug 30, 2020
5bf1586
Added tests for cv_results_ and base estimator inputs
NicolasHug Aug 30, 2020
ee4724b
pep8
NicolasHug Aug 30, 2020
c35c48d
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Aug 31, 2020
d8849f5
avoid monkeypatching
NicolasHug Aug 31, 2020
be849cb
rename df
NicolasHug Aug 31, 2020
0064d49
use Joel's suggestions for testing masks
NicolasHug Aug 31, 2020
af5a809
Made it experimental
NicolasHug Aug 31, 2020
2b39677
Should fix docs
NicolasHug Aug 31, 2020
46afbca
whats new entry
NicolasHug Sep 2, 2020
7a2cd4d
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Sep 5, 2020
d8c2519
Apply suggestions from code review
NicolasHug Sep 5, 2020
669fdce
Merge branch 'successive_halving' of github.com:NicolasHug/scikit-lea…
NicolasHug Sep 5, 2020
b537ce7
Addressed comments to docs
NicolasHug Sep 5, 2020
54a6276
Addressed comments in examples
NicolasHug Sep 5, 2020
8adf44e
minor doc update
NicolasHug Sep 6, 2020
3d96178
minor renaming in UG
NicolasHug Sep 7, 2020
e5bb4bb
forgot some
NicolasHug Sep 7, 2020
9d2a628
some sad note about splitter statefulness :'(
NicolasHug Sep 8, 2020
143c4e8
Merge branch 'master' of github.com:scikit-learn/scikit-learn into su…
NicolasHug Sep 8, 2020
820ceb5
Addressed comments
NicolasHug Sep 8, 2020
645b50d
ratio -> factor
NicolasHug Sep 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/conf.py
Expand Up @@ -356,6 +356,7 @@ def __call__(self, directory):
# discovered properly by sphinx
from sklearn.experimental import enable_hist_gradient_boosting # noqa
from sklearn.experimental import enable_iterative_imputer # noqa
from sklearn.experimental import enable_successive_halving # noqa


def make_carousel_thumbs(app, exception):
Expand Down
9 changes: 9 additions & 0 deletions doc/conftest.py
Expand Up @@ -57,6 +57,13 @@ def setup_impute():
raise SkipTest("Skipping impute.rst, pandas not installed")


def setup_grid_search():
try:
import pandas # noqa
except ImportError:
raise SkipTest("Skipping grid_search.rst, pandas not installed")


def setup_unsupervised_learning():
try:
import skimage # noqa
Expand Down Expand Up @@ -86,5 +93,7 @@ def pytest_runtest_setup(item):
raise SkipTest('FeatureHasher is not compatible with PyPy')
elif fname.endswith('modules/impute.rst'):
setup_impute()
elif fname.endswith('modules/grid_search.rst'):
setup_grid_search()
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
elif fname.endswith('statistical_inference/unsupervised_learning.rst'):
setup_unsupervised_learning()
2 changes: 2 additions & 0 deletions doc/modules/classes.rst
Expand Up @@ -1193,9 +1193,11 @@ Hyper-parameter optimizers
:template: class.rst

model_selection.GridSearchCV
model_selection.HalvingGridSearchCV
model_selection.ParameterGrid
model_selection.ParameterSampler
model_selection.RandomizedSearchCV
model_selection.HalvingRandomSearchCV


Model validation
Expand Down
406 changes: 390 additions & 16 deletions doc/modules/grid_search.rst

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions doc/whats_new/v0.24.rst
Expand Up @@ -412,6 +412,14 @@ Changelog
:pr:`17478` by :user:`Teon Brooks <teonbrooks>` and
:user:`Mohamed Maskani <maskani-moh>`.

- |Feature| Added (experimental) parameter search estimators
:class:`model_selection.HalvingRandomSearchCV` and
:class:`model_selection.HalvingGridSearchCV` which implement Successive
Halving, and can be used as a drop-in replacements for
:class:`model_selection.RandomizedSearchCV` and
:class:`model_selection.GridSearchCV`. :pr:`13900` by `Nicolas Hug`_, `Joel
Nothman`_ and `Andreas Müller`_.
Comment on lines +420 to +421
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman @amueller added you guys here considering the amount of design you both made


- |Fix| Fixed the `len` of :class:`model_selection.ParameterSampler` when
all distributions are lists and `n_iter` is more than the number of unique
parameter combinations. :pr:`18222` by `Nicolas Hug`_.
Expand Down
122 changes: 122 additions & 0 deletions examples/model_selection/plot_successive_halving_heatmap.py
@@ -0,0 +1,122 @@
"""
Comparison between grid search and successive halving
=====================================================

This example compares the parameter search performed by
:class:`~sklearn.model_selection.HalvingGridSearchCV` and
:class:`~sklearn.model_selection.GridSearchCV`.

"""
from time import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.svm import SVC
from sklearn import datasets
from sklearn.model_selection import GridSearchCV
from sklearn.experimental import enable_successive_halving # noqa
from sklearn.model_selection import HalvingGridSearchCV


print(__doc__)

# %%
# We first define the parameter space for an :class:`~sklearn.svm.SVC`
# estimator, and compute the time required to train a
# :class:`~sklearn.model_selection.HalvingGridSearchCV` instance, as well as a
# :class:`~sklearn.model_selection.GridSearchCV` instance.

rng = np.random.RandomState(0)
X, y = datasets.make_classification(n_samples=1000, random_state=rng)

gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
Cs = [1, 10, 100, 1e3, 1e4, 1e5]
param_grid = {'gamma': gammas, 'C': Cs}

clf = SVC(random_state=rng)

tic = time()
gsh = HalvingGridSearchCV(estimator=clf, param_grid=param_grid, ratio=2,
random_state=rng)
gsh.fit(X, y)
gsh_time = time() - tic

tic = time()
gs = GridSearchCV(estimator=clf, param_grid=param_grid)
gs.fit(X, y)
gs_time = time() - tic

# %%
# We now plot heatmaps for both search estimators.


def make_heatmap(ax, gs, is_sh=False, make_cbar=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we can't easily reuse any of the confusion matrix plot? I've been nagging @thomasjpfan to do a grid-search visualizer ;) But I guess pandas out is nice, too.

"""Helper to make a heatmap."""
results = pd.DataFrame.from_dict(gs.cv_results_)
results['params_str'] = results.params.apply(str)
if is_sh:
# SH dataframe: get mean_test_score values for the highest iter
scores_matrix = results.sort_values('iter').pivot_table(
index='param_gamma', columns='param_C',
values='mean_test_score', aggfunc='last'
)
else:
scores_matrix = results.pivot(index='param_gamma', columns='param_C',
values='mean_test_score')

im = ax.imshow(scores_matrix)

ax.set_xticks(np.arange(len(Cs)))
ax.set_xticklabels(['{:.0E}'.format(x) for x in Cs])
ax.set_xlabel('C', fontsize=15)

ax.set_yticks(np.arange(len(gammas)))
ax.set_yticklabels(['{:.0E}'.format(x) for x in gammas])
ax.set_ylabel('gamma', fontsize=15)

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")

if is_sh:
iterations = results.pivot_table(index='param_gamma',
columns='param_C', values='iter',
aggfunc='max').values
for i in range(len(gammas)):
for j in range(len(Cs)):
ax.text(j, i, iterations[i, j],
ha="center", va="center", color="w", fontsize=20)

if make_cbar:
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax)
cbar_ax.set_ylabel('mean_test_score', rotation=-90, va="bottom",
fontsize=15)


fig, axes = plt.subplots(ncols=2, sharey=True)
ax1, ax2 = axes

make_heatmap(ax1, gsh, is_sh=True)
make_heatmap(ax2, gs, make_cbar=True)

ax1.set_title('Successive Halving\ntime = {:.3f}s'.format(gsh_time),
fontsize=15)
ax2.set_title('GridSearch\ntime = {:.3f}s'.format(gs_time), fontsize=15)

plt.show()

# %%
# The heatmaps show the mean test score of the parameter combinations for an
# :class:`~sklearn.svm.SVC` instance. The
# :class:`~sklearn.model_selection.HalvingGridSearchCV` also shows the
# iteration at which the combinations where last used. The combinations marked
# as ``0`` were only evaluated at the first iteration, while the ones with
# ``5`` are the parameter combinations that are considered the best ones.
#
# We can see that the :class:`~sklearn.model_selection.HalvingGridSearchCV`
# class is able to find parameter combinations that are just as accurate as
# :class:`~sklearn.model_selection.GridSearchCV`, in much less time.
84 changes: 84 additions & 0 deletions examples/model_selection/plot_successive_halving_iterations.py
@@ -0,0 +1,84 @@
"""
Successive halving Iterations
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved
=============================

This example illustrates how a successive halving search (
:class:`~sklearn.model_selection.HalvingGridSearchCV` and
:class:`~sklearn.model_selection.HalvingRandomSearchCV`) iteratively chooses
the best parameter combination out of multiple candidates.

"""
import pandas as pd
from sklearn import datasets
import matplotlib.pyplot as plt
from scipy.stats import randint
import numpy as np

from sklearn.experimental import enable_successive_halving # noqa
from sklearn.model_selection import HalvingRandomSearchCV
from sklearn.ensemble import RandomForestClassifier


print(__doc__)

# %%
# We first define the parameter space and train a
# :class:`~sklearn.model_selection.HalvingRandomSearchCV` instance.

rng = np.random.RandomState(0)

X, y = datasets.make_classification(n_samples=700, random_state=rng)

clf = RandomForestClassifier(n_estimators=20, random_state=rng)

param_dist = {"max_depth": [3, None],
"max_features": randint(1, 11),
"min_samples_split": randint(2, 11),
"bootstrap": [True, False],
"criterion": ["gini", "entropy"]}

rsh = HalvingRandomSearchCV(
estimator=clf,
param_distributions=param_dist,
ratio=2,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason to use ratio=2? I thought ratio=3 was the more common choice.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed. In this case ratio=2 makes the plot look nicer (longer) and makes it easier illustrate the SH process in the narrative docs

random_state=rng)
rsh.fit(X, y)

# %%
# We can now use the `cv_results_` attribute of the search estimator to inspect
# and plot the evolution of the search.

results = pd.DataFrame(rsh.cv_results_)
results['params_str'] = results.params.apply(str)
results.drop_duplicates(subset=('params_str', 'iter'), inplace=True)
mean_scores = results.pivot(index='iter', columns='params_str',
values='mean_test_score')
ax = mean_scores.plot(legend=False, alpha=.6)

labels = [
f'iter={i}\nn_samples={rsh.n_resources_[i]}\n'
f'n_candidates={rsh.n_candidates_[i]}'
for i in range(rsh.n_iterations_)
]
ax.set_xticklabels(labels, rotation=45, multialignment='left')
ax.set_title('Scores of candidates over iterations')
ax.set_ylabel('mean test score', fontsize=15)
ax.set_xlabel('iterations', fontsize=15)
plt.tight_layout()
plt.show()

# %%
# Number of candidates and amount of resource at each iteration
# -------------------------------------------------------------
#
# At the first iteration, a small amount of resources is used. The resource
# here is the number of samples that the estimators are trained on. All
# candidates are evaluated.
#
# At the second iteration, only the best half of the candidates is evaluated.
# The number of allocated resources is doubled: candidates are evaluated on
# twice as many samples.
#
# This process is repeated until the last iteration, where only 2 candidates
# are left. The best candidate is the candidate that has the best score at the
# last iteration.
35 changes: 35 additions & 0 deletions sklearn/experimental/enable_successive_halving.py
@@ -0,0 +1,35 @@
"""Enables Successive Halving search-estimators

The API and results of these estimators might change without any deprecation
cycle.

Importing this file dynamically sets the
:class:`~sklearn.model_selection.HalvingRandomSearchCV` and
:class:`~sklearn.model_selection.HalvingGridSearchCV` as attributes of the
`model_selection` module::

>>> # explicitly require this experimental feature
>>> from sklearn.experimental import enable_successive_halving # noqa
>>> # now you can import normally from model_selection
>>> from sklearn.model_selection import HalvingRandomSearchCV
>>> from sklearn.model_selection import HalvingGridSearchCV


The ``# noqa`` comment comment can be removed: it just tells linters like
flake8 to ignore the import, which appears as unused.
"""

from ..model_selection._search_successive_halving import (
HalvingRandomSearchCV,
HalvingGridSearchCV
)

from .. import model_selection

# use settattr to avoid mypy errors when monkeypatching
setattr(model_selection, "HalvingRandomSearchCV",
HalvingRandomSearchCV)
setattr(model_selection, "HalvingGridSearchCV",
HalvingGridSearchCV)

model_selection.__all__ += ['HalvingRandomSearchCV', 'HalvingGridSearchCV']
43 changes: 43 additions & 0 deletions sklearn/experimental/tests/test_enable_successive_halving.py
@@ -0,0 +1,43 @@
"""Tests for making sure experimental imports work as expected."""

import textwrap

from sklearn.utils._testing import assert_run_python_script


def test_imports_strategies():
# Make sure different import strategies work or fail as expected.

# Since Python caches the imported modules, we need to run a child process
# for every test case. Else, the tests would not be independent
# (manually removing the imports from the cache (sys.modules) is not
# recommended and can lead to many complications).

good_import = """
from sklearn.experimental import enable_successive_halving
from sklearn.model_selection import HalvingGridSearchCV
from sklearn.model_selection import HalvingRandomSearchCV
"""
assert_run_python_script(textwrap.dedent(good_import))

good_import_with_model_selection_first = """
import sklearn.model_selection
from sklearn.experimental import enable_successive_halving
from sklearn.model_selection import HalvingGridSearchCV
from sklearn.model_selection import HalvingRandomSearchCV
"""
assert_run_python_script(
textwrap.dedent(good_import_with_model_selection_first)
)

bad_imports = """
import pytest

with pytest.raises(ImportError):
from sklearn.model_selection import HalvingGridSearchCV

import sklearn.experimental
with pytest.raises(ImportError):
from sklearn.model_selection import HalvingGridSearchCV
"""
assert_run_python_script(textwrap.dedent(bad_imports))
14 changes: 12 additions & 2 deletions sklearn/model_selection/__init__.py
@@ -1,3 +1,5 @@
import typing

from ._split import BaseCrossValidator
from ._split import KFold
from ._split import GroupKFold
Expand Down Expand Up @@ -29,7 +31,15 @@
from ._search import ParameterSampler
from ._search import fit_grid_point

__all__ = ('BaseCrossValidator',
if typing.TYPE_CHECKING:
# Avoid errors in type checkers (e.g. mypy) for experimental estimators.
# TODO: remove this check once the estimator is no longer experimental.
from ._search_successive_halving import ( # noqa
HalvingGridSearchCV, HalvingRandomSearchCV
)


__all__ = ['BaseCrossValidator',
'GridSearchCV',
'TimeSeriesSplit',
'KFold',
Expand All @@ -56,4 +66,4 @@
'learning_curve',
'permutation_test_score',
'train_test_split',
'validation_curve')
'validation_curve']