Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions xrspatial/focal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,33 @@ def _focal_stats_cpu(agg, kernel, stats_funcs, boundary='nan'):
return stats


_VALID_STATS_FUNCS = ('mean', 'max', 'min', 'range', 'std', 'var',
'sum', 'variety')


def _validate_stats_funcs(stats_funcs):
"""Normalise and validate the ``stats_funcs`` argument of focal_stats.

A bare string is wrapped into a one-element list so it is not iterated
character by character. Unknown names raise a ValueError listing the
valid options.
"""
if isinstance(stats_funcs, str):
stats_funcs = [stats_funcs]
if len(stats_funcs) == 0:
raise ValueError(
f"stats_funcs must not be empty, "
f"choose from {list(_VALID_STATS_FUNCS)}"
)
unknown = [s for s in stats_funcs if s not in _VALID_STATS_FUNCS]
if unknown:
raise ValueError(
f"Invalid stats_funcs {unknown}, "
f"must be one of {list(_VALID_STATS_FUNCS)}"
)
return stats_funcs


def focal_stats(agg,
kernel,
stats_funcs=None,
Expand Down Expand Up @@ -1193,8 +1220,9 @@ def focal_stats(agg,
Dimensions without coordinates: dim_0, dim_1
"""
if stats_funcs is None:
stats_funcs = ['mean', 'max', 'min', 'range', 'std', 'var',
'sum', 'variety']
stats_funcs = list(_VALID_STATS_FUNCS)
else:
stats_funcs = _validate_stats_funcs(stats_funcs)

_validate_raster(agg, func_name='focal_stats', name='agg', ndim=(2, 3))

Expand Down
33 changes: 33 additions & 0 deletions xrspatial/tests/test_focal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,39 @@ def test_focal_stats_default_stats_funcs():
assert result.sizes['stats'] == 8


def test_focal_stats_rejects_unknown_stats_func():
# Regression for #2770: an unknown name used to fall through as a raw
# KeyError. It must now raise a clear ValueError listing valid options.
agg = xr.DataArray(data_random)
with pytest.raises(ValueError, match=r"Invalid stats_funcs.*bogus"):
focal_stats(agg, _api_kernel, stats_funcs=['bogus'])


def test_focal_stats_accepts_bare_string():
# Regression for #2770: a bare string used to be iterated character by
# character (e.g. 'mean' -> 'm','e','a','n') and fail. It must be treated
# as a single stat name.
agg = xr.DataArray(data_random)
result = focal_stats(agg, _api_kernel, stats_funcs='mean')
assert result.sizes['stats'] == 1
assert list(result.coords['stats'].values) == ['mean']


def test_focal_stats_rejects_empty_stats_funcs():
# Regression for #2770: an empty list used to reach xr.concat and fail with
# an obscure error. It must raise a clear ValueError instead.
agg = xr.DataArray(data_random)
with pytest.raises(ValueError, match=r"stats_funcs must not be empty"):
focal_stats(agg, _api_kernel, stats_funcs=[])


def test_focal_stats_valid_list_happy_path():
agg = xr.DataArray(data_random)
result = focal_stats(agg, _api_kernel, stats_funcs=['mean', 'sum'])
assert result.sizes['stats'] == 2
assert list(result.coords['stats'].values) == ['mean', 'sum']


@cuda_and_cupy_available
def test_focal_stats_name_gpu():
import cupy
Expand Down
Loading