Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
overview_levels: list[int] | None = None,
overview_resampling: str = 'mean',
bigtiff: bool | None = None,
gpu: bool | None = None) -> None:
gpu: bool | None = None,
streaming_buffer_bytes: int = 256 * 1024 * 1024) -> None:
"""Write data as a GeoTIFF or Cloud Optimized GeoTIFF.

Dask-backed DataArrays are written in streaming mode: one tile-row
Expand Down Expand Up @@ -467,6 +468,11 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
'min', 'max', 'median', 'mode', or 'cubic'.
gpu : bool or None
Force GPU compression. None (default) auto-detects CuPy data.
streaming_buffer_bytes : int
Soft cap on bytes materialised per dask compute call when
streaming a dask-backed DataArray. Defaults to 256 MB. Wide
rasters whose tile-row exceeds this budget are split into
horizontal segments. Ignored for numpy / CuPy / COG paths.
"""
# VRT tiled output
if path.lower().endswith('.vrt'):
Expand Down Expand Up @@ -599,6 +605,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
gdal_metadata_xml=gdal_meta_xml,
extra_tags=extra_tags_list,
bigtiff=bigtiff,
streaming_buffer_bytes=streaming_buffer_bytes,
)
return

Expand Down
138 changes: 89 additions & 49 deletions xrspatial/geotiff/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,15 +1118,31 @@ def write_streaming(dask_data, path: str, *,
resolution_unit: int | None = None,
gdal_metadata_xml: str | None = None,
extra_tags: list | None = None,
bigtiff: bool | None = None) -> None:
"""Write a dask array as a GeoTIFF by streaming one tile-row at a time.

Peak memory is approximately ``tile_height * width * bytes_per_sample``
for tiled output, or ``rows_per_strip * width * bytes_per_sample`` for
stripped output.
bigtiff: bool | None = None,
streaming_buffer_bytes: int = 256 * 1024 * 1024) -> None:
"""Write a dask array as a GeoTIFF by streaming pixel data.

For tiled output, each tile-row is computed in horizontal segments
that fit within ``streaming_buffer_bytes``. Most rasters fit in a
single segment per tile-row, matching the previous behaviour. Wide
rasters get bounded peak memory at the cost of more dask compute
calls.

Peak materialised memory is approximately
``min(streaming_buffer_bytes, tile_height * width * bytes_per_sample
* samples)`` for tiled output, or
``rows_per_strip * width * bytes_per_sample * samples`` for stripped
output (no horizontal segmentation in strip mode).

After all pixel data is written the IFD offset and byte-count arrays
are patched in place.

