Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 73 additions & 22 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,14 @@ def _is_gpu_data(data) -> bool:
'lz4': (0, 16),
}

# Names accepted by ``compression=`` in :func:`to_geotiff`. Kept in sync with
# ``_compression_tag`` in ``_writer.py``. Validated up-front so users see a
# friendly error rather than the deeper traceback from ``_compression_tag``.
_VALID_COMPRESSIONS = (
'none', 'deflate', 'lzw', 'jpeg', 'packbits', 'zstd', 'lz4',
'jpeg2000', 'j2k', 'lerc',
)


def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
crs: int | str | None = None,
Expand Down Expand Up @@ -452,12 +460,17 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
tiled : bool
Use tiled layout (default True).
tile_size : int
Tile size in pixels (default 256).
Tile size in pixels (default 256). Ignored when ``tiled=False``;
a warning is emitted if a non-default value is passed alongside
strip mode.
predictor : bool or int
TIFF predictor. ``False``/``0``/``1`` -> none, ``True``/``2`` ->
horizontal differencing (good for integer data), ``3`` ->
floating-point predictor (float dtypes only; typically gives
better deflate/zstd ratios on float data than predictor 2).
TIFF predictor. Accepted values:

* ``False``, ``0``, or ``1`` -> no predictor.
* ``True`` or ``2`` -> horizontal differencing (good for integer
data; ``True`` and ``2`` are exactly equivalent).
* ``3`` -> floating-point predictor (float dtypes only; typically
gives better deflate/zstd ratios on float data than predictor 2).
cog : bool
Write as Cloud Optimized GeoTIFF.
overview_levels : list[int] or None
Expand All @@ -468,6 +481,27 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
gpu : bool or None
Force GPU compression. None (default) auto-detects CuPy data.
"""
# Up-front validation: catch bad compression names before they reach
# any of the deeper write paths (streaming, GPU, VRT, COG) where the
# error surfaces from _compression_tag with a less obvious traceback.
if isinstance(compression, str):
if compression.lower() not in _VALID_COMPRESSIONS:
raise ValueError(
f"Unknown compression {compression!r}. "
f"Valid options: {list(_VALID_COMPRESSIONS)}.")

# tile_size only applies to tiled output; warn if the caller passed a
# non-default size alongside strip mode (it would otherwise be silently
# ignored).
if not tiled and tile_size != 256:
import warnings
warnings.warn(
f"tile_size={tile_size} is ignored when tiled=False "
"(strip layout). Pass tiled=True to use tile_size, or drop "
"tile_size to silence this warning.",
stacklevel=2,
)

# VRT tiled output
if path.lower().endswith('.vrt'):
if cog:
Expand Down Expand Up @@ -900,7 +934,11 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,
"""
import dask.array as da

# VRT files: delegate to read_vrt which handles chunks
# ``read_geotiff`` already routes ``.vrt`` to ``read_vrt`` before
# reaching here, so this branch is only hit when ``read_geotiff_dask``
# is called directly with a VRT path. Keep it as a defensive fallback
# rather than letting the windowed-read path try to parse VRT XML as
# TIFF bytes. ``read_vrt`` is the single source of truth for VRT.
if source.lower().endswith('.vrt'):
return read_vrt(source, dtype=dtype, name=name, chunks=chunks)

