Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 100 additions & 42 deletions xrspatial/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,42 @@
_INTERP_DEPTH = {'nearest': 1, 'bilinear': 1, 'cubic': 10}

# Approximate working-set size per output cell for the eager backends:
# one float64 working buffer (8 B) plus a float32 output cell (4 B).
# scipy.ndimage.map_coordinates also allocates a temporary of the same
# size during higher-order spline evaluation; the 0.5 * available bound
# below leaves room for that.
_BYTES_PER_OUTPUT_CELL = 12
# one float64 working buffer (8 B) plus a float64 output cell (8 B) in
# the worst case. scipy.ndimage.map_coordinates also allocates a
# temporary of the same size during higher-order spline evaluation; the
# 0.5 * available bound below leaves room for that.
_BYTES_PER_OUTPUT_CELL = 16


# -- Working / output dtype selection ----------------------------------------

def _working_dtype(input_dtype):
"""Pick the working float dtype for resampling.

float64 inputs stay in float64 to preserve precision; everything else
(smaller floats, integers, bool) uses float32.
"""
dt = np.dtype(input_dtype)
if dt.kind == 'f' and dt.itemsize >= 8:
return np.float64
return np.float32


def _output_dtype(input_dtype):
"""Pick the output dtype for resampling.

Float inputs keep their dtype. Integer / bool inputs return float32
because NaN-sentinel resampling needs a float type.
"""
dt = np.dtype(input_dtype)
if dt.kind == 'f':
return dt.type
return np.float32


def _maybe_astype(arr, dtype):
"""astype copy that no-ops when already at the requested dtype."""
return arr if arr.dtype == np.dtype(dtype) else arr.astype(dtype)


# -- Memory guard ------------------------------------------------------------
Expand Down Expand Up @@ -599,13 +630,13 @@ def _agg_block_mode_nb(data, target_h, target_w,
def _interp_block_np(block, global_in_h, global_in_w,
global_out_h, global_out_w,
cum_in_y, cum_in_x, cum_out_y, cum_out_x,
depth, order, block_info=None):
depth, order, work_dtype, out_dtype, block_info=None):
"""Interpolate one (possibly overlapped) numpy block."""
yi, xi = block_info[0]['chunk-location']
target_h = int(cum_out_y[yi + 1] - cum_out_y[yi])
target_w = int(cum_out_x[xi + 1] - cum_out_x[xi])

block = block.astype(np.float64)
block = _maybe_astype(block, work_dtype)

# Global output pixel indices for this chunk
oy = np.arange(cum_out_y[yi], cum_out_y[yi + 1], dtype=np.float64)
Expand All @@ -628,28 +659,29 @@ def _interp_block_np(block, global_in_h, global_in_w,
result = _scipy_map_coords(block, coords, order=order, mode='nearest')
else:
filled = np.where(mask, 0.0, block)
weights = (~mask).astype(np.float64)
weights = (~mask).astype(block.dtype)
z_data = _scipy_map_coords(filled, coords, order=order, mode='nearest')
z_wt = _scipy_map_coords(weights, coords, order=order, mode='nearest')
# Majority-weight gate (see _nan_aware_interp_np for rationale).
result = np.where(z_wt > 0.5,
z_data / np.maximum(z_wt, 1e-10), np.nan)

return result.reshape(target_h, target_w).astype(np.float32)
return _maybe_astype(result.reshape(target_h, target_w), out_dtype)


def _interp_block_cupy(block, global_in_h, global_in_w,
global_out_h, global_out_w,
cum_in_y, cum_in_x, cum_out_y, cum_out_x,
depth, order, block_info=None):
depth, order, work_dtype, out_dtype, block_info=None):
"""CuPy variant of :func:`_interp_block_np`."""
from cupyx.scipy.ndimage import map_coordinates as _cupy_map_coords

yi, xi = block_info[0]['chunk-location']
target_h = int(cum_out_y[yi + 1] - cum_out_y[yi])
target_w = int(cum_out_x[xi + 1] - cum_out_x[xi])

block = block.astype(cupy.float64)
if block.dtype != cupy.dtype(work_dtype):
block = block.astype(work_dtype)

