Skip to content

Commit

Permalink
Merge 56666d3 into de73d6f
Browse files Browse the repository at this point in the history
  • Loading branch information
vindelico committed Aug 23, 2023
2 parents de73d6f + 56666d3 commit 9fb0c44
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 22 deletions.
43 changes: 26 additions & 17 deletions clisops/core/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,20 +1101,25 @@ def subset_shape(

sp_dims = set(mask_2d.dims) # Spatial dimensions

# Find the outer mask. When subsetting unconnected shapes,
# we don't want to drop the inner NaN regions, it may cause problems downstream.
inner_mask = xarray.full_like(mask_2d, True, dtype=bool)
for dim in sp_dims:
# For each dimension, propagate shape indexes in either directions
# Then sum on the other dimension. You get a step function going from 0 to X.
# The non-zero part that left and right have in common is the "inner" zone.
left = mask_2d.bfill(dim).sum(sp_dims - {dim})
right = mask_2d.ffill(dim).sum(sp_dims - {dim})
# True in the inner zone, False in the outer
inner_mask = inner_mask & (left != 0) & (right != 0)

# inner_mask including the shapes
inner_mask = mask_2d.notnull() | inner_mask
if len(sp_dims) > 1:
# Find the outer mask. When subsetting unconnected shapes,
# we don't want to drop the inner NaN regions, it may cause problems downstream.
inner_mask = xarray.full_like(mask_2d, True, dtype=bool)
for dim in sp_dims:
# For each dimension, propagate shape indexes in either directions
# Then sum on the other dimension. You get a step function going from 0 to X.
# The non-zero part that left and right have in common is the "inner" zone.
left = mask_2d.bfill(dim).sum(sp_dims - {dim})
right = mask_2d.ffill(dim).sum(sp_dims - {dim})
# True in the inner zone, False in the outer
inner_mask = inner_mask & (left != 0) & (right != 0)

# inner_mask including the shapes
inner_mask = mask_2d.notnull() | inner_mask
else:
# in the locstream case inner_mask remains all True, but all non-polygon values can be dropped,
# so here "outside inner_mask" is everything outside the polygon.
inner_mask = mask_2d.notnull()

# loop through variables
for v in ds_copy.data_vars:
Expand All @@ -1125,9 +1130,13 @@ def subset_shape(
# Remove grid points outside the inner mask
# Then extract the coords.
# Using a where(inner_mask) on ds_copy triggers warnings with dask, sel seems safer.
mask_2d = mask_2d.where(inner_mask, drop=True)
for dim in sp_dims:
ds_copy = ds_copy.sel({dim: mask_2d[dim]})
# But this only works if dims have coords.
if set(sp_dims).issubset(ds_copy.coords.keys()):
mask_2d = mask_2d.where(inner_mask, drop=True)
for dim in sp_dims:
ds_copy = ds_copy.sel({dim: mask_2d[dim]})
else:
ds_copy = ds_copy.where(inner_mask, drop=True)

# Add a CRS definition using CF conventions and as a global attribute in CRS_WKT for reference purposes
ds_copy.attrs["crs"] = raster_crs.to_string()
Expand Down
10 changes: 5 additions & 5 deletions tests/test_core_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,17 +868,17 @@ def test_vectorize_touches_polygons(self):

def test_locstream(self):
da = xr.DataArray(
[1, 2, 3, 4],
[1, 2, 3, 4, 5, 6, 7, 8, 9],
dims=("site",),
coords={
"lat": (("site",), [10, 30, 20, 40]),
"lon": (("site",), [-50, -80, -70, -100]),
"lat": (("site",), [55, 55, 55, 40, 40, 40, 25, 25, 25]),
"lon": (("site",), [-80, -70, -60, -80, -70, -60, -80, -70, -60]),
},
)
poly = Polygon([[-90, 15], [-65, 15], [-65, 35], [-90, 35]])
poly = Polygon([[-90, 40], [-70, 20], [-50, 40], [-70, 60]])
shape = gpd.GeoDataFrame(geometry=[poly])
sub = subset.subset_shape(da, shape=shape)
exp = da.isel(site=[1, 2])
exp = da.isel(site=[1, 3, 4, 5, 7])
xr.testing.assert_identical(sub, exp)


Expand Down

0 comments on commit 9fb0c44

Please sign in to comment.