Skip to content

Commit

Permalink
MNT fix MacOS failure with more stable estimator_checks (#18667)
Browse files Browse the repository at this point in the history
* More stable estimator_checks

* Revert back to default check_estimators_dtypes

* delayed wrapper does not expose check_pickle kwarg
  • Loading branch information
ogrisel committed Oct 22, 2020
1 parent d933c20 commit 8d10285
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
8 changes: 2 additions & 6 deletions sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,13 +713,11 @@ class from an array representing our data set and ask who's
parse_version(joblib.__version__) < parse_version('0.12'))
if old_joblib:
# Deal with change of API in joblib
delayed_query = delayed(_tree_query_parallel_helper)
parallel_kwargs = {"backend": "threading"}
else:
delayed_query = delayed(_tree_query_parallel_helper)
parallel_kwargs = {"prefer": "threads"}
chunked_results = Parallel(n_jobs, **parallel_kwargs)(
delayed_query(
delayed(_tree_query_parallel_helper)(
self._tree, X[s], n_neighbors, return_distance)
for s in gen_even_slices(X.shape[0], n_jobs)
)
Expand Down Expand Up @@ -1038,13 +1036,11 @@ class from an array representing our data set and ask who's
"or set algorithm='brute'" % self._fit_method)

n_jobs = effective_n_jobs(self.n_jobs)
delayed_query = delayed(_tree_query_radius_parallel_helper)
if parse_version(joblib.__version__) < parse_version('0.12'):
# Deal with change of API in joblib
delayed_query = delayed(_tree_query_radius_parallel_helper,
check_pickle=False)
parallel_kwargs = {"backend": "threading"}
else:
delayed_query = delayed(_tree_query_radius_parallel_helper)
parallel_kwargs = {"prefer": "threads"}

chunked_results = Parallel(n_jobs, **parallel_kwargs)(
Expand Down
3 changes: 1 addition & 2 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from ._testing import raises
from . import is_scalar_nan

from ..discriminant_analysis import LinearDiscriminantAnalysis
from ..linear_model import LogisticRegression
from ..linear_model import Ridge

Expand Down Expand Up @@ -346,7 +345,7 @@ def _construct_instance(Estimator):
if issubclass(Estimator, RegressorMixin):
estimator = Estimator(Ridge())
else:
estimator = Estimator(LinearDiscriminantAnalysis())
estimator = Estimator(LogisticRegression(C=1))
elif required_parameters in (['estimators'],):
# Heterogeneous ensemble classes (i.e. stacking, voting)
if issubclass(Estimator, RegressorMixin):
Expand Down

0 comments on commit 8d10285

Please sign in to comment.