oy = cupy.arange(int(cum_out_y[yi]), int(cum_out_y[yi + 1]),
dtype=cupy.float64)
Expand All @@ -671,20 +703,23 @@ def _interp_block_cupy(block, global_in_h, global_in_w,
result = _cupy_map_coords(block, coords, order=order, mode='nearest')
else:
filled = cupy.where(mask, 0.0, block)
weights = (~mask).astype(cupy.float64)
weights = (~mask).astype(block.dtype)
z_data = _cupy_map_coords(filled, coords, order=order, mode='nearest')
z_wt = _cupy_map_coords(weights, coords, order=order, mode='nearest')
# Majority-weight gate (see _nan_aware_interp_np for rationale).
result = cupy.where(z_wt > 0.5,
z_data / cupy.maximum(z_wt, 1e-10), cupy.nan)

return result.reshape(target_h, target_w).astype(cupy.float32)
result = result.reshape(target_h, target_w)
if result.dtype != cupy.dtype(out_dtype):
result = result.astype(out_dtype)
return result


def _agg_block_np(block, method, global_in_h, global_in_w,
global_out_h, global_out_w,
cum_in_y, cum_in_x, cum_out_y, cum_out_x,
depth_y, depth_x, block_info=None):
depth_y, depth_x, out_dtype, block_info=None):
"""Block-aggregate one (possibly overlapped) numpy chunk.

Runs the entire chunk inside one numba dispatch via the
Expand All @@ -695,7 +730,9 @@ def _agg_block_np(block, method, global_in_h, global_in_w,
target_h = int(cum_out_y[yi + 1] - cum_out_y[yi])
target_w = int(cum_out_x[xi + 1] - cum_out_x[xi])

block = block.astype(np.float64)
# _AGG_FUNCS kernels are @ngjit-compiled with hard-coded float64
# working buffers; cast accordingly so numba dispatch matches.
block = _maybe_astype(block, np.float64)
# The overlapped block starts depth pixels before the original chunk
in_y0 = int(cum_in_y[yi]) - depth_y
in_x0 = int(cum_in_x[xi]) - depth_x
Expand All @@ -708,44 +745,51 @@ def _agg_block_np(block, method, global_in_h, global_in_w,
int(global_in_h), int(global_in_w),
int(global_out_h), int(global_out_w),
in_y0, in_x0)
return out.astype(np.float32)
return _maybe_astype(out, out_dtype)


def _agg_block_cupy(block, method, global_in_h, global_in_w,
global_out_h, global_out_w,
cum_in_y, cum_in_x, cum_out_y, cum_out_x,
depth_y, depth_x, block_info=None):
depth_y, depth_x, out_dtype, block_info=None):
"""Block-aggregate one cupy chunk (falls back to CPU)."""
cpu = cupy.asnumpy(block)
result = _agg_block_np(
cpu, method, global_in_h, global_in_w,
global_out_h, global_out_w,
cum_in_y, cum_in_x, cum_out_y, cum_out_x,
depth_y, depth_x, block_info=block_info,
depth_y, depth_x, out_dtype, block_info=block_info,
)
return cupy.asarray(result)


# -- Per-backend runners -----------------------------------------------------

def _run_numpy(data, scale_y, scale_x, method):
data = data.astype(np.float64)
work_dt = _working_dtype(data.dtype)
out_dt = _output_dtype(data.dtype)
data = _maybe_astype(data, work_dt)
out_h, out_w = _output_shape(*data.shape, scale_y, scale_x)

if method in INTERP_METHODS:
return _nan_aware_interp_np(data, out_h, out_w,
INTERP_METHODS[method]).astype(np.float32)
result = _nan_aware_interp_np(data, out_h, out_w,
INTERP_METHODS[method])
return _maybe_astype(result, out_dt)

return _AGG_FUNCS[method](data, out_h, out_w).astype(np.float32)
result = _AGG_FUNCS[method](data, out_h, out_w)
return _maybe_astype(result, out_dt)


def _run_cupy(data, scale_y, scale_x, method):
data = data.astype(cupy.float64)
work_dt = _working_dtype(data.dtype)
out_dt = _output_dtype(data.dtype)
data = data if data.dtype == cupy.dtype(work_dt) else data.astype(work_dt)
out_h, out_w = _output_shape(*data.shape, scale_y, scale_x)

if method in INTERP_METHODS:
return _nan_aware_interp_cupy(data, out_h, out_w,
INTERP_METHODS[method]).astype(cupy.float32)
result = _nan_aware_interp_cupy(data, out_h, out_w,
INTERP_METHODS[method])
return result if result.dtype == cupy.dtype(out_dt) else result.astype(out_dt)

# Aggregate: GPU reshape+reduce for integer factors, CPU fallback otherwise
fy, fx = data.shape[0] / out_h, data.shape[1] / out_w
Expand All @@ -757,11 +801,12 @@ def _run_cupy(data, scale_y, scale_x, method):
reducer = {'average': cupy.nanmean,
'min': cupy.nanmin,
'max': cupy.nanmax}[method]
return reducer(reshaped, axis=(1, 3)).astype(cupy.float32)
result = reducer(reshaped, axis=(1, 3))
return result if result.dtype == cupy.dtype(out_dt) else result.astype(out_dt)

cpu = cupy.asnumpy(data)
return cupy.asarray(
_AGG_FUNCS[method](cpu, out_h, out_w).astype(np.float32)
_maybe_astype(_AGG_FUNCS[method](cpu, out_h, out_w), out_dt)
)


Expand Down Expand Up @@ -794,8 +839,11 @@ def _ensure_min_chunksize(data, min_size):


def _run_dask_numpy(data, scale_y, scale_x, method):
data = data.astype(np.float64)
meta = np.array((), dtype=np.float32)
work_dt = _working_dtype(data.dtype)
out_dt = _output_dtype(data.dtype)
if data.dtype != np.dtype(work_dt):
data = data.astype(work_dt)
meta = np.array((), dtype=out_dt)

if method in INTERP_METHODS:
order = INTERP_METHODS[method]
Expand Down Expand Up @@ -829,9 +877,10 @@ def _run_dask_numpy(data, scale_y, scale_x, method):
global_out_h=global_out_h, global_out_w=global_out_w,
cum_in_y=cum_in_y, cum_in_x=cum_in_x,
cum_out_y=cum_out_y, cum_out_x=cum_out_x,
depth=depth, order=order)
depth=depth, order=order,
work_dtype=work_dt, out_dtype=out_dt)
return da.map_blocks(fn, src, chunks=(out_y, out_x),
dtype=np.float32, meta=meta)
dtype=out_dt, meta=meta)

import math
# Aggregate windows can cross chunk boundaries; size chunks to satisfy
Expand Down Expand Up @@ -867,14 +916,18 @@ def _run_dask_numpy(data, scale_y, scale_x, method):
global_out_h=global_out_h, global_out_w=global_out_w,
cum_in_y=cum_in_y, cum_in_x=cum_in_x,
cum_out_y=cum_out_y, cum_out_x=cum_out_x,
depth_y=depth_y, depth_x=depth_x)
depth_y=depth_y, depth_x=depth_x,
out_dtype=out_dt)
return da.map_blocks(fn, src, chunks=(out_y, out_x),
dtype=np.float32, meta=meta)
dtype=out_dt, meta=meta)


def _run_dask_cupy(data, scale_y, scale_x, method):
data = data.astype(cupy.float64)
meta = cupy.array((), dtype=cupy.float32)
work_dt = _working_dtype(data.dtype)
out_dt = _output_dtype(data.dtype)
if data.dtype != cupy.dtype(work_dt):
data = data.astype(work_dt)
meta = cupy.array((), dtype=out_dt)

if method in INTERP_METHODS:
order = INTERP_METHODS[method]
Expand Down Expand Up @@ -908,9 +961,10 @@ def _run_dask_cupy(data, scale_y, scale_x, method):
global_out_h=global_out_h, global_out_w=global_out_w,
cum_in_y=cum_in_y, cum_in_x=cum_in_x,
cum_out_y=cum_out_y, cum_out_x=cum_out_x,
depth=depth, order=order)
depth=depth, order=order,
work_dtype=work_dt, out_dtype=out_dt)
return da.map_blocks(fn, src, chunks=(out_y, out_x),
dtype=cupy.float32, meta=meta)
dtype=out_dt, meta=meta)

import math
# Aggregate windows can cross chunk boundaries; size chunks to satisfy
Expand Down Expand Up @@ -946,9 +1000,10 @@ def _run_dask_cupy(data, scale_y, scale_x, method):
global_out_h=global_out_h, global_out_w=global_out_w,
cum_in_y=cum_in_y, cum_in_x=cum_in_x,
cum_out_y=cum_out_y, cum_out_x=cum_out_x,
depth_y=depth_y, depth_x=depth_x)
depth_y=depth_y, depth_x=depth_x,
out_dtype=out_dt)
return da.map_blocks(fn, src, chunks=(out_y, out_x),
dtype=cupy.float32, meta=meta)
dtype=out_dt, meta=meta)


# -- Public API --------------------------------------------------------------
Expand Down Expand Up @@ -986,8 +1041,9 @@ def _apply_nodata_mask(agg, nodata):
if nodata is None:
return agg
# Promote to float so NaN can be stored. xr.where keeps the backend.
# Integer / bool inputs become float32 (consistent with _output_dtype).
if not np.issubdtype(agg.dtype, np.floating):
agg = agg.astype(np.float64)
agg = agg.astype(np.float32)
if np.isnan(nodata):
return agg # already-NaN sentinels need no replacement
return agg.where(agg != nodata)
Expand Down Expand Up @@ -1039,8 +1095,10 @@ def resample(
Returns
-------
xarray.DataArray
Resampled raster with updated coordinates, ``res`` attribute,
and float32 dtype.
Resampled raster with updated coordinates and ``res`` attribute.
Output dtype matches the input float dtype (float32 or float64);
integer inputs return float32 since NaN-sentinel resampling
requires a float type.
"""
_validate_raster(agg, func_name='resample', name='agg', ndim=(2, 3))

Expand Down
Loading