diff --git a/benchmarks/bench_saga.py b/benchmarks/bench_saga.py index 4e0e2a81875bd..492527d7e4c67 100644 --- a/benchmarks/bench_saga.py +++ b/benchmarks/bench_saga.py @@ -7,7 +7,8 @@ import time import os -from joblib import delayed, Parallel +from joblib import Parallel +from sklearn.utils.fixes import delayed import matplotlib.pyplot as plt import numpy as np diff --git a/build_tools/circle/linting.sh b/build_tools/circle/linting.sh index c23cd80d3afd6..cad62b496d294 100755 --- a/build_tools/circle/linting.sh +++ b/build_tools/circle/linting.sh @@ -172,3 +172,11 @@ then echo "$doctest_directive" exit 1 fi + +joblib_import="$(git grep -l -A 10 -E "joblib import.+delayed" -- "*.py" ":!sklearn/utils/_joblib.py" ":!sklearn/utils/fixes.py")" + +if [ ! -z "$joblib_import" ]; then + echo "Use from sklearn.utils.fixes import delayed instead of joblib delayed. The following files contains imports to joblib.delayed:" + echo "$joblib_import" + exit 1 +fi diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 06acb08709e5d..d1b0febb15605 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -13,7 +13,7 @@ from math import log import numpy as np -from joblib import delayed, Parallel +from joblib import Parallel from scipy.special import expit from scipy.special import xlogy @@ -24,6 +24,7 @@ MetaEstimatorMixin) from .preprocessing import label_binarize, LabelBinarizer from .utils import check_array, indexable, column_or_1d +from .utils.fixes import delayed from .utils.validation import check_is_fitted, check_consistent_length from .utils.validation import _check_sample_weight from .pipeline import Pipeline diff --git a/sklearn/cluster/_mean_shift.py b/sklearn/cluster/_mean_shift.py index 6c52d266d5466..777d3b1832291 100644 --- a/sklearn/cluster/_mean_shift.py +++ b/sklearn/cluster/_mean_shift.py @@ -16,10 +16,11 @@ import numpy as np import warnings -from joblib import Parallel, delayed +from joblib import Parallel from collections import defaultdict from ..utils.validation import check_is_fitted, _deprecate_positional_args +from ..utils.fixes import delayed from ..utils import check_random_state, gen_batches, check_array from ..base import BaseEstimator, ClusterMixin from ..neighbors import NearestNeighbors diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 780ce46433b90..bae63c5e989d1 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -12,7 +12,7 @@ import numbers import numpy as np from scipy import sparse -from joblib import Parallel, delayed +from joblib import Parallel from ..base import clone, TransformerMixin from ..utils._estimator_html_repr import _VisualBlock @@ -25,6 +25,7 @@ from ..utils.metaestimators import _BaseComposition from ..utils.validation import check_array, check_is_fitted from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed __all__ = [ diff --git a/sklearn/covariance/_graph_lasso.py b/sklearn/covariance/_graph_lasso.py index 65024242e0a56..c43b465def374 100644 --- a/sklearn/covariance/_graph_lasso.py +++ b/sklearn/covariance/_graph_lasso.py @@ -13,13 +13,14 @@ import numpy as np from scipy import linalg -from joblib import Parallel, delayed +from joblib import Parallel from . import empirical_covariance, EmpiricalCovariance, log_likelihood from ..exceptions import ConvergenceWarning from ..utils.validation import check_random_state from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed # mypy error: Module 'sklearn.linear_model' has no attribute '_cd_fast' from ..linear_model import _cd_fast as cd_fast # type: ignore from ..linear_model import lars_path_gram diff --git a/sklearn/decomposition/_dict_learning.py b/sklearn/decomposition/_dict_learning.py index 388910444195d..71cbfde40d1c6 100644 --- a/sklearn/decomposition/_dict_learning.py +++ b/sklearn/decomposition/_dict_learning.py @@ -11,7 +11,7 @@ import numpy as np from scipy import linalg -from joblib import Parallel, delayed, effective_n_jobs +from joblib import Parallel, effective_n_jobs from ..base import BaseEstimator, TransformerMixin from ..utils import deprecated @@ -19,6 +19,7 @@ gen_batches) from ..utils.extmath import randomized_svd, row_norms from ..utils.validation import check_is_fitted, _deprecate_positional_args +from ..utils.fixes import delayed from ..linear_model import Lasso, orthogonal_mp_gram, LassoLars, Lars diff --git a/sklearn/decomposition/_lda.py b/sklearn/decomposition/_lda.py index d3c6fe2e57cd4..14dd87b9db130 100644 --- a/sklearn/decomposition/_lda.py +++ b/sklearn/decomposition/_lda.py @@ -14,13 +14,14 @@ import numpy as np import scipy.sparse as sp from scipy.special import gammaln, logsumexp -from joblib import Parallel, delayed, effective_n_jobs +from joblib import Parallel, effective_n_jobs from ..base import BaseEstimator, TransformerMixin from ..utils import check_random_state, gen_batches, gen_even_slices from ..utils.validation import check_non_negative from ..utils.validation import check_is_fitted from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed from ._online_lda_fast import (mean_change, _dirichlet_expectation_1d, _dirichlet_expectation_2d) diff --git a/sklearn/ensemble/_bagging.py b/sklearn/ensemble/_bagging.py index 375adbf6a3672..58104d23fcf4e 100644 --- a/sklearn/ensemble/_bagging.py +++ b/sklearn/ensemble/_bagging.py @@ -10,7 +10,7 @@ from abc import ABCMeta, abstractmethod from warnings import warn -from joblib import Parallel, delayed +from joblib import Parallel from ._base import BaseEnsemble, _partition_estimators from ..base import ClassifierMixin, RegressorMixin @@ -23,6 +23,7 @@ from ..utils.random import sample_without_replacement from ..utils.validation import has_fit_parameter, check_is_fitted, \ _check_sample_weight, _deprecate_positional_args +from ..utils.fixes import delayed __all__ = ["BaggingClassifier", diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index ebbb38244aa95..81fc319fdfadb 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -48,7 +48,7 @@ class calls the ``fit`` method of each sub-estimator on random samples import numpy as np from scipy.sparse import issparse from scipy.sparse import hstack as sparse_hstack -from joblib import Parallel, delayed +from joblib import Parallel from ..base import ClassifierMixin, RegressorMixin, MultiOutputMixin from ..metrics import r2_score @@ -59,6 +59,7 @@ class calls the ``fit`` method of each sub-estimator on random samples from ..utils import check_random_state, check_array, compute_sample_weight from ..exceptions import DataConversionWarning from ._base import BaseEnsemble, _partition_estimators +from ..utils.fixes import delayed from ..utils.fixes import _joblib_parallel_args from ..utils.multiclass import check_classification_targets from ..utils.validation import check_is_fitted, _check_sample_weight diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index 7eddc4c167165..47a7841db6bdd 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -7,7 +7,7 @@ from copy import deepcopy import numpy as np -from joblib import Parallel, delayed +from joblib import Parallel import scipy.sparse as sparse from ..base import clone @@ -33,6 +33,7 @@ from ..utils.validation import check_is_fitted from ..utils.validation import column_or_1d from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble, diff --git a/sklearn/ensemble/_voting.py b/sklearn/ensemble/_voting.py index e6c6effdd0bb9..811dbf59f405d 100644 --- a/sklearn/ensemble/_voting.py +++ b/sklearn/ensemble/_voting.py @@ -17,7 +17,7 @@ import numpy as np -from joblib import Parallel, delayed +from joblib import Parallel from ..base import ClassifierMixin from ..base import RegressorMixin @@ -33,6 +33,7 @@ from ..utils.validation import _deprecate_positional_args from ..exceptions import NotFittedError from ..utils._estimator_html_repr import _VisualBlock +from ..utils.fixes import delayed class _BaseVoting(TransformerMixin, _BaseHeterogeneousEnsemble): diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index f1b5e4793c551..2b7e9571ce6dd 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -8,12 +8,13 @@ import numpy as np import numbers -from joblib import Parallel, delayed, effective_n_jobs +from joblib import Parallel, effective_n_jobs from ..utils.metaestimators import if_delegate_has_method from ..utils.metaestimators import _safe_split from ..utils.validation import check_is_fitted from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed from ..base import BaseEstimator from ..base import MetaEstimatorMixin from ..base import clone diff --git a/sklearn/inspection/_permutation_importance.py b/sklearn/inspection/_permutation_importance.py index f3a46c3504a82..688f6e9e68e03 100644 --- a/sklearn/inspection/_permutation_importance.py +++ b/sklearn/inspection/_permutation_importance.py @@ -1,13 +1,13 @@ """Permutation importance for estimators.""" import numpy as np from joblib import Parallel -from joblib import delayed from ..metrics import check_scoring from ..utils import Bunch from ..utils import check_random_state from ..utils import check_array from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed def _weights_scorer(scorer, estimator, X, y, sample_weight): diff --git a/sklearn/inspection/_plot/partial_dependence.py b/sklearn/inspection/_plot/partial_dependence.py index fb8473dacb641..744d8d5493c9e 100644 --- a/sklearn/inspection/_plot/partial_dependence.py +++ b/sklearn/inspection/_plot/partial_dependence.py @@ -6,7 +6,7 @@ import numpy as np from scipy import sparse from scipy.stats.mstats import mquantiles -from joblib import Parallel, delayed +from joblib import Parallel from .. import partial_dependence from ...base import is_regressor @@ -14,6 +14,7 @@ from ...utils import check_matplotlib_support # noqa from ...utils import _safe_indexing from ...utils.validation import _deprecate_positional_args +from ...utils.fixes import delayed @_deprecate_positional_args diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index 9d165829c5e7e..2399e1216238f 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -23,7 +23,7 @@ from scipy import optimize from scipy import sparse from scipy.special import expit -from joblib import Parallel, delayed +from joblib import Parallel from ..base import (BaseEstimator, ClassifierMixin, RegressorMixin, MultiOutputMixin) @@ -37,6 +37,7 @@ from ..utils._seq_dataset import ArrayDataset32, CSRDataset32 from ..utils._seq_dataset import ArrayDataset64, CSRDataset64 from ..utils.validation import check_is_fitted, _check_sample_weight +from ..utils.fixes import delayed from ..preprocessing import normalize as f_normalize # TODO: bayesian_ridge_regression and bayesian_regression_ard diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index e4a3bcf6cbb7f..428e2c63378a2 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -12,7 +12,7 @@ import numpy as np from scipy import sparse -from joblib import Parallel, delayed, effective_n_jobs +from joblib import Parallel, effective_n_jobs from ._base import LinearModel, _pre_fit from ..base import RegressorMixin, MultiOutputMixin @@ -25,6 +25,7 @@ from ..utils.validation import check_is_fitted, _check_sample_weight from ..utils.validation import column_or_1d from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed # mypy error: Module 'sklearn.linear_model' has no attribute '_cd_fast' from . import _cd_fast as cd_fast # type: ignore diff --git a/sklearn/linear_model/_least_angle.py b/sklearn/linear_model/_least_angle.py index 889570452e365..2a2bc84e9d9eb 100644 --- a/sklearn/linear_model/_least_angle.py +++ b/sklearn/linear_model/_least_angle.py @@ -15,7 +15,7 @@ import numpy as np from scipy import linalg, interpolate from scipy.linalg.lapack import get_lapack_funcs -from joblib import Parallel, delayed +from joblib import Parallel from ._base import LinearModel from ..base import RegressorMixin, MultiOutputMixin @@ -25,6 +25,7 @@ from ..model_selection import check_cv from ..exceptions import ConvergenceWarning from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed SOLVE_TRIANGULAR_ARGS = {'check_finite': False} diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index bd205114fb99f..d58b4142de467 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -16,7 +16,7 @@ import numpy as np from scipy import optimize, sparse from scipy.special import expit, logsumexp -from joblib import Parallel, delayed, effective_n_jobs +from joblib import Parallel, effective_n_jobs from ._base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator from ._sag import sag_solver @@ -32,6 +32,7 @@ from ..utils.validation import _deprecate_positional_args from ..utils.multiclass import check_classification_targets from ..utils.fixes import _joblib_parallel_args +from ..utils.fixes import delayed from ..model_selection import check_cv from ..metrics import get_scorer diff --git a/sklearn/linear_model/_omp.py b/sklearn/linear_model/_omp.py index 93879af95c912..d8b4e4fec0758 100644 --- a/sklearn/linear_model/_omp.py +++ b/sklearn/linear_model/_omp.py @@ -11,12 +11,13 @@ import numpy as np from scipy import linalg from scipy.linalg.lapack import get_lapack_funcs -from joblib import Parallel, delayed +from joblib import Parallel from ._base import LinearModel, _pre_fit from ..base import RegressorMixin, MultiOutputMixin from ..utils import as_float_array, check_array from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed from ..model_selection import check_cv premature = """ Orthogonal matching pursuit ended prematurely due to linear diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index e4fb1298abd58..e99116ca4f3e3 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -9,7 +9,7 @@ from abc import ABCMeta, abstractmethod -from joblib import Parallel, delayed +from joblib import Parallel from ..base import clone, is_classifier from ._base import LinearClassifierMixin, SparseCoefMixin @@ -20,6 +20,7 @@ from ..utils.multiclass import _check_partial_fit_first_call from ..utils.validation import check_is_fitted, _check_sample_weight from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed from ..exceptions import ConvergenceWarning from ..model_selection import StratifiedShuffleSplit, ShuffleSplit diff --git a/sklearn/linear_model/_theil_sen.py b/sklearn/linear_model/_theil_sen.py index 559e94bb02c6b..f008c3f82ecb1 100644 --- a/sklearn/linear_model/_theil_sen.py +++ b/sklearn/linear_model/_theil_sen.py @@ -15,12 +15,13 @@ from scipy import linalg from scipy.special import binom from scipy.linalg.lapack import get_lapack_funcs -from joblib import Parallel, delayed, effective_n_jobs +from joblib import Parallel, effective_n_jobs from ._base import LinearModel from ..base import RegressorMixin from ..utils import check_random_state from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed from ..exceptions import ConvergenceWarning _EPSILON = np.finfo(np.double).eps diff --git a/sklearn/manifold/_mds.py b/sklearn/manifold/_mds.py index 8c3cf7ed03912..4cdb6b1b29cf5 100644 --- a/sklearn/manifold/_mds.py +++ b/sklearn/manifold/_mds.py @@ -6,7 +6,7 @@ # License: BSD import numpy as np -from joblib import Parallel, delayed, effective_n_jobs +from joblib import Parallel, effective_n_jobs import warnings @@ -15,6 +15,7 @@ from ..utils import check_random_state, check_array, check_symmetric from ..isotonic import IsotonicRegression from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed def _smacof_single(dissimilarities, metric=True, n_components=2, init=None, diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 570828f67f202..91178f9b7f2ab 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -17,7 +17,7 @@ from scipy.spatial import distance from scipy.sparse import csr_matrix from scipy.sparse import issparse -from joblib import Parallel, delayed, effective_n_jobs +from joblib import Parallel, effective_n_jobs from ..utils.validation import _num_samples from ..utils.validation import check_non_negative @@ -29,6 +29,7 @@ from ..preprocessing import normalize from ..utils._mask import _get_mask from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed from ..utils.fixes import sp_version, parse_version from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index cade49345d539..d62e8f318e5fa 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -32,12 +32,13 @@ from ._validation import _insert_error_scores from ._validation import _normalize_score_results from ..exceptions import NotFittedError -from joblib import Parallel, delayed +from joblib import Parallel from ..utils import check_random_state from ..utils.random import sample_without_replacement from ..utils.validation import indexable, check_is_fitted, _check_fit_params from ..utils.validation import _deprecate_positional_args from ..utils.metaestimators import if_delegate_has_method +from ..utils.fixes import delayed from ..metrics._scorer import _check_multimetric_scoring from ..metrics import check_scoring from ..utils import deprecated diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index a7c374ab33b90..40342c2d0ae00 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -18,13 +18,14 @@ import numpy as np import scipy.sparse as sp -from joblib import Parallel, delayed, logger +from joblib import Parallel, logger from ..base import is_classifier, clone from ..utils import indexable, check_random_state, _safe_indexing from ..utils.validation import _check_fit_params from ..utils.validation import _num_samples from ..utils.validation import _deprecate_positional_args +from ..utils.fixes import delayed from ..utils.metaestimators import _safe_split from ..metrics import check_scoring from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 4050524bb3f2d..e761784651787 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -53,9 +53,10 @@ check_classification_targets, _ovr_decision_function) from .utils.metaestimators import _safe_split, if_delegate_has_method +from .utils.fixes import delayed from .exceptions import NotFittedError -from joblib import Parallel, delayed +from joblib import Parallel __all__ = [ "OneVsRestClassifier", diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index 8336d7d126a57..484041e476173 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -16,7 +16,7 @@ import numpy as np import scipy.sparse as sp -from joblib import Parallel, delayed +from joblib import Parallel from abc import ABCMeta, abstractmethod from .base import BaseEstimator, clone, MetaEstimatorMixin @@ -27,6 +27,7 @@ from .utils.validation import (check_is_fitted, has_fit_parameter, _check_fit_params, _deprecate_positional_args) from .utils.multiclass import check_classification_targets +from .utils.fixes import delayed __all__ = ["MultiOutputRegressor", "MultiOutputClassifier", "ClassifierChain", "RegressorChain"] diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index 4b8de6b3655f6..7ffadbf7a1c03 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -15,7 +15,7 @@ import numpy as np from scipy.sparse import csr_matrix, issparse import joblib -from joblib import Parallel, delayed, effective_n_jobs +from joblib import Parallel, effective_n_jobs from ._ball_tree import BallTree from ._kd_tree import KDTree @@ -28,6 +28,7 @@ from ..utils.multiclass import check_classification_targets from ..utils.validation import check_is_fitted from ..utils.validation import check_non_negative +from ..utils.fixes import delayed from ..utils.fixes import parse_version from ..exceptions import DataConversionWarning, EfficiencyWarning @@ -703,9 +704,7 @@ class from an array representing our data set and ask who's parse_version(joblib.__version__) < parse_version('0.12')) if old_joblib: # Deal with change of API in joblib - check_pickle = False if old_joblib else None - delayed_query = delayed(_tree_query_parallel_helper, - check_pickle=check_pickle) + delayed_query = delayed(_tree_query_parallel_helper) parallel_kwargs = {"backend": "threading"} else: delayed_query = delayed(_tree_query_parallel_helper) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index e4b3da262e193..3baa4346366c3 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -14,7 +14,7 @@ import numpy as np from scipy import sparse -from joblib import Parallel, delayed +from joblib import Parallel from .base import clone, TransformerMixin from .utils._estimator_html_repr import _VisualBlock @@ -22,6 +22,7 @@ from .utils import Bunch, _print_elapsed_time from .utils.validation import check_memory from .utils.validation import _deprecate_positional_args +from .utils.fixes import delayed from .utils.metaestimators import _BaseComposition diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index d715584665ab6..d9896468959a1 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -10,7 +10,9 @@ # # License: BSD 3 clause +from functools import update_wrapper from distutils.version import LooseVersion +import functools import numpy as np import scipy.sparse as sp @@ -18,6 +20,7 @@ import scipy.stats from scipy.sparse.linalg import lsqr as sparse_lsqr # noqa from numpy.ma import MaskedArray as _MaskedArray # TODO: remove in 0.25 +from .._config import config_context, get_config from .deprecation import deprecated @@ -196,3 +199,24 @@ def _take_along_axis(arr, indices, axis): fancy_index = tuple(fancy_index) return arr[fancy_index] + + +# remove when https://github.com/joblib/joblib/issues/1071 is fixed +def delayed(function): + """Decorator used to capture the arguments of a function.""" + @functools.wraps(function) + def delayed_function(*args, **kwargs): + return _FuncWrapper(function), args, kwargs + return delayed_function + + +class _FuncWrapper: + """"Load the global configuration before calling the function.""" + def __init__(self, function): + self.function = function + self.config = get_config() + update_wrapper(self, self.function) + + def __call__(self, *args, **kwargs): + with config_context(**self.config): + return self.function(*args, **kwargs) diff --git a/sklearn/utils/tests/test_parallel.py b/sklearn/utils/tests/test_parallel.py new file mode 100644 index 0000000000000..c5f2c6a2f94ec --- /dev/null +++ b/sklearn/utils/tests/test_parallel.py @@ -0,0 +1,30 @@ +from distutils.version import LooseVersion + +import pytest +from joblib import Parallel +import joblib + +from numpy.testing import assert_array_equal + +from sklearn._config import config_context, get_config +from sklearn.utils.fixes import delayed + + +def get_working_memory(): + return get_config()["working_memory"] + + +@pytest.mark.parametrize("n_jobs", [1, 2]) +@pytest.mark.parametrize("backend", ["loky", "threading", + "multiprocessing"]) +def test_configuration_passes_through_to_joblib(n_jobs, backend): + # Tests that the global global configuration is passed to joblib jobs + + if joblib.__version__ < LooseVersion('0.12') and backend == 'loky': + pytest.skip('loky backend does not exist in joblib <0.12') + + with config_context(working_memory=123): + results = Parallel(n_jobs=n_jobs, backend=backend)( + delayed(get_working_memory)() for _ in range(2)) + + assert_array_equal(results, [123] * 2)