Parameters
----------
streaming_buffer_bytes : int
Soft cap on bytes materialised per dask compute call when
writing tiles. Defaults to 256 MB. Values smaller than one tile
column are clamped up to one tile column.
"""
import os
import tempfile
Expand Down Expand Up @@ -1312,59 +1328,83 @@ def write_streaming(dask_data, path: str, *,

# Stream pixel data
if tiled:
# Decide how many tile-columns we can buffer at once.
# bytes_per_full_tile_row = tile_h * width * dtype * samples;
# if it fits the budget we buffer the whole row (matches
# original behaviour). Otherwise segment horizontally,
# always at tile boundaries to keep slicing aligned.
bytes_per_tile_col = (
th * tw * bytes_per_sample * samples)
bytes_per_full_row = bytes_per_tile_col * tiles_across
if bytes_per_full_row <= streaming_buffer_bytes:
tiles_per_segment = tiles_across
else:
tiles_per_segment = max(
1, streaming_buffer_bytes // bytes_per_tile_col)

for tr in range(tiles_down):
r0 = tr * th
r1 = min(r0 + th, height)
actual_h = r1 - r0

# Compute one tile-row from the dask graph
if dask_data.ndim == 3:
row_np = np.asarray(dask_data[r0:r1, :, :].compute())
else:
row_np = np.asarray(dask_data[r0:r1, :].compute())
if hasattr(row_np, 'get'):
row_np = row_np.get()

if row_np.dtype != out_dtype:
row_np = row_np.astype(out_dtype)

# NaN -> nodata sentinel
if (nodata is not None and row_np.dtype.kind == 'f'
and not np.isnan(nodata)):
nan_mask = np.isnan(row_np)
if nan_mask.any():
row_np = row_np.copy()
row_np[nan_mask] = row_np.dtype.type(nodata)

for tc in range(tiles_across):
c0 = tc * tw
c1 = min(c0 + tw, width)
actual_w = c1 - c0

tile_slice = row_np[:, c0:c1]
for seg_start in range(0, tiles_across, tiles_per_segment):
seg_end = min(seg_start + tiles_per_segment,
tiles_across)
seg_c0 = seg_start * tw
seg_c1 = min(seg_end * tw, width)

if actual_h < th or actual_w < tw:
if row_np.ndim == 3:
padded = np.zeros((th, tw, samples),
dtype=out_dtype)
else:
padded = np.zeros((th, tw), dtype=out_dtype)
padded[:actual_h, :actual_w] = tile_slice
tile_arr = padded
# Compute just this horizontal segment
if dask_data.ndim == 3:
seg_np = np.asarray(
dask_data[r0:r1, seg_c0:seg_c1, :].compute())
else:
tile_arr = np.ascontiguousarray(tile_slice)
seg_np = np.asarray(
dask_data[r0:r1, seg_c0:seg_c1].compute())
if hasattr(seg_np, 'get'):
seg_np = seg_np.get()

if seg_np.dtype != out_dtype:
seg_np = seg_np.astype(out_dtype)

# NaN -> nodata sentinel
if (nodata is not None and seg_np.dtype.kind == 'f'
and not np.isnan(nodata)):
nan_mask = np.isnan(seg_np)
if nan_mask.any():
seg_np = seg_np.copy()
seg_np[nan_mask] = seg_np.dtype.type(nodata)

for tc in range(seg_start, seg_end):
c0 = tc * tw
c1 = min(c0 + tw, width)
actual_w = c1 - c0

local_c0 = c0 - seg_c0
local_c1 = c1 - seg_c0
tile_slice = seg_np[:, local_c0:local_c1]

if actual_h < th or actual_w < tw:
if seg_np.ndim == 3:
padded = np.zeros((th, tw, samples),
dtype=out_dtype)
else:
padded = np.zeros((th, tw), dtype=out_dtype)
padded[:actual_h, :actual_w] = tile_slice
tile_arr = padded
else:
tile_arr = np.ascontiguousarray(tile_slice)

compressed = _compress_block(
tile_arr, tw, th, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level)
compressed = _compress_block(
tile_arr, tw, th, samples, out_dtype,
bytes_per_sample, pred_int, comp_tag,
compression_level)

actual_offsets.append(current_offset)
actual_counts.append(len(compressed))
f.write(compressed)
current_offset += len(compressed)
actual_offsets.append(current_offset)
actual_counts.append(len(compressed))
f.write(compressed)
current_offset += len(compressed)

del row_np
del seg_np
else:
# Strip layout
for i in range(n_entries):
Expand Down
71 changes: 71 additions & 0 deletions xrspatial/geotiff/tests/test_streaming_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,74 @@ def test_cog_with_dask_still_works(self, sample_raster, tmp_path):
result = open_geotiff(path)
np.testing.assert_array_almost_equal(
result.values, sample_raster.values, decimal=5)


# -- Tile-row segmentation by streaming_buffer_bytes (#1485) ------------------

class TestStreamingBufferBudget:
"""Horizontal segmentation when a tile-row exceeds the buffer budget."""

def test_round_trip_wide_raster(self, tmp_path):
"""Tight buffer forces multi-segment compute; output must be byte-equal."""
rng = np.random.default_rng(1485)
arr = rng.random((1024, 8192), dtype=np.float64)
da = xr.DataArray(arr, dims=['y', 'x'])
dask_da = da.chunk({'y': 256, 'x': 1024})

# Default budget: one segment covers the whole row (control)
ref_path = str(tmp_path / 'wide_default_1485.tif')
to_geotiff(dask_da, ref_path, compression='zstd')
ref = open_geotiff(ref_path).values

# Tight budget: forces ~2-tile segments per tile-row
seg_path = str(tmp_path / 'wide_segmented_1485.tif')
to_geotiff(dask_da, seg_path, compression='zstd',
streaming_buffer_bytes=2 * 256 * 256 * 8) # 2 tile cols
seg = open_geotiff(seg_path).values

np.testing.assert_array_equal(ref, arr)
np.testing.assert_array_equal(seg, arr)

def test_tight_4mb_budget_succeeds(self, tmp_path):
"""A 4 MB cap on a wide raster must succeed without OOM."""
# 256 rows * 8192 cols * 8 bytes = 16 MB per tile-row.
# 4 MB budget forces splitting each tile-row into segments.
arr = np.arange(256 * 8192, dtype=np.float64).reshape(256, 8192)
da = xr.DataArray(arr, dims=['y', 'x'])
dask_da = da.chunk({'y': 256, 'x': 2048})

path = str(tmp_path / 'tight_budget_1485.tif')
to_geotiff(dask_da, path, compression='zstd',
streaming_buffer_bytes=4 * 1024 * 1024)

result = open_geotiff(path).values
np.testing.assert_array_equal(result, arr)

def test_smaller_than_one_tile_clamps(self, tmp_path):
"""Budget below one tile must still produce a valid file (clamped)."""
arr = np.arange(256 * 1024, dtype=np.float32).reshape(256, 1024)
da = xr.DataArray(arr, dims=['y', 'x'])
dask_da = da.chunk({'y': 256, 'x': 256})

path = str(tmp_path / 'tiny_budget_1485.tif')
# 1 byte is well below one tile (256*256*4 = 262144 bytes)
to_geotiff(dask_da, path, compression='zstd',
streaming_buffer_bytes=1)

result = open_geotiff(path).values
np.testing.assert_array_equal(result, arr)

def test_multiband_segmentation(self, tmp_path):
"""3-band float64 raster with horizontal segmentation."""
rng = np.random.default_rng(1485)
arr = rng.random((512, 2048, 3), dtype=np.float64)
da = xr.DataArray(arr, dims=['y', 'x', 'band'])
dask_da = da.chunk({'y': 256, 'x': 512, 'band': 3})

path = str(tmp_path / 'multiband_1485.tif')
# Force ~2 tile-cols per segment
to_geotiff(dask_da, path, compression='zstd',
streaming_buffer_bytes=2 * 256 * 256 * 8 * 3)

result = open_geotiff(path).values
np.testing.assert_array_almost_equal(result, arr, decimal=10)
Loading