diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 3c02b675..d71d623e 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -411,7 +411,8 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, overview_levels: list[int] | None = None, overview_resampling: str = 'mean', bigtiff: bool | None = None, - gpu: bool | None = None) -> None: + gpu: bool | None = None, + streaming_buffer_bytes: int = 256 * 1024 * 1024) -> None: """Write data as a GeoTIFF or Cloud Optimized GeoTIFF. Dask-backed DataArrays are written in streaming mode: one tile-row @@ -467,6 +468,11 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, 'min', 'max', 'median', 'mode', or 'cubic'. gpu : bool or None Force GPU compression. None (default) auto-detects CuPy data. + streaming_buffer_bytes : int + Soft cap on bytes materialised per dask compute call when + streaming a dask-backed DataArray. Defaults to 256 MB. Wide + rasters whose tile-row exceeds this budget are split into + horizontal segments. Ignored for numpy / CuPy / COG paths. """ # VRT tiled output if path.lower().endswith('.vrt'): @@ -599,6 +605,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, gdal_metadata_xml=gdal_meta_xml, extra_tags=extra_tags_list, bigtiff=bigtiff, + streaming_buffer_bytes=streaming_buffer_bytes, ) return diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index 13311a94..9b6470d4 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -1118,15 +1118,31 @@ def write_streaming(dask_data, path: str, *, resolution_unit: int | None = None, gdal_metadata_xml: str | None = None, extra_tags: list | None = None, - bigtiff: bool | None = None) -> None: - """Write a dask array as a GeoTIFF by streaming one tile-row at a time. - - Peak memory is approximately ``tile_height * width * bytes_per_sample`` - for tiled output, or ``rows_per_strip * width * bytes_per_sample`` for - stripped output. + bigtiff: bool | None = None, + streaming_buffer_bytes: int = 256 * 1024 * 1024) -> None: + """Write a dask array as a GeoTIFF by streaming pixel data. + + For tiled output, each tile-row is computed in horizontal segments + that fit within ``streaming_buffer_bytes``. Most rasters fit in a + single segment per tile-row, matching the previous behaviour. Wide + rasters get bounded peak memory at the cost of more dask compute + calls. + + Peak materialised memory is approximately + ``min(streaming_buffer_bytes, tile_height * width * bytes_per_sample + * samples)`` for tiled output, or + ``rows_per_strip * width * bytes_per_sample * samples`` for stripped + output (no horizontal segmentation in strip mode). After all pixel data is written the IFD offset and byte-count arrays are patched in place. + + Parameters + ---------- + streaming_buffer_bytes : int + Soft cap on bytes materialised per dask compute call when + writing tiles. Defaults to 256 MB. Values smaller than one tile + column are clamped up to one tile column. """ import os import tempfile @@ -1312,59 +1328,83 @@ def write_streaming(dask_data, path: str, *, # Stream pixel data if tiled: + # Decide how many tile-columns we can buffer at once. + # bytes_per_full_tile_row = tile_h * width * dtype * samples; + # if it fits the budget we buffer the whole row (matches + # original behaviour). Otherwise segment horizontally, + # always at tile boundaries to keep slicing aligned. + bytes_per_tile_col = ( + th * tw * bytes_per_sample * samples) + bytes_per_full_row = bytes_per_tile_col * tiles_across + if bytes_per_full_row <= streaming_buffer_bytes: + tiles_per_segment = tiles_across + else: + tiles_per_segment = max( + 1, streaming_buffer_bytes // bytes_per_tile_col) + for tr in range(tiles_down): r0 = tr * th r1 = min(r0 + th, height) actual_h = r1 - r0 - # Compute one tile-row from the dask graph - if dask_data.ndim == 3: - row_np = np.asarray(dask_data[r0:r1, :, :].compute()) - else: - row_np = np.asarray(dask_data[r0:r1, :].compute()) - if hasattr(row_np, 'get'): - row_np = row_np.get() - - if row_np.dtype != out_dtype: - row_np = row_np.astype(out_dtype) - - # NaN -> nodata sentinel - if (nodata is not None and row_np.dtype.kind == 'f' - and not np.isnan(nodata)): - nan_mask = np.isnan(row_np) - if nan_mask.any(): - row_np = row_np.copy() - row_np[nan_mask] = row_np.dtype.type(nodata) - - for tc in range(tiles_across): - c0 = tc * tw - c1 = min(c0 + tw, width) - actual_w = c1 - c0 - - tile_slice = row_np[:, c0:c1] + for seg_start in range(0, tiles_across, tiles_per_segment): + seg_end = min(seg_start + tiles_per_segment, + tiles_across) + seg_c0 = seg_start * tw + seg_c1 = min(seg_end * tw, width) - if actual_h < th or actual_w < tw: - if row_np.ndim == 3: - padded = np.zeros((th, tw, samples), - dtype=out_dtype) - else: - padded = np.zeros((th, tw), dtype=out_dtype) - padded[:actual_h, :actual_w] = tile_slice - tile_arr = padded + # Compute just this horizontal segment + if dask_data.ndim == 3: + seg_np = np.asarray( + dask_data[r0:r1, seg_c0:seg_c1, :].compute()) else: - tile_arr = np.ascontiguousarray(tile_slice) + seg_np = np.asarray( + dask_data[r0:r1, seg_c0:seg_c1].compute()) + if hasattr(seg_np, 'get'): + seg_np = seg_np.get() + + if seg_np.dtype != out_dtype: + seg_np = seg_np.astype(out_dtype) + + # NaN -> nodata sentinel + if (nodata is not None and seg_np.dtype.kind == 'f' + and not np.isnan(nodata)): + nan_mask = np.isnan(seg_np) + if nan_mask.any(): + seg_np = seg_np.copy() + seg_np[nan_mask] = seg_np.dtype.type(nodata) + + for tc in range(seg_start, seg_end): + c0 = tc * tw + c1 = min(c0 + tw, width) + actual_w = c1 - c0 + + local_c0 = c0 - seg_c0 + local_c1 = c1 - seg_c0 + tile_slice = seg_np[:, local_c0:local_c1] + + if actual_h < th or actual_w < tw: + if seg_np.ndim == 3: + padded = np.zeros((th, tw, samples), + dtype=out_dtype) + else: + padded = np.zeros((th, tw), dtype=out_dtype) + padded[:actual_h, :actual_w] = tile_slice + tile_arr = padded + else: + tile_arr = np.ascontiguousarray(tile_slice) - compressed = _compress_block( - tile_arr, tw, th, samples, out_dtype, - bytes_per_sample, pred_int, comp_tag, - compression_level) + compressed = _compress_block( + tile_arr, tw, th, samples, out_dtype, + bytes_per_sample, pred_int, comp_tag, + compression_level) - actual_offsets.append(current_offset) - actual_counts.append(len(compressed)) - f.write(compressed) - current_offset += len(compressed) + actual_offsets.append(current_offset) + actual_counts.append(len(compressed)) + f.write(compressed) + current_offset += len(compressed) - del row_np + del seg_np else: # Strip layout for i in range(n_entries): diff --git a/xrspatial/geotiff/tests/test_streaming_write.py b/xrspatial/geotiff/tests/test_streaming_write.py index 328a4c40..23585294 100644 --- a/xrspatial/geotiff/tests/test_streaming_write.py +++ b/xrspatial/geotiff/tests/test_streaming_write.py @@ -261,3 +261,74 @@ def test_cog_with_dask_still_works(self, sample_raster, tmp_path): result = open_geotiff(path) np.testing.assert_array_almost_equal( result.values, sample_raster.values, decimal=5) + + +# -- Tile-row segmentation by streaming_buffer_bytes (#1485) ------------------ + +class TestStreamingBufferBudget: + """Horizontal segmentation when a tile-row exceeds the buffer budget.""" + + def test_round_trip_wide_raster(self, tmp_path): + """Tight buffer forces multi-segment compute; output must be byte-equal.""" + rng = np.random.default_rng(1485) + arr = rng.random((1024, 8192), dtype=np.float64) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 256, 'x': 1024}) + + # Default budget: one segment covers the whole row (control) + ref_path = str(tmp_path / 'wide_default_1485.tif') + to_geotiff(dask_da, ref_path, compression='zstd') + ref = open_geotiff(ref_path).values + + # Tight budget: forces ~2-tile segments per tile-row + seg_path = str(tmp_path / 'wide_segmented_1485.tif') + to_geotiff(dask_da, seg_path, compression='zstd', + streaming_buffer_bytes=2 * 256 * 256 * 8) # 2 tile cols + seg = open_geotiff(seg_path).values + + np.testing.assert_array_equal(ref, arr) + np.testing.assert_array_equal(seg, arr) + + def test_tight_4mb_budget_succeeds(self, tmp_path): + """A 4 MB cap on a wide raster must succeed without OOM.""" + # 256 rows * 8192 cols * 8 bytes = 16 MB per tile-row. + # 4 MB budget forces splitting each tile-row into segments. + arr = np.arange(256 * 8192, dtype=np.float64).reshape(256, 8192) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 256, 'x': 2048}) + + path = str(tmp_path / 'tight_budget_1485.tif') + to_geotiff(dask_da, path, compression='zstd', + streaming_buffer_bytes=4 * 1024 * 1024) + + result = open_geotiff(path).values + np.testing.assert_array_equal(result, arr) + + def test_smaller_than_one_tile_clamps(self, tmp_path): + """Budget below one tile must still produce a valid file (clamped).""" + arr = np.arange(256 * 1024, dtype=np.float32).reshape(256, 1024) + da = xr.DataArray(arr, dims=['y', 'x']) + dask_da = da.chunk({'y': 256, 'x': 256}) + + path = str(tmp_path / 'tiny_budget_1485.tif') + # 1 byte is well below one tile (256*256*4 = 262144 bytes) + to_geotiff(dask_da, path, compression='zstd', + streaming_buffer_bytes=1) + + result = open_geotiff(path).values + np.testing.assert_array_equal(result, arr) + + def test_multiband_segmentation(self, tmp_path): + """3-band float64 raster with horizontal segmentation.""" + rng = np.random.default_rng(1485) + arr = rng.random((512, 2048, 3), dtype=np.float64) + da = xr.DataArray(arr, dims=['y', 'x', 'band']) + dask_da = da.chunk({'y': 256, 'x': 512, 'band': 3}) + + path = str(tmp_path / 'multiband_1485.tif') + # Force ~2 tile-cols per segment + to_geotiff(dask_da, path, compression='zstd', + streaming_buffer_bytes=2 * 256 * 256 * 8 * 3) + + result = open_geotiff(path).values + np.testing.assert_array_almost_equal(result, arr, decimal=10)