Skip to content

Commit

Permalink
Merge pull request #61 from mraspaud/fix-dask-chunking-geogrid
Browse files Browse the repository at this point in the history
Fix geogrid chunking to accept "auto" and to preserve dtype
  • Loading branch information
mraspaud authored Nov 21, 2023
2 parents 11e0519 + ed60b5c commit de23a17
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 24 deletions.
45 changes: 22 additions & 23 deletions geotiepoints/interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,6 @@ def _interp(self):
if np.array_equal(self.hrow_indices, self.row_indices):
return self._interp1d()

xpoints, ypoints = np.meshgrid(self.hrow_indices,
self.hcol_indices)

for num, data in enumerate(self.tie_data):
spl = RectBivariateSpline(self.row_indices,
self.col_indices,
Expand All @@ -221,8 +218,7 @@ def _interp(self):
kx=self.kx_,
ky=self.ky_)

new_data_ = spl.ev(xpoints.ravel(), ypoints.ravel())
self.new_data[num] = new_data_.reshape(xpoints.shape).T.copy(order='C')
self.new_data[num] = spl(self.hrow_indices, self.hcol_indices, grid=True)

def _interp1d(self):
"""Interpolate in one dimension."""
Expand Down Expand Up @@ -279,38 +275,29 @@ def interpolate_dask(self, fine_points, method, chunks):
"""Interpolate (lazily) to a dask array."""
from dask.base import tokenize
import dask.array as da
from dask.array.core import normalize_chunks
v_fine_points, h_fine_points = fine_points
shape = len(v_fine_points), len(h_fine_points)

try:
v_chunk_size, h_chunk_size = chunks
except TypeError:
v_chunk_size, h_chunk_size = chunks, chunks

vchunks = range(0, shape[0], v_chunk_size)
hchunks = range(0, shape[1], h_chunk_size)
chunks = normalize_chunks(chunks, shape, dtype=self.values.dtype)

token = tokenize(v_chunk_size, h_chunk_size, self.points, self.values, fine_points, method)
token = tokenize(chunks, self.points, self.values, fine_points, method)
name = 'interpolate-' + token

dskx = {(name, i, j): (self.interpolate_slices,
(slice(vcs, min(vcs + v_chunk_size, shape[0])),
slice(hcs, min(hcs + h_chunk_size, shape[1]))),
method
)
for i, vcs in enumerate(vchunks)
for j, hcs in enumerate(hchunks)
}
dskx = {(name, ) + position: (self.interpolate_slices,
slices,
method)
for position, slices in _enumerate_chunk_slices(chunks)}

res = da.Array(dskx, name, shape=list(shape),
chunks=(v_chunk_size, h_chunk_size),
chunks=chunks,
dtype=self.values.dtype)
return res

def interpolate_numpy(self, fine_points, method="linear"):
"""Interpolate to a numpy array."""
fine_x, fine_y = np.meshgrid(*fine_points, indexing='ij')
return self.interpolator((fine_x, fine_y), method=method)
return self.interpolator((fine_x, fine_y), method=method).astype(self.values.dtype)

def interpolate_slices(self, fine_points, method="linear"):
"""Interpolate using slices.
Expand All @@ -325,6 +312,18 @@ def interpolate_slices(self, fine_points, method="linear"):
return self.interpolate_numpy(fine_points, method=method)


def _enumerate_chunk_slices(chunks):
"""Enumerate chunks with slices."""
for position in np.ndindex(tuple(map(len, (chunks)))):
slices = []
for pos, chunk in zip(position, chunks):
chunk_size = chunk[pos]
offset = sum(chunk[:pos])
slices.append(slice(offset, offset + chunk_size))

yield (position, slices)


class MultipleGridInterpolator:
"""Interpolator that works on multiple data arrays."""

Expand Down
34 changes: 34 additions & 0 deletions geotiepoints/tests/test_geointerpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,40 @@ def test_geogrid_interpolation_to_shape(self):
np.testing.assert_allclose(lons[0, :], lons_expected, rtol=5e-5)
np.testing.assert_allclose(lats[:, 0], lats_expected, rtol=5e-5)

def test_geogrid_interpolation_preserves_dtype(self):
"""Test that the interpolator works with both explicit tie-point arrays and swath definition objects."""
x_points = np.array([0, 1, 3, 7])
y_points = np.array([0, 1, 3, 7, 15])

interpolator = GeoGridInterpolator((y_points, x_points),
TIE_LONS.astype(np.float32), TIE_LATS.astype(np.float32))

lons, lats = interpolator.interpolate_to_shape((16, 8))

assert lons.dtype == np.float32
assert lats.dtype == np.float32

def test_chunked_geogrid_interpolation(self):
"""Test that the interpolator works with both explicit tie-point arrays and swath definition objects."""
dask = pytest.importorskip("dask")

x_points = np.array([0, 1, 3, 7])
y_points = np.array([0, 1, 3, 7, 15])

interpolator = GeoGridInterpolator((y_points, x_points),
TIE_LONS.astype(np.float32), TIE_LATS.astype(np.float32))

lons, lats = interpolator.interpolate_to_shape((16, 8), chunks=4)

assert lons.chunks == ((4, 4, 4, 4), (4, 4))
assert lats.chunks == ((4, 4, 4, 4), (4, 4))

with dask.config.set({"array.chunk-size": 64}):

lons, lats = interpolator.interpolate_to_shape((16, 8), chunks="auto")
assert lons.chunks == ((4, 4, 4, 4), (4, 4))
assert lats.chunks == ((4, 4, 4, 4), (4, 4))

def test_geogrid_interpolation_can_extrapolate(self):
"""Test that the interpolator can also extrapolate given the right parameters."""
x_points = np.array([0, 1, 3, 7])
Expand Down
21 changes: 20 additions & 1 deletion geotiepoints/tests/test_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def grid_interpolator():
[2, 2, 2, 1],
[0, 3, 3, 3],
[1, 2, 1, 2],
[4, 4, 4, 4]])
[4, 4, 4, 4]],
dtype=np.float64)

return SingleGridInterpolator((ypoints, xpoints), data)

Expand Down Expand Up @@ -378,3 +379,21 @@ def test_interpolate_dask(self, grid_interpolator, chunks, expected_chunks):

np.testing.assert_allclose(res, self.expected, atol=2e-9)
assert interpolate.called

def test_interpolate_preserves_dtype(self):
"""Test that interpolation is preserving the dtype."""
xpoints = np.array([0, 3, 7, 15])
ypoints = np.array([0, 3, 7, 15, 31])
data = np.array([[0, 1, 0, 1],
[2, 2, 2, 1],
[0, 3, 3, 3],
[1, 2, 1, 2],
[4, 4, 4, 4]],
dtype=np.float32)

grid_interpolator = SingleGridInterpolator((ypoints, xpoints), data)
fine_x = np.arange(16)
fine_y = np.arange(32)

res = grid_interpolator.interpolate((fine_y, fine_x), method="cubic")
assert res.dtype == data.dtype

0 comments on commit de23a17

Please sign in to comment.