Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions xrspatial/reproject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_compute_chunk_layout,
_compute_output_grid,
_make_output_coords,
_validate_grid_params,
)
from ._interpolate import (
_resample_cupy,
Expand Down Expand Up @@ -527,6 +528,15 @@ def reproject(
f"got {type(raster).__name__}"
)

_validate_grid_params(
resolution=resolution,
bounds=bounds,
width=width,
height=height,
transform_precision=transform_precision,
func_name='reproject',
)

_validate_resampling(resampling)

# Resolve CRS
Expand Down Expand Up @@ -1350,6 +1360,15 @@ def merge(
if not rasters:
raise ValueError("merge(): rasters list must not be empty")

_validate_grid_params(
resolution=resolution,
bounds=bounds,
width=None,
height=None,
transform_precision=None,
func_name='merge',
)

_validate_resampling(resampling)
_validate_strategy(strategy)

Expand Down
84 changes: 84 additions & 0 deletions xrspatial/reproject/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,90 @@
import numpy as np


def _validate_grid_params(*, resolution, bounds, width, height,
transform_precision, func_name):
"""Range-check user-supplied grid parameters."""
if resolution is not None:
if isinstance(resolution, (tuple, list)):
if len(resolution) != 2:
raise ValueError(
f"{func_name}(): resolution tuple must have length 2, "
f"got length {len(resolution)}"
)
res_values = list(resolution)
else:
res_values = [resolution]
for r in res_values:
try:
r_float = float(r)
except (TypeError, ValueError):
raise ValueError(
f"{func_name}(): resolution must be a positive finite "
f"number, got {r!r}"
)
if not (np.isfinite(r_float) and r_float > 0):
raise ValueError(
f"{func_name}(): resolution must be a positive finite "
f"number, got {r!r}"
)

for label, val in (('width', width), ('height', height)):
if val is None:
continue
if not isinstance(val, (int, np.integer)) or isinstance(val, bool):
raise ValueError(
f"{func_name}(): {label} must be a positive integer, got {val!r}"
)
if val <= 0:
raise ValueError(
f"{func_name}(): {label} must be a positive integer, got {val!r}"
)

if bounds is not None:
try:
left, bottom, right, top = bounds
except (TypeError, ValueError):
raise ValueError(
f"{func_name}(): bounds must be a 4-tuple "
f"(left, bottom, right, top), got {bounds!r}"
)
for label, v in (('left', left), ('bottom', bottom),
('right', right), ('top', top)):
try:
v_float = float(v)
except (TypeError, ValueError):
raise ValueError(
f"{func_name}(): bounds {label}={v!r} is not numeric"
)
if not np.isfinite(v_float):
raise ValueError(
f"{func_name}(): bounds {label} must be finite, got {v!r}"
)
if float(right) <= float(left):
raise ValueError(
f"{func_name}(): bounds right ({right}) must be greater "
f"than left ({left})"
)
if float(top) <= float(bottom):
raise ValueError(
f"{func_name}(): bounds top ({top}) must be greater "
f"than bottom ({bottom})"
)

if transform_precision is not None:
if (not isinstance(transform_precision, (int, np.integer))
or isinstance(transform_precision, bool)):
raise ValueError(
f"{func_name}(): transform_precision must be a non-negative "
f"integer, got {transform_precision!r}"
)
if transform_precision < 0:
raise ValueError(
f"{func_name}(): transform_precision must be a non-negative "
f"integer, got {transform_precision!r}"
)


def _transform_boundary(source_crs, target_crs, xs, ys):
"""Transform coordinate arrays, preferring Numba fast path over pyproj.

Expand Down
115 changes: 115 additions & 0 deletions xrspatial/tests/test_reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,121 @@ def test_numpy_chunk_source_window_guard(self):
assert result.shape[0] > 0 and result.shape[1] > 0


# =====================================================================
# Issue #1433: grid/bounds/precision parameter validation
# =====================================================================

class TestValidateGridParams:
"""reproject(): grid params reject zero / negative / non-finite."""

@staticmethod
def _good_raster():
return xr.DataArray(
np.zeros((4, 4), dtype=np.float64),
dims=('y', 'x'),
coords={'y': np.arange(4), 'x': np.arange(4)},
attrs={'crs': 'EPSG:4326'},
)

@pytest.mark.parametrize("res", [0, 0.0, -1, -2.5,
float('inf'), float('-inf'),
float('nan')])
def test_resolution_rejected(self, res):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="resolution"):
reproject(r, 'EPSG:4326', resolution=res)

def test_resolution_tuple_with_zero_rejected(self):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="resolution"):
reproject(r, 'EPSG:4326', resolution=(1.0, 0.0))

def test_resolution_tuple_wrong_length_rejected(self):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="length 2"):
reproject(r, 'EPSG:4326', resolution=(1.0, 2.0, 3.0))

@pytest.mark.parametrize("w", [0, -1, 1.5])
def test_width_rejected(self, w):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="width"):
reproject(r, 'EPSG:4326', width=w, height=10)

@pytest.mark.parametrize("h", [0, -1, 1.5])
def test_height_rejected(self, h):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="height"):
reproject(r, 'EPSG:4326', width=10, height=h)

def test_bounds_collapsed_x_rejected(self):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="right"):
reproject(r, 'EPSG:4326', bounds=(10, 0, 10, 10))

def test_bounds_collapsed_y_rejected(self):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="top"):
reproject(r, 'EPSG:4326', bounds=(0, 10, 10, 10))

def test_bounds_inverted_x_rejected(self):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="right"):
reproject(r, 'EPSG:4326', bounds=(10, 0, 0, 10))

def test_bounds_nan_rejected(self):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="finite"):
reproject(r, 'EPSG:4326', bounds=(0, 0, float('nan'), 10))

def test_bounds_wrong_length_rejected(self):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="4-tuple"):
reproject(r, 'EPSG:4326', bounds=(0, 0, 10))

def test_transform_precision_negative_rejected(self):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="transform_precision"):
reproject(r, 'EPSG:4326', transform_precision=-1)

def test_transform_precision_float_rejected(self):
from xrspatial.reproject import reproject
r = self._good_raster()
with pytest.raises(ValueError, match="transform_precision"):
reproject(r, 'EPSG:4326', transform_precision=1.5)


class TestValidateMergeGridParams:
@staticmethod
def _raster():
return xr.DataArray(
np.zeros((4, 4), dtype=np.float64),
dims=('y', 'x'),
coords={'y': np.arange(4), 'x': np.arange(4)},
attrs={'crs': 'EPSG:4326'},
)

def test_merge_resolution_rejected(self):
from xrspatial.reproject import merge
with pytest.raises(ValueError, match="resolution"):
merge([self._raster()], resolution=-1.0)

def test_merge_bounds_rejected(self):
from xrspatial.reproject import merge
with pytest.raises(ValueError, match="right"):
merge([self._raster()], bounds=(10, 0, 0, 10))


# =====================================================================
# Issue #1435: NaN/Inf rejection in scalar inputs
# =====================================================================
Expand Down