diff --git a/xvec/tests/test_zonal_stats.py b/xvec/tests/test_zonal_stats.py index 64edabe..18b4f48 100644 --- a/xvec/tests/test_zonal_stats.py +++ b/xvec/tests/test_zonal_stats.py @@ -147,6 +147,36 @@ def test_dataset(method): ) +@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"]) +def test_dataset_flat(method): + ds = xr.tutorial.open_dataset("eraint_uvz").isel(month=0).isel(level=0) + world = gpd.read_file(geodatasets.get_path("naturalearth land")) + result = ds.xvec.zonal_stats(world.geometry, "longitude", "latitude", method=method) + + if method in ["exactextract", None]: + xr.testing.assert_allclose( + xr.Dataset( + { + "z": np.array(114857.63685302), + "u": np.array(9.84182437), + "v": np.array(-0.00330402), + } + ), + result.mean(), + ) + else: + xr.testing.assert_allclose( + xr.Dataset( + { + "z": np.array(114302.08524294), + "u": np.array(9.5196515), + "v": np.array(0.29297792), + } + ), + result.drop_vars(["month", "level"]).mean(), + ) + + @pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"]) def test_dataarray(method): ds = xr.tutorial.open_dataset("eraint_uvz") @@ -163,6 +193,24 @@ def test_dataarray(method): assert result.mean() == pytest.approx(61367.76185577) +@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"]) +def test_dataarray_flat(method): + ds = xr.tutorial.open_dataset("eraint_uvz") + world = gpd.read_file(geodatasets.get_path("naturalearth land")) + result = ( + ds.z.isel(month=0) + .isel(level=0) + .xvec.zonal_stats(world.geometry, "longitude", "latitude", method=method) + ) + + assert result.shape == (127,) + assert result.dims == ("geometry",) + if method in ["exactextract", None]: + assert result.mean() == pytest.approx(114857.63685302) + else: + assert result.mean() == pytest.approx(114302.08524294) + + @pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"]) def test_stat(method): ds = xr.tutorial.open_dataset("eraint_uvz") diff --git a/xvec/zonal.py b/xvec/zonal.py index 628ad3b..bb22806 100644 --- a/xvec/zonal.py +++ b/xvec/zonal.py @@ -698,8 +698,14 @@ def _agg_exactextract( # Stack the other dimensions into one dimension called "location" arr_dims = tuple(dim for dim in acc._obj.dims if dim not in [x_coords, y_coords]) - data = acc._obj.stack(location=arr_dims) - locs = data.location.size + if arr_dims: + # Stack non-spatial dimensions if they exist + data = acc._obj.stack(location=arr_dims) + locs = data.location.size + else: + # No additional dimensions to stack, create a dummy "location" dimension + data = acc._obj.expand_dims("location") + locs = 1 # Check the order of dimensions data = data.transpose("location", y_coords, x_coords) @@ -713,23 +719,25 @@ def _agg_exactextract( results = exactextract.exact_extract( rast=data, vec=gdf, ops=stats, output="pandas", strategy=strategy ) - # Get all the dimensions execpt x_coords, y_coords, they will be used to stack the + # Get all the dimensions except x_coords, y_coords, they will be used to stack the # dataarray later if original_is_ds is True: - # Get the original dataset information to use for unstacking the resulte later + # Get the original dataset information to use for unstacking the result later coords_info = {name: geometry} original_shape = [len(geometry)] - for dim in arr_dims: - original_shape.append(acc._obj[dim].size) - if dim != "variable": - coords_info[dim] = acc._obj[dim].values + if arr_dims: + for dim in arr_dims: + original_shape.append(acc._obj[dim].size) + if dim != "variable": + coords_info[dim] = acc._obj[dim].values else: - # Get the original dataarray information to use for unstacking the resulte later + # Get the original dataarray information to use for unstacking the result later coords_info = {name: geometry} original_shape = [len(geometry)] - for dim in arr_dims: - original_shape.append(acc._obj[dim].size) - coords_info[dim] = acc._obj[dim].values + if arr_dims: + for dim in arr_dims: + original_shape.append(acc._obj[dim].size) + coords_info[dim] = acc._obj[dim].values return results, original_shape, coords_info, locs