Skip to content

Commit

Permalink
FIX Workaround limitation of cloudpickle under PyPy (#12566)
Browse files Browse the repository at this point in the history
  • Loading branch information
ogrisel authored and jnothman committed Nov 14, 2018
1 parent 01e1529 commit fc538bd
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions sklearn/neighbors/base.py
Expand Up @@ -283,6 +283,15 @@ def _pairwise(self):
return self.metric == 'precomputed'


def _tree_query_parallel_helper(tree, data, n_neighbors, return_distance):
"""Helper for the Parallel calls in KNeighborsMixin.kneighbors
The Cython method tree.query is not directly picklable by cloudpickle
under PyPy.
"""
return tree.query(data, n_neighbors, return_distance)


class KNeighborsMixin(object):
"""Mixin for k-neighbors searches"""

Expand Down Expand Up @@ -433,15 +442,15 @@ class from an array representing our data set and ask who's
if (sys.version_info < (3,) or
LooseVersion(joblib_version) < LooseVersion('0.12')):
# Deal with change of API in joblib
delayed_query = delayed(self._tree.query,
delayed_query = delayed(_tree_query_parallel_helper,
check_pickle=False)
parallel_kwargs = {"backend": "threading"}
else:
delayed_query = delayed(self._tree.query)
delayed_query = delayed(_tree_query_parallel_helper)
parallel_kwargs = {"prefer": "threads"}
result = Parallel(n_jobs, **parallel_kwargs)(
delayed_query(
X[s], n_neighbors, return_distance)
self._tree, X[s], n_neighbors, return_distance)
for s in gen_even_slices(X.shape[0], n_jobs)
)
else:
Expand Down Expand Up @@ -561,6 +570,15 @@ def kneighbors_graph(self, X=None, n_neighbors=None,
return kneighbors_graph


def _tree_query_radius_parallel_helper(tree, data, radius, return_distance):
"""Helper for the Parallel calls in RadiusNeighborsMixin.radius_neighbors
The Cython method tree.query_radius is not directly picklable by
cloudpickle under PyPy.
"""
return tree.query_radius(data, radius, return_distance)


class RadiusNeighborsMixin(object):
"""Mixin for radius-based neighbors searches"""

Expand Down Expand Up @@ -717,14 +735,14 @@ class from an array representing our data set and ask who's
n_jobs = effective_n_jobs(self.n_jobs)
if LooseVersion(joblib_version) < LooseVersion('0.12'):
# Deal with change of API in joblib
delayed_query = delayed(self._tree.query_radius,
delayed_query = delayed(_tree_query_radius_parallel_helper,
check_pickle=False)
parallel_kwargs = {"backend": "threading"}
else:
delayed_query = delayed(self._tree.query_radius)
delayed_query = delayed(_tree_query_radius_parallel_helper)
parallel_kwargs = {"prefer": "threads"}
results = Parallel(n_jobs, **parallel_kwargs)(
delayed_query(X[s], radius, return_distance)
delayed_query(self._tree, X[s], radius, return_distance)
for s in gen_even_slices(X.shape[0], n_jobs)
)
if return_distance:
Expand Down

0 comments on commit fc538bd

Please sign in to comment.