Skip to content

Commit

Permalink
ENH: supply query_options to NearestNDInterpolator
Browse files Browse the repository at this point in the history
  • Loading branch information
harshilkamdar committed Nov 11, 2023
1 parent b625a28 commit ff1d52f
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 26 deletions.
38 changes: 22 additions & 16 deletions scipy/interpolate/_ndgriddata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,6 @@ class NearestNDInterpolator(NDInterpolatorBase):
Options passed to the underlying ``cKDTree``.
.. versionadded:: 0.17.0
distance_upper_bound : float, optional
Option to truncate ``cKDTree`` nearest neighbor query to some
``distance_upper_bound``. This is useful for larger interpolation problems.
.. versionadded:: 1.12.0
workers : int, optional
Number of workers to use for parallel processing during nearest neighbor searches.
If -1 is given all CPU threads are used. Default: 1.
.. versionadded:: 1.12.0
See Also
--------
Expand Down Expand Up @@ -99,37 +89,53 @@ class NearestNDInterpolator(NDInterpolatorBase):
"""

def __init__(self, x, y, rescale=False, tree_options=None, *, distance_upper_bound=np.inf, workers=1):
def __init__(self, x, y, rescale=False, tree_options=None):
NDInterpolatorBase.__init__(self, x, y, rescale=rescale,
need_contiguous=False,
need_values=False)
if tree_options is None:
tree_options = dict()
self.tree = cKDTree(self.points, **tree_options)
self.distance_upper_bound = distance_upper_bound
self.workers = workers
self.values = np.asarray(y)

def __call__(self, *args):
def __call__(self, *args, query_options=None):
"""
Evaluate interpolator at given points.
Parameters
----------
query_options : dict, optional
This allows `eps`, `p`, `distance_upper_bound`, and `workers` being passed to the cKDTree's query function
to be explicitly set. See the `scipy.spatial.cKDTree.query` for an overview of the different options. Note
that k is restricted to 1 since NearestNDInterpolator has to have k=1 by definition.
..versionadded:: 1.12.0
x1, x2, ... xn : array-like of float
Points where to interpolate data at.
x1, x2, ... xn can be array-like of float with broadcastable shape.
or x1 can be array-like of float with shape ``(..., ndim)``
"""
if query_options is not None and not isinstance(query_options, dict):
raise TypeError("query_options must be a dictionary")

# gather the query options to pass to cKDTree.query
query_options = query_options or {}
eps = query_options.get('eps', 0)
p = query_options.get('p', 2)
distance_upper_bound = query_options.get('distance_upper_bound', np.inf)
workers = query_options.get('workers', 1)

# For the sake of enabling subclassing, NDInterpolatorBase._set_xi performs some operations
# which are not required by NearestNDInterpolator.__call__, hence here we operate
# on xi directly, without calling a parent class function.
xi = _ndim_coords_from_arrays(args, ndim=self.points.shape[1])
xi = self._check_call_shape(xi)
xi = self._scale_x(xi)
dist, i = self.tree.query(xi, distance_upper_bound=self.distance_upper_bound, workers=self.workers)
if np.isfinite(self.distance_upper_bound):
dist, i = self.tree.query(xi, eps=eps, p=p, distance_upper_bound=distance_upper_bound, workers=workers)

if np.isfinite(distance_upper_bound):
# if distance_upper_bound is set to not be infinite, then we need to consider the case where cKDtree
# does not find any points within distance_upper_bound to return. It marks those points as having infinte
# distance, which is what will be used below to mask the array and return only the points that were deemed
Expand Down
34 changes: 24 additions & 10 deletions scipy/interpolate/tests/test_ndgriddata.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,23 +196,37 @@ def test_nearest_list_argument(self):
NI = NearestNDInterpolator((d[0], d[1]), list(d[2]))
assert_array_equal(NI([0.1, 0.9], [0.1, 0.9]), [0, 2])

def test_nearest_max_dist(self):
def test_nearest_query_options(self):
nd = np.array([[0, 1, 0, 1],
[0, 0, 1, 1],
[0, 1, 1, 2]])
delta = 0.1
query_points = [0 + delta, 1 + delta], [0 + delta, 1 + delta]

# case 1 - query max_dist is smaller than the query points' nearest distance to nd.
NI = NearestNDInterpolator((nd[0], nd[1]), nd[2],
distance_upper_bound=np.sqrt(delta ** 2 + delta ** 2) - 1e-7)
assert_array_equal(NI(query_points), [np.nan, np.nan])

# case 2 - query max_dist is larger, so should return non np.nan
NI = NearestNDInterpolator((nd[0], nd[1]), nd[2],
distance_upper_bound=np.sqrt(delta ** 2 + delta ** 2) + 1e-7)
assert_array_equal(NI(query_points), [0, 2])

NI = NearestNDInterpolator((nd[0], nd[1]), nd[2])
distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) - 1e-7
assert_array_equal(NI(query_points, query_options={"distance_upper_bound": distance_upper_bound}),
[np.nan, np.nan])

# case 2 - query p is inf, will return [0, 2]
distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) - 1e-7
p = np.inf
assert_array_equal(NI(query_points, query_options={"distance_upper_bound": distance_upper_bound, "p": p}),
[0, 2])

# case 3 - query max_dist is larger, so should return non np.nan
distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) + 1e-7
assert_array_equal(NI(query_points, query_options={"distance_upper_bound": distance_upper_bound}),
[0, 2])

def test_nearest_query_valid_inputs(self):
nd = np.array([[0, 1, 0, 1],
[0, 0, 1, 1],
[0, 1, 1, 2]])
NI = NearestNDInterpolator((nd[0], nd[1]), nd[2])
with self.assertRaises(TypeError):
NI([0.5, 0.5], query_options="not a dictionary")


class TestNDInterpolators:
Expand Down

0 comments on commit ff1d52f

Please sign in to comment.