Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions xrspatial/polygon_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ def clip_polygon(
raster = _crop_to_bbox(raster, geom_pairs, all_touched=all_touched)

# Build a binary mask via rasterize, aligned to the (possibly cropped)
# raster grid. Always produce a plain numpy mask first, then convert
# to the raster's backend so xarray's .where() sees matching types.
# raster grid. Propagate the raster's chunk structure so the mask is
# built lazily for dask backends instead of materializing a full numpy
# array.
from .rasterize import rasterize

kw = dict(rasterize_kw or {})
Expand All @@ -208,6 +209,12 @@ def clip_polygon(
kw['dtype'] = np.uint8
kw['all_touched'] = all_touched

if has_dask_array() and isinstance(raster.data, da.Array):
rc, cc = raster.data.chunks[-2], raster.data.chunks[-1]
kw.setdefault('chunks', (rc[0], cc[0]))
if has_cuda_and_cupy() and is_dask_cupy(raster):
kw.setdefault('use_cuda', True)

mask = rasterize(geom_pairs, **kw)

# Apply the mask. Keep it lazy for dask backends to avoid
Expand Down
33 changes: 33 additions & 0 deletions xrspatial/tests/test_polygon_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,36 @@ def test_geoseries_input(self):
gs = geopandas.GeoSeries([poly])
result = clip_polygon(raster, gs, crop=False)
assert result.shape == raster.shape


# ---------------------------------------------------------------------------
# Issue #1207 regression tests
# ---------------------------------------------------------------------------

@dask_array_available
class TestClipPolygonDaskLazyMask:
def test_mask_stays_lazy_for_dask_input(self):
"""clip_polygon on dask input should not materialize a full numpy mask (#1207).

We verify by checking that the dask task graph contains rasterize
chunk tasks (not just a single from_array wrapping a pre-computed
numpy array).
"""
import dask.array as da

dk_raster = _make_raster(backend='dask+numpy', chunks=(4, 3))
poly = _inner_polygon()

result = clip_polygon(dk_raster, poly, crop=False)
assert isinstance(result.data, da.Array)

# With chunked rasterize, the graph has tasks per chunk.
# With the old approach (numpy mask + da.from_array), the graph
# would have fewer chunk-level tasks for the mask.
graph = dict(result.data.__dask_graph__())
# At minimum, a 8x6 raster with (4,3) chunks = 2x2 = 4 mask chunks
# plus raster chunks plus where-condition tasks.
# Just verify we have more than the trivial single-mask case.
assert len(graph) > 4, (
f"graph has only {len(graph)} tasks; mask may not be chunked"
)
Loading