diff --git a/xrspatial/polygon_clip.py b/xrspatial/polygon_clip.py index 7e49de2f..3338c912 100644 --- a/xrspatial/polygon_clip.py +++ b/xrspatial/polygon_clip.py @@ -198,8 +198,9 @@ def clip_polygon( raster = _crop_to_bbox(raster, geom_pairs, all_touched=all_touched) # Build a binary mask via rasterize, aligned to the (possibly cropped) - # raster grid. Always produce a plain numpy mask first, then convert - # to the raster's backend so xarray's .where() sees matching types. + # raster grid. Propagate the raster's chunk structure so the mask is + # built lazily for dask backends instead of materializing a full numpy + # array. from .rasterize import rasterize kw = dict(rasterize_kw or {}) @@ -208,6 +209,12 @@ def clip_polygon( kw['dtype'] = np.uint8 kw['all_touched'] = all_touched + if has_dask_array() and isinstance(raster.data, da.Array): + rc, cc = raster.data.chunks[-2], raster.data.chunks[-1] + kw.setdefault('chunks', (rc[0], cc[0])) + if has_cuda_and_cupy() and is_dask_cupy(raster): + kw.setdefault('use_cuda', True) + mask = rasterize(geom_pairs, **kw) # Apply the mask. Keep it lazy for dask backends to avoid diff --git a/xrspatial/tests/test_polygon_clip.py b/xrspatial/tests/test_polygon_clip.py index 61373609..c38d66a6 100644 --- a/xrspatial/tests/test_polygon_clip.py +++ b/xrspatial/tests/test_polygon_clip.py @@ -325,3 +325,36 @@ def test_geoseries_input(self): gs = geopandas.GeoSeries([poly]) result = clip_polygon(raster, gs, crop=False) assert result.shape == raster.shape + + +# --------------------------------------------------------------------------- +# Issue #1207 regression tests +# --------------------------------------------------------------------------- + +@dask_array_available +class TestClipPolygonDaskLazyMask: + def test_mask_stays_lazy_for_dask_input(self): + """clip_polygon on dask input should not materialize a full numpy mask (#1207). + + We verify by checking that the dask task graph contains rasterize + chunk tasks (not just a single from_array wrapping a pre-computed + numpy array). + """ + import dask.array as da + + dk_raster = _make_raster(backend='dask+numpy', chunks=(4, 3)) + poly = _inner_polygon() + + result = clip_polygon(dk_raster, poly, crop=False) + assert isinstance(result.data, da.Array) + + # With chunked rasterize, the graph has tasks per chunk. + # With the old approach (numpy mask + da.from_array), the graph + # would have fewer chunk-level tasks for the mask. + graph = dict(result.data.__dask_graph__()) + # At minimum, a 8x6 raster with (4,3) chunks = 2x2 = 4 mask chunks + # plus raster chunks plus where-condition tasks. + # Just verify we have more than the trivial single-mask case. + assert len(graph) > 4, ( + f"graph has only {len(graph)} tasks; mask may not be chunked" + )