Expand Down Expand Up @@ -944,23 +982,24 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,
# Graph-size guard. Each chunk becomes a delayed task whose Python graph
# entry retains ~1KB. At very large chunk counts the graph itself OOMs
# the driver before any read executes (30TB at chunks=256 => ~500M tasks
# => ~500GB graph on host). Auto-scale chunks up to cap total task count.
_MAX_DASK_CHUNKS = 1_000_000
# => ~500GB graph on host). Refuse anything past the cap and ask the
# caller to pick a chunk size, rather than silently rescaling -- the
# rescaled chunks may not align with the user's downstream pipeline.
_MAX_DASK_CHUNKS = 50_000
n_chunks = ((full_h + ch_h - 1) // ch_h) * ((full_w + ch_w - 1) // ch_w)
if n_chunks > _MAX_DASK_CHUNKS:
import math
scale = math.sqrt(n_chunks / _MAX_DASK_CHUNKS)
new_ch_h = int(math.ceil(ch_h * scale))
new_ch_w = int(math.ceil(ch_w * scale))
import warnings
warnings.warn(
f"read_geotiff_dask: requested chunks=({ch_h}, {ch_w}) on a "
f"{full_h}x{full_w} image would produce {n_chunks} dask tasks, "
f"exceeding the {_MAX_DASK_CHUNKS}-task cap. Auto-scaling to "
f"chunks=({new_ch_h}, {new_ch_w}).",
stacklevel=2,
suggested_h = int(math.ceil(ch_h * scale))
suggested_w = int(math.ceil(ch_w * scale))
raise ValueError(
f"read_geotiff_dask: chunks=({ch_h}, {ch_w}) on a "
f"{full_h}x{full_w} image would produce {n_chunks:,} dask "
f"tasks, exceeding the {_MAX_DASK_CHUNKS:,}-task cap. Pass a "
f"larger chunks=... value explicitly (e.g. chunks="
f"({suggested_h}, {suggested_w}) keeps the task count under "
"the cap)."
)
ch_h, ch_w = new_ch_h, new_ch_w

# Build dask array from delayed windowed reads
rows = list(range(0, full_h, ch_h))
Expand Down Expand Up @@ -1355,12 +1394,14 @@ def _gpu_compress_to_part(gpu_arr, w, h, spp):
# Full resolution
parts = [_gpu_compress_to_part(arr, width, height, samples)]

# Overview generation
# Overview generation -- mirrors the CPU writer's 8-level cap.
if cog:
if overview_levels is None:
from ._writer import _MAX_OVERVIEW_LEVELS
overview_levels = []
oh, ow = height, width
while oh > tile_size and ow > tile_size:
while (oh > tile_size and ow > tile_size and
len(overview_levels) < _MAX_OVERVIEW_LEVELS):
oh //= 2
ow //= 2
if oh > 0 and ow > 0:
Expand Down Expand Up @@ -1505,13 +1546,23 @@ def write_vrt(vrt_path: str, source_files: list[str], **kwargs) -> str:
Output .vrt file path.
source_files : list of str
Paths to the source GeoTIFF files.
**kwargs
relative, crs_wkt, nodata -- see _vrt.write_vrt.
relative : bool, optional
Store source paths relative to the VRT file (default True).
crs_wkt : str or None, optional
CRS as a WKT string. If None, the CRS is taken from the first
source GeoTIFF.
nodata : float or None, optional
NoData value. If None, taken from the first source GeoTIFF.

Returns
-------
str
Path to the written VRT file.

Notes
-----
Only the keyword arguments listed above are accepted. Passing any
other keyword raises ``TypeError`` from the underlying writer.
"""
from ._vrt import write_vrt as _write_vrt_internal
return _write_vrt_internal(vrt_path, source_files, **kwargs)
Expand Down
53 changes: 53 additions & 0 deletions xrspatial/geotiff/_gpu_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,51 @@
import numpy as np
from numba import cuda

#: Fraction of free GPU memory we're willing to allocate in a single call.
#: Above this, raise MemoryError up-front so the caller gets an actionable
#: error rather than a CUDA OOM deep inside the kernel launch.
_GPU_FREE_MEMORY_FRACTION = 0.9


def _check_gpu_memory(required_bytes: int, what: str = "tile buffer") -> None:
"""Raise MemoryError if *required_bytes* would exhaust the GPU.

Calls ``cupy.cuda.runtime.memGetInfo()`` and refuses any allocation
that would consume more than ``_GPU_FREE_MEMORY_FRACTION`` of the
currently free memory. This is a soft guard -- another process can
grab memory between the check and the allocation -- but it catches
the common 'this single tensor is way too big' case before CUDA
raises a less informative error.

Parameters
----------
required_bytes : int
Bytes the caller is about to allocate (sum across all buffers in
the same logical step).
what : str
Short label included in the error message, e.g. ``"tile buffer"``.
"""
if required_bytes <= 0:
return
try:
import cupy
free, total = cupy.cuda.runtime.memGetInfo()
except Exception:
# If we can't query, fall through and let the real allocation
# surface the error. Don't add a second failure mode here.
return

budget = int(free * _GPU_FREE_MEMORY_FRACTION)
if required_bytes > budget:
raise MemoryError(
f"GPU out of memory: {what} needs {required_bytes:,} bytes "
f"but only {free:,} bytes free on device (cap is "
f"{_GPU_FREE_MEMORY_FRACTION:.0%} of free = {budget:,} "
"bytes). Consider reading the file in chunks via "
"read_geotiff_dask(..., chunks=...) or freeing GPU memory "
"with cupy.get_default_memory_pool().free_all_blocks()."
)

# LZW constants (same as _compression.py)
LZW_CLEAR_CODE = 256
LZW_EOI_CODE = 257
Expand Down Expand Up @@ -1006,6 +1051,8 @@ class _NvjpegImage(ctypes.Structure):
('pitch', ctypes.c_size_t * 4),
]

_check_gpu_memory(n_tiles * tile_bytes,
what="nvJPEG output buffer")
d_all = cupy.empty(n_tiles * tile_bytes, dtype=cupy.uint8)

decode_fn = getattr(lib, 'nvjpegDecode')
Expand Down Expand Up @@ -1353,6 +1400,8 @@ def _apply_predictor_and_assemble(d_decomp, d_decomp_offsets, n_tiles,

tiles_across = math.ceil(image_width / tile_width)
total_pixels = image_width * image_height
_check_gpu_memory(total_pixels * bytes_per_pixel,
what="full-image output buffer")
d_output = cupy.empty(total_pixels * bytes_per_pixel, dtype=cupy.uint8)

tpb = 256
Expand Down Expand Up @@ -1440,6 +1489,7 @@ def gpu_decode_tiles(

# Allocate decompressed buffer on device
decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes
_check_gpu_memory(n_tiles * tile_bytes, what="tile decode buffer")
d_decomp = cupy.zeros(n_tiles * tile_bytes, dtype=cupy.uint8)
d_decomp_offsets = cupy.asarray(decomp_offsets)
d_tile_sizes = cupy.full(n_tiles, tile_bytes, dtype=cupy.int64)
Expand Down Expand Up @@ -1470,6 +1520,7 @@ def gpu_decode_tiles(
d_comp_sizes = cupy.asarray(np.array(comp_sizes, dtype=np.int64))

decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes
_check_gpu_memory(n_tiles * tile_bytes, what="tile decode buffer")
d_decomp = cupy.zeros(n_tiles * tile_bytes, dtype=cupy.uint8)
d_decomp_offsets = cupy.asarray(decomp_offsets)
d_tile_sizes = cupy.full(n_tiles, tile_bytes, dtype=cupy.int64)
Expand Down Expand Up @@ -1602,6 +1653,8 @@ def gpu_decode_tiles(
# Assemble tiles into output image on GPU
tiles_across = math.ceil(image_width / tile_width)
total_pixels = image_width * image_height
_check_gpu_memory(total_pixels * bytes_per_pixel,
what="full-image output buffer")
d_output = cupy.empty(total_pixels * bytes_per_pixel, dtype=cupy.uint8)

tpb = 256
Expand Down
Loading
Loading