Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add max dist to NearestNDInterpolator #19483

Merged
merged 21 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
84fde3f
ENH: query_max_dist option to speedup interpolate.NearestNDInterpolator
harshilkamdar Nov 7, 2023
ef4e6c5
ENH: query_max_dist option to speedup interpolate.NearestNDInterpolator
harshilkamdar Nov 7, 2023
cf2e6b7
MAINT: apply suggestions from code review
harshilkamdar Nov 8, 2023
2fa01c9
ENH: distance_upper_bound & workers for NearestNDInterpolator
harshilkamdar Nov 8, 2023
b625a28
BUG: add missing self for workers
harshilkamdar Nov 8, 2023
ff1d52f
ENH: supply query_options to NearestNDInterpolator
harshilkamdar Nov 11, 2023
b867112
MAINT: apply suggestions from code review
harshilkamdar Nov 15, 2023
0041cac
MAINT: update tests and doc with new API
harshilkamdar Nov 15, 2023
2f9f2f2
BUG: check all dists instead of if.
harshilkamdar Nov 15, 2023
839f533
Merge branch 'scipy:main' into add-max-dist-nearestnd-interpolate
harshilkamdar Nov 16, 2023
43ea043
BUG: update old test.
harshilkamdar Nov 16, 2023
8f7f8b2
Merge remote-tracking branch 'origin/add-max-dist-nearestnd-interpola…
harshilkamdar Nov 16, 2023
3d750bf
BUG: fix nD case
harshilkamdar Nov 22, 2023
2e497d3
BUG: fix tests and multidimensional query & y points
harshilkamdar Nov 22, 2023
4d01d49
BUG: fix test case for complex dtypes.
harshilkamdar Nov 22, 2023
7c3367a
DOC: fix documentation for dimensionality to be more accurate
harshilkamdar Nov 22, 2023
35c9577
Update scipy/interpolate/_ndgriddata.py
harshilkamdar Nov 22, 2023
b142521
DOC: fix documentation to be clearer for y values
harshilkamdar Nov 23, 2023
4189d20
Merge remote-tracking branch 'origin/add-max-dist-nearestnd-interpola…
harshilkamdar Nov 23, 2023
ca3f52f
DOC: forgot dim
harshilkamdar Nov 23, 2023
fd06795
DOC: a trivial doc tweak
ev-br Nov 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions scipy/interpolate/_ndgriddata.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,55 @@ def __init__(self, x, y, rescale=False, tree_options=None):
self.tree = cKDTree(self.points, **tree_options)
self.values = np.asarray(y)

def __call__(self, *args):
def __call__(self, *args, query_options=None):
harshilkamdar marked this conversation as resolved.
Show resolved Hide resolved
"""
Evaluate interpolator at given points.

Parameters
----------
query_options : dict, optional
harshilkamdar marked this conversation as resolved.
Show resolved Hide resolved
This allows `eps`, `p`, `distance_upper_bound`, and `workers` being passed to the cKDTree's query function
ev-br marked this conversation as resolved.
Show resolved Hide resolved
to be explicitly set. See the `scipy.spatial.cKDTree.query` for an overview of the different options. Note
that k is restricted to 1 since NearestNDInterpolator has to have k=1 by definition.

..versionadded:: 1.12.0

x1, x2, ... xn : array-like of float
Points where to interpolate data at.
x1, x2, ... xn can be array-like of float with broadcastable shape.
or x1 can be array-like of float with shape ``(..., ndim)``

"""
if query_options is not None and not isinstance(query_options, dict):
raise TypeError("query_options must be a dictionary")

# gather the query options to pass to cKDTree.query
query_options = query_options or {}
eps = query_options.get('eps', 0)
p = query_options.get('p', 2)
distance_upper_bound = query_options.get('distance_upper_bound', np.inf)
workers = query_options.get('workers', 1)

harshilkamdar marked this conversation as resolved.
Show resolved Hide resolved
# For the sake of enabling subclassing, NDInterpolatorBase._set_xi performs some operations
# which are not required by NearestNDInterpolator.__call__, hence here we operate
# on xi directly, without calling a parent class function.
xi = _ndim_coords_from_arrays(args, ndim=self.points.shape[1])
xi = self._check_call_shape(xi)
xi = self._scale_x(xi)
dist, i = self.tree.query(xi)
return self.values[i]
dist, i = self.tree.query(xi, eps=eps, p=p, distance_upper_bound=distance_upper_bound, workers=workers)
harshilkamdar marked this conversation as resolved.
Show resolved Hide resolved

if np.isfinite(distance_upper_bound):
# if distance_upper_bound is set to not be infinite, then we need to consider the case where cKDtree
# does not find any points within distance_upper_bound to return. It marks those points as having infinte
# distance, which is what will be used below to mask the array and return only the points that were deemed
# to have a close enough neighbor to return something useful.
valid_mask = np.isfinite(dist)
interp_values = np.zeros(i.shape)
interp_values[~valid_mask] = np.nan
interp_values[valid_mask] = self.values[i[valid_mask]]
return interp_values
else:
return self.values[i]


#------------------------------------------------------------------------------
Expand Down
32 changes: 32 additions & 0 deletions scipy/interpolate/tests/test_ndgriddata.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,38 @@ def test_nearest_list_argument(self):
NI = NearestNDInterpolator((d[0], d[1]), list(d[2]))
assert_array_equal(NI([0.1, 0.9], [0.1, 0.9]), [0, 2])

def test_nearest_query_options(self):
nd = np.array([[0, 1, 0, 1],
[0, 0, 1, 1],
[0, 1, 1, 2]])
delta = 0.1
query_points = [0 + delta, 1 + delta], [0 + delta, 1 + delta]

# case 1 - query max_dist is smaller than the query points' nearest distance to nd.
NI = NearestNDInterpolator((nd[0], nd[1]), nd[2])
distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) - 1e-7
assert_array_equal(NI(query_points, query_options={"distance_upper_bound": distance_upper_bound}),
[np.nan, np.nan])

# case 2 - query p is inf, will return [0, 2]
distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) - 1e-7
p = np.inf
assert_array_equal(NI(query_points, query_options={"distance_upper_bound": distance_upper_bound, "p": p}),
[0, 2])

# case 3 - query max_dist is larger, so should return non np.nan
distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) + 1e-7
assert_array_equal(NI(query_points, query_options={"distance_upper_bound": distance_upper_bound}),
[0, 2])

def test_nearest_query_valid_inputs(self):
nd = np.array([[0, 1, 0, 1],
[0, 0, 1, 1],
[0, 1, 1, 2]])
NI = NearestNDInterpolator((nd[0], nd[1]), nd[2])
with self.assertRaises(TypeError):
NI([0.5, 0.5], query_options="not a dictionary")


class TestNDInterpolators:
@parametrize_interpolators
Expand Down
Loading