Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions xvec/tests/test_zonal_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,36 @@ def test_dataset(method):
)


@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
def test_dataset_flat(method):
ds = xr.tutorial.open_dataset("eraint_uvz").isel(month=0).isel(level=0)
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
result = ds.xvec.zonal_stats(world.geometry, "longitude", "latitude", method=method)

if method in ["exactextract", None]:
xr.testing.assert_allclose(
xr.Dataset(
{
"z": np.array(114857.63685302),
"u": np.array(9.84182437),
"v": np.array(-0.00330402),
}
),
result.mean(),
)
else:
xr.testing.assert_allclose(
xr.Dataset(
{
"z": np.array(114302.08524294),
"u": np.array(9.5196515),
"v": np.array(0.29297792),
}
),
result.drop_vars(["month", "level"]).mean(),
)


@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
def test_dataarray(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
Expand All @@ -163,6 +193,24 @@ def test_dataarray(method):
assert result.mean() == pytest.approx(61367.76185577)


@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
def test_dataarray_flat(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
result = (
ds.z.isel(month=0)
.isel(level=0)
.xvec.zonal_stats(world.geometry, "longitude", "latitude", method=method)
)

assert result.shape == (127,)
assert result.dims == ("geometry",)
if method in ["exactextract", None]:
assert result.mean() == pytest.approx(114857.63685302)
else:
assert result.mean() == pytest.approx(114302.08524294)


@pytest.mark.parametrize("method", [None, "rasterize", "iterate", "exactextract"])
def test_stat(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
Expand Down
32 changes: 20 additions & 12 deletions xvec/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,8 +698,14 @@ def _agg_exactextract(

# Stack the other dimensions into one dimension called "location"
arr_dims = tuple(dim for dim in acc._obj.dims if dim not in [x_coords, y_coords])
data = acc._obj.stack(location=arr_dims)
locs = data.location.size
if arr_dims:
# Stack non-spatial dimensions if they exist
data = acc._obj.stack(location=arr_dims)
locs = data.location.size
else:
# No additional dimensions to stack, create a dummy "location" dimension
data = acc._obj.expand_dims("location")
locs = 1

# Check the order of dimensions
data = data.transpose("location", y_coords, x_coords)
Expand All @@ -713,23 +719,25 @@ def _agg_exactextract(
results = exactextract.exact_extract(
rast=data, vec=gdf, ops=stats, output="pandas", strategy=strategy
)
# Get all the dimensions execpt x_coords, y_coords, they will be used to stack the
# Get all the dimensions except x_coords, y_coords, they will be used to stack the
# dataarray later
if original_is_ds is True:
# Get the original dataset information to use for unstacking the resulte later
# Get the original dataset information to use for unstacking the result later
coords_info = {name: geometry}
original_shape = [len(geometry)]
for dim in arr_dims:
original_shape.append(acc._obj[dim].size)
if dim != "variable":
coords_info[dim] = acc._obj[dim].values
if arr_dims:
for dim in arr_dims:
original_shape.append(acc._obj[dim].size)
if dim != "variable":
coords_info[dim] = acc._obj[dim].values
else:
# Get the original dataarray information to use for unstacking the resulte later
# Get the original dataarray information to use for unstacking the result later
coords_info = {name: geometry}
original_shape = [len(geometry)]
for dim in arr_dims:
original_shape.append(acc._obj[dim].size)
coords_info[dim] = acc._obj[dim].values
if arr_dims:
for dim in arr_dims:
original_shape.append(acc._obj[dim].size)
coords_info[dim] = acc._obj[dim].values
return results, original_shape, coords_info, locs


Expand Down
Loading