diff --git a/xrspatial/focal.py b/xrspatial/focal.py index 2cb276fb..93c7bf69 100644 --- a/xrspatial/focal.py +++ b/xrspatial/focal.py @@ -1121,6 +1121,33 @@ def _focal_stats_cpu(agg, kernel, stats_funcs, boundary='nan'): return stats +_VALID_STATS_FUNCS = ('mean', 'max', 'min', 'range', 'std', 'var', + 'sum', 'variety') + + +def _validate_stats_funcs(stats_funcs): + """Normalise and validate the ``stats_funcs`` argument of focal_stats. + + A bare string is wrapped into a one-element list so it is not iterated + character by character. Unknown names raise a ValueError listing the + valid options. + """ + if isinstance(stats_funcs, str): + stats_funcs = [stats_funcs] + if len(stats_funcs) == 0: + raise ValueError( + f"stats_funcs must not be empty, " + f"choose from {list(_VALID_STATS_FUNCS)}" + ) + unknown = [s for s in stats_funcs if s not in _VALID_STATS_FUNCS] + if unknown: + raise ValueError( + f"Invalid stats_funcs {unknown}, " + f"must be one of {list(_VALID_STATS_FUNCS)}" + ) + return stats_funcs + + def focal_stats(agg, kernel, stats_funcs=None, @@ -1193,8 +1220,9 @@ def focal_stats(agg, Dimensions without coordinates: dim_0, dim_1 """ if stats_funcs is None: - stats_funcs = ['mean', 'max', 'min', 'range', 'std', 'var', - 'sum', 'variety'] + stats_funcs = list(_VALID_STATS_FUNCS) + else: + stats_funcs = _validate_stats_funcs(stats_funcs) _validate_raster(agg, func_name='focal_stats', name='agg', ndim=(2, 3)) diff --git a/xrspatial/tests/test_focal.py b/xrspatial/tests/test_focal.py index fdb7c67e..a803af40 100644 --- a/xrspatial/tests/test_focal.py +++ b/xrspatial/tests/test_focal.py @@ -1397,6 +1397,39 @@ def test_focal_stats_default_stats_funcs(): assert result.sizes['stats'] == 8 +def test_focal_stats_rejects_unknown_stats_func(): + # Regression for #2770: an unknown name used to fall through as a raw + # KeyError. It must now raise a clear ValueError listing valid options. + agg = xr.DataArray(data_random) + with pytest.raises(ValueError, match=r"Invalid stats_funcs.*bogus"): + focal_stats(agg, _api_kernel, stats_funcs=['bogus']) + + +def test_focal_stats_accepts_bare_string(): + # Regression for #2770: a bare string used to be iterated character by + # character (e.g. 'mean' -> 'm','e','a','n') and fail. It must be treated + # as a single stat name. + agg = xr.DataArray(data_random) + result = focal_stats(agg, _api_kernel, stats_funcs='mean') + assert result.sizes['stats'] == 1 + assert list(result.coords['stats'].values) == ['mean'] + + +def test_focal_stats_rejects_empty_stats_funcs(): + # Regression for #2770: an empty list used to reach xr.concat and fail with + # an obscure error. It must raise a clear ValueError instead. + agg = xr.DataArray(data_random) + with pytest.raises(ValueError, match=r"stats_funcs must not be empty"): + focal_stats(agg, _api_kernel, stats_funcs=[]) + + +def test_focal_stats_valid_list_happy_path(): + agg = xr.DataArray(data_random) + result = focal_stats(agg, _api_kernel, stats_funcs=['mean', 'sum']) + assert result.sizes['stats'] == 2 + assert list(result.coords['stats'].values) == ['mean', 'sum'] + + @cuda_and_cupy_available def test_focal_stats_name_gpu(): import cupy