From 55ab86f4b9d8e20b439313290fdedf6ac0f7083d Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Tue, 5 May 2026 04:07:27 -0700 Subject: [PATCH] Stream tile writes per dask chunk segment to bound peak memory in to_geotiff (#1485) `write_streaming` previously materialised one full tile-row per dask compute call. For wide rasters that buffer dwarfs a single tile and defeats the point of streaming. Add `streaming_buffer_bytes` (default 256 MB) to `to_geotiff` and thread it through to `write_streaming`. Tile-rows whose footprint exceeds the budget are split into horizontal segments at tile-column boundaries; each segment is computed, written, and freed before the next. Behaviour is unchanged for rasters that fit in one segment (the common case). Strip layout, COG, GPU, and eager paths are not touched. Tests added in `test_streaming_write.py::TestStreamingBufferBudget`: round-trip equality, a 4 MB tight-budget regression on a 16 MB tile-row, sub-tile clamping, and a 3-band multiband case. --- xrspatial/geotiff/__init__.py | 9 +- xrspatial/geotiff/_writer.py | 138 +++++++++++------- .../geotiff/tests/test_streaming_write.py | 71 +++++++++ 3 files changed, 168 insertions(+), 50 deletions(-) diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 3c02b675..d71d623e 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -411,7 +411,8 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, overview_levels: list[int] | None = None, overview_resampling: str = 'mean', bigtiff: bool | None = None, - gpu: bool | None = None) -> None: + gpu: bool | None = None, + streaming_buffer_bytes: int = 256 * 1024 * 1024) -> None: """Write data as a GeoTIFF or Cloud Optimized GeoTIFF. Dask-backed DataArrays are written in streaming mode: one tile-row @@ -467,6 +468,11 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, 'min', 'max', 'median', 'mode', or 'cubic'. gpu : bool or None Force GPU compression. None (default) auto-detects CuPy data. + streaming_buffer_bytes : int + Soft cap on bytes materialised per dask compute call when + streaming a dask-backed DataArray. Defaults to 256 MB. Wide + rasters whose tile-row exceeds this budget are split into + horizontal segments. Ignored for numpy / CuPy / COG paths. """ # VRT tiled output if path.lower().endswith('.vrt'): @@ -599,6 +605,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, gdal_metadata_xml=gdal_meta_xml, extra_tags=extra_tags_list, bigtiff=bigtiff, + streaming_buffer_bytes=streaming_buffer_bytes, ) return diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index 13311a94..9b6470d4 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -1118,15 +1118,31 @@ def write_streaming(dask_data, path: str, *, resolution_unit: int | None = None, gdal_metadata_xml: str | None = None, extra_tags: list | None = None, - bigtiff: bool | None = None) -> None: - """Write a dask array as a GeoTIFF by streaming one tile-row at a time. - - Peak memory is approximately ``tile_height * width * bytes_per_sample`` - for tiled output, or ``rows_per_strip * width * bytes_per_sample`` for - stripped output. + bigtiff: bool | None = None, + streaming_buffer_bytes: int = 256 * 1024 * 1024) -> None: + """Write a dask array as a GeoTIFF by streaming pixel data. + + For tiled output, each tile-row is computed in horizontal segments + that fit within ``streaming_buffer_bytes``. Most rasters fit in a + single segment per tile-row, matching the previous behaviour. Wide + rasters get bounded peak memory at the cost of more dask compute + calls. + + Peak materialised memory is approximately + ``min(streaming_buffer_bytes, tile_height * width * bytes_per_sample + * samples)`` for tiled output, or + ``rows_per_strip * width * bytes_per_sample * samples`` for stripped + output (no horizontal segmentation in strip mode). After all pixel data is written the IFD offset and byte-count arrays are patched in place. + + Parameters + ---------- + streaming_buffer_bytes : int + Soft cap on bytes materialised per dask compute call when + writing tiles. Defaults to 256 MB. Values smaller than one tile + column are clamped up to one tile column. """ import os import tempfile @@ -1312,59 +1328,83 @@ def write_streaming(dask_data, path: str, *, # Stream pixel data if tiled: + # Decide how many tile-columns we can buffer at once. + # bytes_per_full_tile_row = tile_h * width * dtype * samples; + # if it fits the budget we buffer the whole row (matches + # original behaviour). Otherwise segment horizontally, + # always at tile boundaries to keep slicing aligned. + bytes_per_tile_col = ( + th * tw * bytes_per_sample * samples) + bytes_per_full_row = bytes_per_tile_col * tiles_across + if bytes_per_full_row <= streaming_buffer_bytes: + tiles_per_segment = tiles_across + else: + tiles_per_segment = max( + 1, streaming_buffer_bytes // bytes_per_tile_col) + for tr in range(tiles_down): r0 = tr * th r1 = min(r0 + th, height) actual_h = r1 - r0 - # Compute one tile-row from the dask graph - if dask_data.ndim == 3: - row_np = np.asarray(dask_data[r0:r1, :, :].compute()) - else: - row_np = np.asarray(dask_data[r0:r1, :].compute()) - if hasattr(row_np, 'get'): - row_np = row_np.get() - - if row_np.dtype != out_dtype: - row_np = row_np.astype(out_dtype) - - # NaN -> nodata sentinel - if (nodata is not None and row_np.dtype.kind == 'f' - and not np.isnan(nodata)): - nan_mask = np.isnan(row_np) - if nan_mask.any(): - row_np = row_np.copy() - row_np[nan_mask] = row_np.dtype.type(nodata) - - for tc in range(tiles_across): - c0 = tc * tw - c1 = min(c0 + tw, width) - actual_w = c1 - c0 - - tile_slice = row_np[:, c0:c1] + for seg_start in range(0, tiles_across, tiles_per_segment): + seg_end = min(seg_start + tiles_per_segment, + tiles_across) + seg_c0 = seg_start * tw + seg_c1 = min(seg_end * tw, width) - if actual_h < th or actual_w < tw: - if row_np.ndim == 3: - padded = np.zeros((th, tw, samples), - dtype=out_dtype) - else: - padded = np.zeros((th, tw), dtype=out_dtype) - padded[:actual_h, :actual_w] = tile_slice - tile_arr = padded + # Compute just this horizontal segment + if dask_data.ndim == 3: + seg_np = np.asarray( + dask_data[r0:r1, seg_c0:seg_c1, :].compute()) else: - tile_arr = np.ascontiguousarray(tile_slice) + seg_np = np.asarray( + dask_data[r0:r1, seg_c0:seg_c1].compute()) + if hasattr(seg_np, 'get'): + seg_np = seg_np.get() + + if seg_np.dtype != out_dtype: + seg_np = seg_np.astype(out_dtype) + + # NaN -> nodata sentinel + if (nodata is not None and seg_np.dtype.kind == 'f' + and not np.isnan(nodata)): + nan_mask = np.isnan(seg_np) + if nan_mask.any(): + seg_np = seg_np.copy() + seg_np[nan_mask] = seg_np.dtype.type(nodata) + + for tc in range(seg_start, seg_end): + c0 = tc * tw + c1 = min(c0 + tw, width) + actual_w = c1 - c0 + + local_c0 = c0 - seg_c0 + local_c1 = c1 - seg_c0 + tile_slice = seg_np[:, local_c0:local_c1] + + if actual_h < th or actual_w < tw: + if seg_np.ndim == 3: + padded = np.zeros((th, tw, samples), + dtype=out_dtype) + else: + padded = np.zeros((th, tw), dtype=out_dtype) + padded[:actual_h, :actual_w] = tile_slice + tile_arr = padded + else: + tile_arr = np.ascontiguousarray(tile_slice) - compressed = _compress_block( - tile_arr, tw, th, samples, out_dtype, - bytes_per_sample, pred_int, comp_tag, - compression_level) + compressed = _compress_block( + tile_arr, tw, th, samples, out_dtype, + bytes_per_sample, pred_int, comp_tag, + compression_level) - actual_offsets.append(current_offset) - actual_counts.append(len(compressed)) - f.write(compressed) - current_offset += len(compressed) + actual_offsets.append(current_offset) + actual_counts.append(len(compressed)) + f.write(compressed) + current_offset += len(compressed) - del row_np + del seg_np else: # Strip layout for i in range(n_entries): diff --git a/xrspatial/geotiff/tests/test_streaming_write.py b/xrspatial/geotiff/tests/test_streaming_write.py index 328a4c40..23585294 100644 --- a/xrspatial/geotiff/tests/test_streaming_write.py +++ b/xrspatial/geotiff/tests/test_streaming_write.py @@ -261,3 +261,74 @@ def test_cog_with_dask_still_works(self, sample_raster, tmp_path): result = open_geotiff(path) np.testing.assert_array_almost_equal( result.values, sample_raster.values, decimal=5) + + +# -- Tile-row segmentation by streaming_buffer_bytes (#1485) ------------------ + +class TestStreamingBufferBudget: + """Horizontal segmentation when a tile-row exceeds the buffer budget.""" + + def test_round_trip_wide_raster(self, tmp_path): + """Tight buffer forces multi-segment compute; output must be byte-equal.""" + rng = np.random.default_rng(1485) + arr = rng.random((1024, 8192), dtype=np.float64) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 256, 'x': 1024}) + + # Default budget: one segment covers the whole row (control) + ref_path = str(tmp_path / 'wide_default_1485.tif') + to_geotiff(dask_da, ref_path, compression='zstd') + ref = open_geotiff(ref_path).values + + # Tight budget: forces ~2-tile segments per tile-row + seg_path = str(tmp_path / 'wide_segmented_1485.tif') + to_geotiff(dask_da, seg_path, compression='zstd', + streaming_buffer_bytes=2 * 256 * 256 * 8) # 2 tile cols + seg = open_geotiff(seg_path).values + + np.testing.assert_array_equal(ref, arr) + np.testing.assert_array_equal(seg, arr) + + def test_tight_4mb_budget_succeeds(self, tmp_path): + """A 4 MB cap on a wide raster must succeed without OOM.""" + # 256 rows * 8192 cols * 8 bytes = 16 MB per tile-row. + # 4 MB budget forces splitting each tile-row into segments. + arr = np.arange(256 * 8192, dtype=np.float64).reshape(256, 8192) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 256, 'x': 2048}) + + path = str(tmp_path / 'tight_budget_1485.tif') + to_geotiff(dask_da, path, compression='zstd', + streaming_buffer_bytes=4 * 1024 * 1024) + + result = open_geotiff(path).values + np.testing.assert_array_equal(result, arr) + + def test_smaller_than_one_tile_clamps(self, tmp_path): + """Budget below one tile must still produce a valid file (clamped).""" + arr = np.arange(256 * 1024, dtype=np.float32).reshape(256, 1024) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 256, 'x': 256}) + + path = str(tmp_path / 'tiny_budget_1485.tif') + # 1 byte is well below one tile (256*256*4 = 262144 bytes) + to_geotiff(dask_da, path, compression='zstd', + streaming_buffer_bytes=1) + + result = open_geotiff(path).values + np.testing.assert_array_equal(result, arr) + + def test_multiband_segmentation(self, tmp_path): + """3-band float64 raster with horizontal segmentation.""" + rng = np.random.default_rng(1485) + arr = rng.random((512, 2048, 3), dtype=np.float64) + da = xr.DataArray(arr, dims=['y', 'x', 'band']) + dask_da = da.chunk({'y': 256, 'x': 512, 'band': 3}) + + path = str(tmp_path / 'multiband_1485.tif') + # Force ~2 tile-cols per segment + to_geotiff(dask_da, path, compression='zstd', + streaming_buffer_bytes=2 * 256 * 256 * 8 * 3) + + result = open_geotiff(path).values + np.testing.assert_array_almost_equal(result, arr, decimal=10)