Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 100 additions & 4 deletions xrspatial/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,12 +945,53 @@ def _run_dask_cupy(data, scale_y, scale_x, method):

# -- Public API --------------------------------------------------------------

def _resolve_nodata(agg, nodata):
"""Resolve the input-side nodata sentinel.

Explicit *nodata* wins. Otherwise fall back to ``_FillValue`` then
``nodata`` in ``agg.attrs``. Returns ``None`` when no sentinel was
found (the caller skips the masking step).

NaN sentinels are returned as NaN so the caller can branch on
``np.isnan`` rather than ``==`` (which never matches NaN).
"""
if nodata is None:
for key in ('_FillValue', 'nodata'):
v = agg.attrs.get(key)
if v is not None:
nodata = v
break
if nodata is None:
return None
nd = float(nodata)
if np.isinf(nd):
raise ValueError(f"nodata must be finite or NaN, got {nodata!r}")
return nd


def _apply_nodata_mask(agg, nodata):
"""Return a float copy of *agg* with sentinel pixels replaced by NaN.

Works for numpy, cupy, dask+numpy, and dask+cupy backings via
xarray's ``.where`` (which dispatches per backend).
"""
if nodata is None:
return agg
# Promote to float so NaN can be stored. xr.where keeps the backend.
if not np.issubdtype(agg.dtype, np.floating):
agg = agg.astype(np.float64)
if np.isnan(nodata):
return agg # already-NaN sentinels need no replacement
return agg.where(agg != nodata)


@supports_dataset
def resample(
agg,
scale_factor=None,
target_resolution=None,
method='nearest',
nodata=None,
name='resample',
):
"""Change raster resolution without changing its CRS.
Expand All @@ -960,21 +1001,30 @@ def resample(
Parameters
----------
agg : xarray.DataArray
Input raster (2-D).
Input raster. 2-D ``(y, x)`` or 3-D ``(band, y, x)``. For 3-D
inputs each band is resampled independently and the leading
non-spatial coordinate is preserved.
scale_factor : float or (float, float), optional
Multiplicative factor applied to the number of pixels.
``0.5`` halves the pixel count (doubles the cell size);
``2.0`` doubles the pixel count (halves the cell size).
A two-element tuple sets ``(scale_y, scale_x)`` independently.
target_resolution : float, optional
target_resolution : float or (float, float), optional
Desired cell size in the same units as the raster coordinates.
Both axes are set to this resolution.
A scalar sets both axes to the same resolution; a 2-tuple sets
``(res_y, res_x)`` independently.
method : str, default ``'nearest'``
Resampling algorithm. Interpolation methods (``'nearest'``,
``'bilinear'``, ``'cubic'``) work for both upsampling and
downsampling. Aggregation methods (``'average'``, ``'min'``,
``'max'``, ``'median'``, ``'mode'``) only support downsampling
(scale_factor <= 1).
nodata : float, optional
Sentinel value in the input that should be treated as missing.
Input pixels equal to *nodata* are replaced with NaN before
resampling. When ``None``, falls back to ``agg.attrs['_FillValue']``
then ``agg.attrs['nodata']``. The output uses NaN as the sentinel
regardless of the input convention.
name : str, default ``'resample'``
Name for the output DataArray.

Expand All @@ -984,7 +1034,7 @@ def resample(
Resampled raster with updated coordinates, ``res`` attribute,
and float32 dtype.
"""
_validate_raster(agg, func_name='resample', name='agg')
_validate_raster(agg, func_name='resample', name='agg', ndim=(2, 3))

