Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add max dist to NearestNDInterpolator #19483

Merged
merged 21 commits into from Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
84fde3f
ENH: query_max_dist option to speedup interpolate.NearestNDInterpolator
harshilkamdar Nov 7, 2023
ef4e6c5
ENH: query_max_dist option to speedup interpolate.NearestNDInterpolator
harshilkamdar Nov 7, 2023
cf2e6b7
MAINT: apply suggestions from code review
harshilkamdar Nov 8, 2023
2fa01c9
ENH: distance_upper_bound & workers for NearestNDInterpolator
harshilkamdar Nov 8, 2023
b625a28
BUG: add missing self for workers
harshilkamdar Nov 8, 2023
ff1d52f
ENH: supply query_options to NearestNDInterpolator
harshilkamdar Nov 11, 2023
b867112
MAINT: apply suggestions from code review
harshilkamdar Nov 15, 2023
0041cac
MAINT: update tests and doc with new API
harshilkamdar Nov 15, 2023
2f9f2f2
BUG: check all dists instead of if.
harshilkamdar Nov 15, 2023
839f533
Merge branch 'scipy:main' into add-max-dist-nearestnd-interpolate
harshilkamdar Nov 16, 2023
43ea043
BUG: update old test.
harshilkamdar Nov 16, 2023
8f7f8b2
Merge remote-tracking branch 'origin/add-max-dist-nearestnd-interpola…
harshilkamdar Nov 16, 2023
3d750bf
BUG: fix nD case
harshilkamdar Nov 22, 2023
2e497d3
BUG: fix tests and multidimensional query & y points
harshilkamdar Nov 22, 2023
4d01d49
BUG: fix test case for complex dtypes.
harshilkamdar Nov 22, 2023
7c3367a
DOC: fix documentation for dimensionality to be more accurate
harshilkamdar Nov 22, 2023
35c9577
Update scipy/interpolate/_ndgriddata.py
harshilkamdar Nov 22, 2023
b142521
DOC: fix documentation to be clearer for y values
harshilkamdar Nov 23, 2023
4189d20
Merge remote-tracking branch 'origin/add-max-dist-nearestnd-interpola…
harshilkamdar Nov 23, 2023
ca3f52f
DOC: forgot dim
harshilkamdar Nov 23, 2023
fd06795
DOC: a trivial doc tweak
ev-br Nov 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 46 additions & 5 deletions scipy/interpolate/_ndgriddata.py
Expand Up @@ -30,9 +30,9 @@ class NearestNDInterpolator(NDInterpolatorBase):

Parameters
----------
x : (npoints, ndims) 2-D ndarray of floats
x : (npoints, ndims) n-D ndarray of floats
ev-br marked this conversation as resolved.
Show resolved Hide resolved
Data point coordinates.
y : (npoints, ) 1-D ndarray of float or complex
y : (npoints, ...) n-D ndarray of floats or complex
ev-br marked this conversation as resolved.
Show resolved Hide resolved
Data values.
rescale : boolean, optional
Rescale points to unit cube before performing interpolation.
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(self, x, y, rescale=False, tree_options=None):
self.tree = cKDTree(self.points, **tree_options)
self.values = np.asarray(y)

def __call__(self, *args):
def __call__(self, *args, **query_options):
"""
Evaluate interpolator at given points.

Expand All @@ -108,6 +108,11 @@ def __call__(self, *args):
Points where to interpolate data at.
x1, x2, ... xn can be array-like of float with broadcastable shape.
or x1 can be array-like of float with shape ``(..., ndim)``
query_options : dict, optional
harshilkamdar marked this conversation as resolved.
Show resolved Hide resolved
This allows `eps`, `p`, `distance_upper_bound`, and `workers` being passed to the cKDTree's query function
ev-br marked this conversation as resolved.
Show resolved Hide resolved
to be explicitly set. See the `scipy.spatial.cKDTree.query` for an overview of the different options.

..versionadded:: 1.12.0
ev-br marked this conversation as resolved.
Show resolved Hide resolved

"""
# For the sake of enabling subclassing, NDInterpolatorBase._set_xi performs some operations
Expand All @@ -116,14 +121,50 @@ def __call__(self, *args):
xi = _ndim_coords_from_arrays(args, ndim=self.points.shape[1])
xi = self._check_call_shape(xi)
xi = self._scale_x(xi)
dist, i = self.tree.query(xi)
return self.values[i]

# We need to handle two important cases for compatibility with a flexible griddata:
Copy link
Member

