diff --git a/xrspatial/geotiff/_backends/gpu.py b/xrspatial/geotiff/_backends/gpu.py index 26db314d..d9d1c7b6 100644 --- a/xrspatial/geotiff/_backends/gpu.py +++ b/xrspatial/geotiff/_backends/gpu.py @@ -766,11 +766,8 @@ def read_geotiff_gpu(source: str, *, def _read_once(): if not _shared_data_cache: - src2 = _FileSource(source) - try: + with _FileSource(source) as src2: _shared_data_cache.append(src2.read_all()) - finally: - src2.close() return _shared_data_cache[0] band_arrays = [] @@ -880,15 +877,12 @@ def _read_once(): for i in range(len(offsets)) ] else: - src2 = _FileSource(source) - data2 = src2.read_all() - try: + with _FileSource(source) as src2: + data2 = src2.read_all() compressed_tiles = [ bytes(data2[offsets[i]:offsets[i] + byte_counts[i]]) for i in range(len(offsets)) ] - finally: - src2.close() if arr_gpu is None: try: @@ -1284,15 +1278,12 @@ def _decode_window_gpu_direct(file_path, all_offsets, all_byte_counts, # usable on the host. Open the file via mmap, slice out just the # bytes for these tiles, and run the GPU decoder on those. from .._reader import _FileSource - src = _FileSource(file_path) - try: + with _FileSource(file_path) as src: data = src.read_all() compressed_tiles = [ bytes(data[sub_offsets[i]:sub_offsets[i] + sub_byte_counts[i]]) for i in range(len(sub_offsets)) ] - finally: - src.close() arr_gpu = gpu_decode_tiles( compressed_tiles, tw, th, sub_w, sub_h, compression, predictor, file_dtype, samples, @@ -1361,11 +1352,8 @@ def _read_geotiff_gpu_chunked(source, *, dtype, chunks, overview_level, if isinstance(src_path, str) and not src_path.startswith( ('http://', 'https://')): try: - _cap_fs = _FileSource(src_path) - try: + with _FileSource(src_path) as _cap_fs: _cap_raw = _cap_fs.read_all() - finally: - _cap_fs.close() _cap_header = parse_header(_cap_raw) _cap_ifds = parse_all_ifds(_cap_raw, _cap_header) _cap_ifd = select_overview_ifd(_cap_ifds, overview_level) @@ -1395,11 +1383,8 @@ def _read_geotiff_gpu_chunked(source, *, dtype, chunks, overview_level, try: if isinstance(src_path, str) and not src_path.startswith( ('http://', 'https://')): - fs = _FileSource(src_path) - try: + with _FileSource(src_path) as fs: raw = fs.read_all() - finally: - fs.close() header = parse_header(raw) ifds = parse_all_ifds(raw, header) if not ifds: diff --git a/xrspatial/geotiff/_sources.py b/xrspatial/geotiff/_sources.py index 217271d4..1a3ca669 100644 --- a/xrspatial/geotiff/_sources.py +++ b/xrspatial/geotiff/_sources.py @@ -324,6 +324,12 @@ def close(self): _mmap_cache.release(self._entry) self._entry = None + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + # --------------------------------------------------------------------------- # HTTP source: pool, timeouts, SSRF defences (issue #1664) diff --git a/xrspatial/geotiff/_vrt.py b/xrspatial/geotiff/_vrt.py index 0bf6ab8d..3905c5d0 100644 --- a/xrspatial/geotiff/_vrt.py +++ b/xrspatial/geotiff/_vrt.py @@ -1956,20 +1956,19 @@ def write_vrt(vrt_path: str, source_files: list[str], *, # Read metadata from all sources sources_meta = [] for src_path in source_files: - src = _FileSource(src_path) - data = src.read_all() - header = parse_header(data) - ifds = parse_all_ifds(data, header) - ifd = ifds[0] - # The writer is reading source metadata to populate the VRT XML; - # it does not decode pixels or run the masking step that the - # new #2441 default-rejection guards against. Pass the opt-in - # so a source TIFF with a non-finite / fractional ``GDAL_NODATA`` - # value can still be referenced by a VRT (the read-side default - # still rejects it when the resulting VRT is later opened). - geo = extract_geo_info(ifd, data, header.byte_order, - allow_invalid_nodata=True) - src.close() + with _FileSource(src_path) as src: + data = src.read_all() + header = parse_header(data) + ifds = parse_all_ifds(data, header) + ifd = ifds[0] + # The writer is reading source metadata to populate the VRT XML; + # it does not decode pixels or run the masking step that the + # new #2441 default-rejection guards against. Pass the opt-in + # so a source TIFF with a non-finite / fractional ``GDAL_NODATA`` + # value can still be referenced by a VRT (the read-side default + # still rejects it when the resulting VRT is later opened). + geo = extract_geo_info(ifd, data, header.byte_order, + allow_invalid_nodata=True) bps = resolve_bits_per_sample(ifd.bits_per_sample) diff --git a/xrspatial/geotiff/_vrt_validation.py b/xrspatial/geotiff/_vrt_validation.py index e1881957..2215ae26 100644 --- a/xrspatial/geotiff/_vrt_validation.py +++ b/xrspatial/geotiff/_vrt_validation.py @@ -557,85 +557,80 @@ def _check_mixed_source_crs( reference_source: str | None = None # None == VRT-declared for src_path in source_paths: - src_handle = None try: - try: - src_handle = _FileSource(src_path) + with _FileSource(src_path) as src_handle: data = src_handle.read_all() header = parse_header(data) ifds = parse_all_ifds(data, header) if not ifds: continue geo = extract_geo_info(ifds[0], data, header.byte_order) - except (OSError, ValueError, struct.error): - # Unreadable / malformed source. The real read path - # surfaces this via the ``missing_sources`` contract - # (default ``'raise'``); the validator deliberately - # stays quiet so the canonical error fires from there - # with the existing message and code path. - continue - - src_crs_raw = geo.crs_wkt - if not src_crs_raw and geo.crs_epsg is not None: - src_crs_raw = f"EPSG:{geo.crs_epsg}" - if not src_crs_raw: - # Source has no CRS at all. The VRT can legitimately - # provide one in this case (reader inherits the - # VRT-declared SRS), so skip. - continue + except (OSError, ValueError, struct.error): + # Unreadable / malformed source. The real read path + # surfaces this via the ``missing_sources`` contract + # (default ``'raise'``); the validator deliberately + # stays quiet so the canonical error fires from there + # with the existing message and code path. + continue + + src_crs_raw = geo.crs_wkt + if not src_crs_raw and geo.crs_epsg is not None: + src_crs_raw = f"EPSG:{geo.crs_epsg}" + if not src_crs_raw: + # Source has no CRS at all. The VRT can legitimately + # provide one in this case (reader inherits the + # VRT-declared SRS), so skip. + continue - try: - src_crs = _PyProjCRS.from_user_input(src_crs_raw) - except _PyProjCRSError: - # Cannot canonicalise the source CRS; cannot prove - # disagreement. The decode path will surface its own - # error if the broken CRS matters. - continue - - if reference_crs is None: - # No VRT-level SRS and this is the first source we - # could read; use it as the reference for the rest. - reference_crs = src_crs - reference_display = _format_crs_for_message( - src_crs_raw, src_crs, - ) - reference_source = src_path - continue + try: + src_crs = _PyProjCRS.from_user_input(src_crs_raw) + except _PyProjCRSError: + # Cannot canonicalise the source CRS; cannot prove + # disagreement. The decode path will surface its own + # error if the broken CRS matters. + continue + + if reference_crs is None: + # No VRT-level SRS and this is the first source we + # could read; use it as the reference for the rest. + reference_crs = src_crs + reference_display = _format_crs_for_message( + src_crs_raw, src_crs, + ) + reference_source = src_path + continue - if src_crs.equals(reference_crs): - continue + if src_crs.equals(reference_crs): + continue - src_display = _format_crs_for_message(src_crs_raw, src_crs) - if reference_source is None: - # Disagreement against the VRT-declared SRS. - raise VRTUnsupportedError( - f"VRT '{source}' source '{src_path}' carries CRS " - f"{src_display}, which does not match the " - f"VRT-declared ({reference_display}). The " - f"mosaic reader has no reprojection step, so the " - f"two sets of pixels cannot be composited into a " - f"single CRS without misplacing one of them. " - f"Reproject the source to the VRT-declared CRS " - f"(e.g. ``gdalwarp -t_srs``) before referencing it, " - f"or split the mosaic into per-CRS VRTs. See issue " - f"#2321." - ) - # Disagreement among sources (no VRT-level SRS). + src_display = _format_crs_for_message(src_crs_raw, src_crs) + if reference_source is None: + # Disagreement against the VRT-declared SRS. raise VRTUnsupportedError( - f"VRT '{source}' has no and its sources disagree " - f"on CRS: '{reference_source}' carries " - f"{reference_display} but '{src_path}' carries " - f"{src_display}. The mosaic reader has no reprojection " - f"step, so the sources cannot be composited into a " - f"single CRS. Reproject every source to a common CRS " - f"(e.g. ``gdalwarp -t_srs``) before assembling the VRT, " - f"or declare an authoritative at the VRT level " - f"and use sources whose CRS already matches it. See " - f"issue #2321." + f"VRT '{source}' source '{src_path}' carries CRS " + f"{src_display}, which does not match the " + f"VRT-declared ({reference_display}). The " + f"mosaic reader has no reprojection step, so the " + f"two sets of pixels cannot be composited into a " + f"single CRS without misplacing one of them. " + f"Reproject the source to the VRT-declared CRS " + f"(e.g. ``gdalwarp -t_srs``) before referencing it, " + f"or split the mosaic into per-CRS VRTs. See issue " + f"#2321." ) - finally: - if src_handle is not None: - src_handle.close() + # Disagreement among sources (no VRT-level SRS). + raise VRTUnsupportedError( + f"VRT '{source}' has no and its sources disagree " + f"on CRS: '{reference_source}' carries " + f"{reference_display} but '{src_path}' carries " + f"{src_display}. The mosaic reader has no reprojection " + f"step, so the sources cannot be composited into a " + f"single CRS. Reproject every source to a common CRS " + f"(e.g. ``gdalwarp -t_srs``) before assembling the VRT, " + f"or declare an authoritative at the VRT level " + f"and use sources whose CRS already matches it. See " + f"issue #2321." + ) def _format_crs_for_message(raw: str, crs) -> str: diff --git a/xrspatial/geotiff/tests/gpu/test_kernels_and_kwargs.py b/xrspatial/geotiff/tests/gpu/test_kernels_and_kwargs.py index 3a844475..b94ad671 100644 --- a/xrspatial/geotiff/tests/gpu/test_kernels_and_kwargs.py +++ b/xrspatial/geotiff/tests/gpu/test_kernels_and_kwargs.py @@ -1409,11 +1409,8 @@ def _parse_for_gds_1896(path: str): from xrspatial.geotiff._header import parse_all_ifds, parse_header, select_overview_ifd from xrspatial.geotiff._reader import _FileSource - fs = _FileSource(path) - try: + with _FileSource(path) as fs: raw = fs.read_all() - finally: - fs.close() header = parse_header(raw) ifds = parse_all_ifds(raw, header) ifd = select_overview_ifd(ifds, None) diff --git a/xrspatial/geotiff/tests/gpu/test_reader.py b/xrspatial/geotiff/tests/gpu/test_reader.py index 6e14a960..e00f6188 100644 --- a/xrspatial/geotiff/tests/gpu/test_reader.py +++ b/xrspatial/geotiff/tests/gpu/test_reader.py @@ -1408,11 +1408,8 @@ def _parse_for_gds_1909(path: str): ) from xrspatial.geotiff._reader import _FileSource - fs = _FileSource(path) - try: + with _FileSource(path) as fs: raw = fs.read_all() - finally: - fs.close() header = parse_header(raw) ifds = parse_all_ifds(raw, header) ifd = select_overview_ifd(ifds, None) diff --git a/xrspatial/geotiff/tests/test_file_source_context_2449.py b/xrspatial/geotiff/tests/test_file_source_context_2449.py new file mode 100644 index 00000000..5428c81a --- /dev/null +++ b/xrspatial/geotiff/tests/test_file_source_context_2449.py @@ -0,0 +1,70 @@ +# Tests for _FileSource context manager protocol (issue #2449). +import os +import struct + +import numpy as np +import pytest + +from xrspatial.geotiff import to_geotiff +from xrspatial.geotiff._sources import _FileSource, _mmap_cache + + +@pytest.fixture +def tiff_path(tmp_path): + arr = np.zeros((4, 4), dtype=np.uint8) + path = str(tmp_path / 'fs_ctx_2449.tif') + to_geotiff(arr, path, compression='none') + return path + + +def _refcount(path): + """Look up the cache refcount for *path*, or None if not cached.""" + real = os.path.realpath(path) + entry = _mmap_cache._entries.get(real) + return None if entry is None else entry[3] + + +def test_enter_returns_self(tiff_path): + src = _FileSource(tiff_path) + try: + assert src.__enter__() is src + finally: + src.close() + + +def test_exit_releases_entry(tiff_path): + _mmap_cache.clear() + with _FileSource(tiff_path) as src: + assert _refcount(tiff_path) == 1 + # mmap is usable inside the block + assert len(src.read_all()) == src.size + # After the with block, refcount returns to 0 (entry stays cached). + assert _refcount(tiff_path) == 0 + + +def test_exit_releases_on_exception(tiff_path): + _mmap_cache.clear() + with pytest.raises(struct.error): + with _FileSource(tiff_path): + assert _refcount(tiff_path) == 1 + struct.unpack('>I', b'') + assert _refcount(tiff_path) == 0 + + +def test_double_close_safe(tiff_path): + with _FileSource(tiff_path) as src: + src.close() + # __exit__ will call close() again; must not raise or + # over-decrement the cache refcount. + assert _refcount(tiff_path) == 0 + assert _refcount(tiff_path) == 0 + + +def test_nested_with_shares_cache_entry(tiff_path): + _mmap_cache.clear() + with _FileSource(tiff_path): + assert _refcount(tiff_path) == 1 + with _FileSource(tiff_path): + assert _refcount(tiff_path) == 2 + assert _refcount(tiff_path) == 1 + assert _refcount(tiff_path) == 0 diff --git a/xrspatial/geotiff/tests/test_polish_1488.py b/xrspatial/geotiff/tests/test_polish_1488.py index 39fce755..3a33d978 100644 --- a/xrspatial/geotiff/tests/test_polish_1488.py +++ b/xrspatial/geotiff/tests/test_polish_1488.py @@ -451,13 +451,10 @@ def test_auto_overview_capped(self, tmp_path): # Re-open and count IFDs (overviews + full-res). from xrspatial.geotiff._header import parse_all_ifds, parse_header from xrspatial.geotiff._reader import _FileSource - src = _FileSource(path) - try: + with _FileSource(path) as src: data = src.read_all() header = parse_header(data) ifds = parse_all_ifds(data, header) - finally: - src.close() # 1 full-res IFD + at most 8 overview IFDs. assert len(ifds) <= 1 + _MAX_OVERVIEW_LEVELS @@ -472,13 +469,10 @@ def test_explicit_overview_levels_not_capped(self, tmp_path): from xrspatial.geotiff._header import parse_all_ifds, parse_header from xrspatial.geotiff._reader import _FileSource - src = _FileSource(path) - try: + with _FileSource(path) as src: data = src.read_all() header = parse_header(data) ifds = parse_all_ifds(data, header) - finally: - src.close() # 10 explicit overviews + 1 full-res = 11 IFDs. assert len(ifds) == 11 diff --git a/xrspatial/geotiff/tests/write/test_overview.py b/xrspatial/geotiff/tests/write/test_overview.py index 29ca5f3d..2f676ae9 100644 --- a/xrspatial/geotiff/tests/write/test_overview.py +++ b/xrspatial/geotiff/tests/write/test_overview.py @@ -1501,13 +1501,10 @@ def test_cog_overview_block_order_rio_cogeo_2308(tmp_path, bands): def _ifd_dimensions(path): """Return (width, height) for every IFD in the file.""" - src = _FileSource(path) - try: + with _FileSource(path) as src: data = src.read_all() header = parse_header(data) ifds = parse_all_ifds(data, header) - finally: - src.close() return [(ifd.width, ifd.height) for ifd in ifds] @@ -1604,13 +1601,10 @@ def test_overview_pyramid_mean_values_are_correct(tmp_path): # The on-disk overview value should match the chained-halving result. # ``open_geotiff`` returns the full-resolution band; read the overview # IFD directly through the low-level path. - src = _FileSource(path) - try: + with _FileSource(path) as src: data = src.read_all() header = parse_header(data) ifds = parse_all_ifds(data, header) - finally: - src.close() # IFD 1 is the /4 overview (only one we requested). assert ifds[1].width == 16 and ifds[1].height == 16 # The byte-level check would require decoding tiles; the shape +