if method not in ALL_METHODS:
raise ValueError(
Expand Down Expand Up @@ -1025,12 +1075,56 @@ def resample(
f"(scale_factor <= 1.0)"
)

# -- nodata: replace sentinels with NaN before resampling ----------------
nd_resolved = _resolve_nodata(agg, nodata)
has_nodata = nd_resolved is not None
if has_nodata:
agg = _apply_nodata_mask(agg, nd_resolved)

# -- fast path: identity -------------------------------------------------
if scale_y == 1.0 and scale_x == 1.0:
out = agg.copy()
out.name = name
# When nodata was applied, advertise NaN as the new sentinel.
if has_nodata:
out.attrs['_FillValue'] = float('nan')
return out

# -- 3D: dispatch per band ----------------------------------------------
if agg.ndim == 3:
leading_dim = agg.dims[0]
bands = []
for i in range(agg.sizes[leading_dim]):
band_2d = agg.isel({leading_dim: i})
band_out = resample(
band_2d,
scale_factor=scale_factor,
target_resolution=target_resolution,
method=method,
# Pass NaN so the recursive call short-circuits masking
# (we already applied the mask on the 3D input above) and
# ignores the original attrs sentinel.
nodata=float('nan'),
name=name,
)
bands.append(band_out)
# Stack along the leading dim. concat preserves the per-band
# coordinate when each input has it.
result = xr.concat(bands, dim=leading_dim)
# concat may reorder dims; transpose to the original layout.
result = result.transpose(*agg.dims)
result.name = name
# Carry across input attrs (concat picks the first; merge with input).
new_attrs = dict(agg.attrs)
new_attrs.update(bands[0].attrs) # res from per-band resample
if has_nodata:
new_attrs['_FillValue'] = float('nan')
result.attrs = new_attrs
# Preserve the leading-dim coordinate if it was on the input.
if leading_dim in agg.coords:
result = result.assign_coords({leading_dim: agg.coords[leading_dim]})
return result

# -- memory guard for eager backends ------------------------------------
# Dask paths build per-chunk allocations lazily (chunk size already
# bounds peak memory). The eager numpy and cupy paths allocate the
Expand Down Expand Up @@ -1077,6 +1171,8 @@ def _new_coords(vals, n_out):

new_attrs = dict(agg.attrs)
new_attrs['res'] = (abs(px), abs(py))
if has_nodata:
new_attrs['_FillValue'] = float('nan')

# Refresh `transform` if the input had one. The rasterio 6-tuple is
# (res_x, 0.0, left, 0.0, -res_y, top). `top` is the upper edge of
Expand Down
178 changes: 178 additions & 0 deletions xrspatial/tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,181 @@ def test_dask_aggregate_smoke_200x200(self):
assert elapsed < 30.0, (
f"dask aggregate took {elapsed:.2f}s; expected well under 5s"
)


# ---------------------------------------------------------------------------
# 3D rasters (issue #1466)
# ---------------------------------------------------------------------------

class TestThreeDRasters:
"""Multi-band ``(band, y, x)`` rasters resample per-band."""

def _make_3d(self, backend='numpy'):
# 3 bands of an 8x8 gradient. Each band has a unique offset so we
# can confirm bands aren't mixed during the dispatch.
y, x = np.mgrid[0:8, 0:8]
band0 = (y * 10 + x).astype(np.float32)
band1 = band0 + 100
band2 = band0 + 200
data = np.stack([band0, band1, band2], axis=0)
agg = xr.DataArray(
data,
dims=('band', 'y', 'x'),
coords={
'band': np.array([1, 2, 3]),
'y': np.arange(8, dtype=np.float64),
'x': np.arange(8, dtype=np.float64),
},
attrs={'res': (1.0, 1.0)},
name='myraster',
)
if backend == 'dask':
import dask.array as da
agg = agg.copy()
agg.data = da.from_array(agg.data, chunks=(1, 4, 4))
elif backend == 'cupy':
import cupy
agg = agg.copy()
agg.data = cupy.asarray(agg.data)
return agg

def test_3d_numpy_shape_and_band_coord(self):
agg = self._make_3d('numpy')
out = resample(agg, scale_factor=0.5, method='nearest')
assert out.shape == (3, 4, 4)
assert out.dims == ('band', 'y', 'x')
np.testing.assert_array_equal(out['band'].values, [1, 2, 3])

def test_3d_per_band_independence(self):
"""Each band's output should be the 2D resample of that band."""
agg = self._make_3d('numpy')
out = resample(agg, scale_factor=0.5, method='average')
for i in range(3):
band_2d = agg.isel(band=i).reset_coords(drop=True)
ref = resample(band_2d, scale_factor=0.5, method='average')
np.testing.assert_allclose(out.isel(band=i).values, ref.values,
atol=1e-5)

def test_3d_target_resolution_tuple(self):
agg = self._make_3d('numpy')
out = resample(agg, target_resolution=(2.0, 4.0))
assert out.shape == (3, 4, 2)

@dask_array_available
def test_3d_dask(self):
agg = self._make_3d('dask')
out = resample(agg, scale_factor=0.5, method='nearest')
assert out.shape == (3, 4, 4)
np.testing.assert_array_equal(out['band'].values, [1, 2, 3])

@cuda_and_cupy_available
def test_3d_cupy(self):
agg = self._make_3d('cupy')
out = resample(agg, scale_factor=0.5, method='nearest')
assert out.shape == (3, 4, 4)
np.testing.assert_array_equal(out['band'].values, [1, 2, 3])


# ---------------------------------------------------------------------------
# nodata handling (issue #1466)
# ---------------------------------------------------------------------------

class TestNodata:
def test_explicit_nodata_int_sentinel(self):
# Integer raster with -9999 sentinel. After resample those pixels
# should become NaN; valid pixels stay as float interpolations.
data = np.array([
[-9999, -9999, 10, 10],
[-9999, -9999, 10, 10],
[20, 20, 30, 30],
[20, 20, 30, 30],
], dtype=np.int32)
agg = create_test_raster(data, attrs={'res': (1.0, 1.0)})
out = resample(agg, scale_factor=0.5, method='nearest',
nodata=-9999)
assert out.shape == (2, 2)
# Top-left output pixel maps to the -9999 region -> NaN
assert np.isnan(out.values[0, 0])
# Bottom-right pixel maps to a valid region -> finite
assert np.isfinite(out.values[1, 1])
assert out.attrs.get('_FillValue') is not None
assert np.isnan(out.attrs['_FillValue'])

def test_nodata_from_fillvalue_attr(self):
# Same data, but sentinel discovered via _FillValue attr.
data = np.array([
[-9999, -9999, 10, 10],
[-9999, -9999, 10, 10],
[20, 20, 30, 30],
[20, 20, 30, 30],
], dtype=np.int32)
agg = create_test_raster(
data, attrs={'res': (1.0, 1.0), '_FillValue': -9999}
)
out = resample(agg, scale_factor=0.5, method='nearest')
assert np.isnan(out.values[0, 0])
assert np.isfinite(out.values[1, 1])

def test_nodata_from_nodata_attr(self):
data = np.array([
[-9999, -9999, 10, 10],
[-9999, -9999, 10, 10],
[20, 20, 30, 30],
[20, 20, 30, 30],
], dtype=np.int32)
agg = create_test_raster(
data, attrs={'res': (1.0, 1.0), 'nodata': -9999}
)
out = resample(agg, scale_factor=0.5, method='average')
assert np.isnan(out.values[0, 0])

def test_nodata_none_no_attrs_unchanged(self):
# Without an explicit param or attr, behavior matches the old
# (pre-#1466) implementation -- no masking, no _FillValue added.
data = np.arange(16, dtype=np.float32).reshape(4, 4)
agg = create_test_raster(data, attrs={'res': (1.0, 1.0)})
out = resample(agg, scale_factor=0.5, method='nearest')
assert '_FillValue' not in out.attrs

def test_nodata_float_explicit(self):
# Float sentinel -- e.g. -1.0 marking masked pixels.
data = np.array([[-1.0, -1.0, 5.0, 5.0],
[-1.0, -1.0, 5.0, 5.0],
[3.0, 3.0, 7.0, 7.0],
[3.0, 3.0, 7.0, 7.0]], dtype=np.float32)
agg = create_test_raster(data, attrs={'res': (1.0, 1.0)})
out = resample(agg, scale_factor=0.5, method='nearest', nodata=-1.0)
assert np.isnan(out.values[0, 0])

def test_explicit_nodata_overrides_attr(self):
# Explicit param wins over _FillValue attr.
# 4x4 with -1 in the top-left 2x2 block. _FillValue says -999
# (which doesn't appear); explicit nodata=-1 should mask the corner.
data = np.array([[-1.0, -1.0, 5.0, 5.0],
[-1.0, -1.0, 5.0, 5.0],
[3.0, 3.0, 7.0, 7.0],
[3.0, 3.0, 7.0, 7.0]], dtype=np.float32)
agg = create_test_raster(
data, attrs={'res': (1.0, 1.0), '_FillValue': -999.0}
)
out = resample(agg, scale_factor=0.5, method='nearest', nodata=-1.0)
# Without override the attr would say -999 (no match) and -1 would
# leak through; with override the top-left output pixel is NaN.
assert np.isnan(out.values[0, 0])


# ---------------------------------------------------------------------------
# target_resolution as 2-tuple (issue #1466)
# ---------------------------------------------------------------------------

class TestTargetResolutionTuple:
def test_tuple_resolution_independent_axes(self, grid_8x8):
# 8x8 grid with res=(1, 1) -> target (2, 4) -> output (4, 2).
out = resample(grid_8x8, target_resolution=(2.0, 4.0))
assert out.shape == (4, 2)

def test_tuple_resolution_matches_scale_factor(self, grid_8x8):
# target_resolution=(2.0, 2.0) should match scale_factor=0.5.
a = resample(grid_8x8, target_resolution=(2.0, 2.0), method='nearest')
b = resample(grid_8x8, scale_factor=0.5, method='nearest')
np.testing.assert_allclose(a.values, b.values)