@ev-br ev-br Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not bring griddata here. The rest of the comment is also confusing (at least to me, and this is not the first time I'm seeing the story about leading/trailing dimensions).

I'm not sure what your m, n, k and l refer to TBH. Maybe take a look at how RegularGridInterpolator documents essentially the same story, possible trailing dimensions in y and xi:

https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.RegularGridInterpolator.__call__.html#scipy.interpolate.RegularGridInterpolator.__call__

I'm not saying RGI does it perfectly (it is not), all I'm saying you maybe can come up with a better formulation for both (likely in a follow-up PR).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is confusing and griddata is not relevant - updated.

# (1) the case where xi is of some dimension (n, m, ..., D), where D is the coordinate dimension, and
# (2) the case where y is multidimensional (npoints, k, l, ...).
# We will first flatten xi to deal with case (1) and build an intermediate return array with shape
# (n*m*.., k, l, ...) and then reshape back to (n, m, ..., k, l, ...).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n*m* is confusing: what does the trailing asterisk stand for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm not being careful here in my words. Have given the description another go.


# Flatten xi for the query
xi_flat = xi.reshape(-1, xi.shape[-1])
original_shape = xi.shape
flattened_shape = xi_flat.shape

# if distance_upper_bound is set to not be infinite, then we need to consider the case where cKDtree
# does not find any points within distance_upper_bound to return. It marks those points as having infinte
# distance, which is what will be used below to mask the array and return only the points that were deemed
# to have a close enough neighbor to return something useful.
dist, i = self.tree.query(xi_flat, **query_options)
valid_mask = np.isfinite(dist)

# create a holder interp_values array with shape (n*m*.., k, l, ...) and fill with nans.
interp_shape = flattened_shape[:-1] + self.values.shape[1:] if self.values.ndim > 1 else flattened_shape[:-1]

if np.issubdtype(self.values.dtype, np.complexfloating):
interp_values = np.full(interp_shape, np.nan, dtype=self.values.dtype)
else:
interp_values = np.full(interp_shape, np.nan)
harshilkamdar marked this conversation as resolved.
Show resolved Hide resolved

if self.values.ndim == 1:
interp_values[valid_mask] = self.values[i[valid_mask]]
else:
interp_values[valid_mask] = self.values[i[valid_mask], ...]
ev-br marked this conversation as resolved.
Show resolved Hide resolved

# (n*m*.., k, l, ...) -> (n, m, ..., k, l, ...)
new_shape = original_shape[:-1] + self.values.shape[1:] if self.values.ndim > 1 else original_shape[:-1]
interp_values = interp_values.reshape(new_shape)

return interp_values


#------------------------------------------------------------------------------
# Convenience interface function
#------------------------------------------------------------------------------


def griddata(points, values, xi, method='linear', fill_value=np.nan,
rescale=False):
"""
Expand Down
32 changes: 32 additions & 0 deletions scipy/interpolate/tests/test_ndgriddata.py
Expand Up @@ -196,6 +196,38 @@ def test_nearest_list_argument(self):
NI = NearestNDInterpolator((d[0], d[1]), list(d[2]))
assert_array_equal(NI([0.1, 0.9], [0.1, 0.9]), [0, 2])

def test_nearest_query_options(self):
nd = np.array([[0, 0.5, 0, 1],
[0, 0, 0.5, 1],
[0, 1, 1, 2]])
delta = 0.1
query_points = [0 + delta, 1 + delta], [0 + delta, 1 + delta]

# case 1 - query max_dist is smaller than the query points' nearest distance to nd.
NI = NearestNDInterpolator((nd[0], nd[1]), nd[2])
distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) - 1e-7
assert_array_equal(NI(query_points, distance_upper_bound=distance_upper_bound),
[np.nan, np.nan])

# case 2 - query p is inf, will return [0, 2]
distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) - 1e-7
p = np.inf
assert_array_equal(NI(query_points, distance_upper_bound=distance_upper_bound, p=p),
[0, 2])

# case 3 - query max_dist is larger, so should return non np.nan
distance_upper_bound = np.sqrt(delta ** 2 + delta ** 2) + 1e-7
assert_array_equal(NI(query_points, distance_upper_bound=distance_upper_bound),
[0, 2])

def test_nearest_query_valid_inputs(self):
nd = np.array([[0, 1, 0, 1],
[0, 0, 1, 1],
[0, 1, 1, 2]])
NI = NearestNDInterpolator((nd[0], nd[1]), nd[2])
with assert_raises(TypeError):
NI([0.5, 0.5], query_options="not a dictionary")


class TestNDInterpolators:
@parametrize_interpolators
Expand Down