diff --git a/sklearn/neighbors/base.py b/sklearn/neighbors/base.py index cba4d0d87c225..730a605cd5baa 100644 --- a/sklearn/neighbors/base.py +++ b/sklearn/neighbors/base.py @@ -283,6 +283,15 @@ def _pairwise(self): return self.metric == 'precomputed' +def _tree_query_parallel_helper(tree, data, n_neighbors, return_distance): + """Helper for the Parallel calls in KNeighborsMixin.kneighbors + + The Cython method tree.query is not directly picklable by cloudpickle + under PyPy. + """ + return tree.query(data, n_neighbors, return_distance) + + class KNeighborsMixin(object): """Mixin for k-neighbors searches""" @@ -433,15 +442,15 @@ class from an array representing our data set and ask who's if (sys.version_info < (3,) or LooseVersion(joblib_version) < LooseVersion('0.12')): # Deal with change of API in joblib - delayed_query = delayed(self._tree.query, + delayed_query = delayed(_tree_query_parallel_helper, check_pickle=False) parallel_kwargs = {"backend": "threading"} else: - delayed_query = delayed(self._tree.query) + delayed_query = delayed(_tree_query_parallel_helper) parallel_kwargs = {"prefer": "threads"} result = Parallel(n_jobs, **parallel_kwargs)( delayed_query( - X[s], n_neighbors, return_distance) + self._tree, X[s], n_neighbors, return_distance) for s in gen_even_slices(X.shape[0], n_jobs) ) else: @@ -561,6 +570,15 @@ def kneighbors_graph(self, X=None, n_neighbors=None, return kneighbors_graph +def _tree_query_radius_parallel_helper(tree, data, radius, return_distance): + """Helper for the Parallel calls in RadiusNeighborsMixin.radius_neighbors + + The Cython method tree.query_radius is not directly picklable by + cloudpickle under PyPy. + """ + return tree.query_radius(data, radius, return_distance) + + class RadiusNeighborsMixin(object): """Mixin for radius-based neighbors searches""" @@ -717,14 +735,14 @@ class from an array representing our data set and ask who's n_jobs = effective_n_jobs(self.n_jobs) if LooseVersion(joblib_version) < LooseVersion('0.12'): # Deal with change of API in joblib - delayed_query = delayed(self._tree.query_radius, + delayed_query = delayed(_tree_query_radius_parallel_helper, check_pickle=False) parallel_kwargs = {"backend": "threading"} else: - delayed_query = delayed(self._tree.query_radius) + delayed_query = delayed(_tree_query_radius_parallel_helper) parallel_kwargs = {"prefer": "threads"} results = Parallel(n_jobs, **parallel_kwargs)( - delayed_query(X[s], radius, return_distance) + delayed_query(self._tree, X[s], radius, return_distance) for s in gen_even_slices(X.shape[0], n_jobs) ) if return_distance: