diff --git a/xrspatial/geotiff/_reader.py b/xrspatial/geotiff/_reader.py index e1dd7ada..feb59574 100644 --- a/xrspatial/geotiff/_reader.py +++ b/xrspatial/geotiff/_reader.py @@ -7,6 +7,7 @@ import threading import urllib.request from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor import numpy as np @@ -236,6 +237,39 @@ def read_range(self, start: int, length: int) -> bytes: with urllib.request.urlopen(req) as resp: return resp.read() + def read_ranges( + self, + ranges: list[tuple[int, int]], + max_workers: int = 8, + ) -> list[bytes]: + """Fetch multiple ranges concurrently using a thread pool. + + Each ``(start, length)`` pair is fetched with its own range request, + but requests run in parallel so total wall time is bounded by the + slowest worker rather than ``len(ranges) * RTT``. + + Returns the bytes for each range in input order. + """ + if not ranges: + return [] + if len(ranges) == 1: + start, length = ranges[0] + return [self.read_range(start, length)] + + workers = min(max_workers, len(ranges)) + results: list[bytes | None] = [None] * len(ranges) + + with ThreadPoolExecutor(max_workers=workers) as ex: + future_to_idx = { + ex.submit(self.read_range, start, length): i + for i, (start, length) in enumerate(ranges) + } + for fut in future_to_idx: + idx = future_to_idx[fut] + results[idx] = fut.result() + + return results # type: ignore[return-value] + def read_all(self) -> bytes: if self._pool is not None: resp = self._pool.request('GET', self._url) @@ -690,6 +724,11 @@ def _read_cog_http(url: str, overview_level: int | None = None, ) -> tuple[np.ndarray, GeoInfo]: """Read a COG via HTTP range requests. + Tile fetches run concurrently through a small thread pool so that the + total wall time is bounded by the slowest tile request rather than + ``num_tiles * RTT``. The pool size can be overridden with the + ``XRSPATIAL_COG_HTTP_WORKERS`` environment variable (default 8). + Parameters ---------- url : str @@ -774,31 +813,47 @@ def _read_cog_http(url: str, overview_level: int | None = None, else: result = np.empty((height, width), dtype=dtype) + # Pass 1: collect every tile's range and where it lands in the output. + # Empty tiles (byte_count == 0) and any tile_idx beyond the offsets + # array are skipped here so the fetch list stays exactly aligned with + # the placements list. + fetch_ranges: list[tuple[int, int]] = [] + placements: list[tuple[int, int]] = [] # (tr, tc) per fetched tile for tr in range(tiles_down): for tc in range(tiles_across): tile_idx = tr * tiles_across + tc if tile_idx >= len(offsets): continue - off = offsets[tile_idx] bc = byte_counts[tile_idx] if bc == 0: continue + fetch_ranges.append((off, bc)) + placements.append((tr, tc)) - tile_data = source.read_range(off, bc) - tile_pixels = _decode_strip_or_tile( - tile_data, compression, tw, th, samples, - bps, bytes_per_sample, is_sub_byte, dtype, pred, - byte_order=header.byte_order) + # Pass 2: fetch all tile bytes in parallel. Worker pool size is tunable + # via XRSPATIAL_COG_HTTP_WORKERS so users on very slow links can dial + # it up without code changes. + try: + workers = max(1, int(_os_module.environ.get('XRSPATIAL_COG_HTTP_WORKERS', '8'))) + except ValueError: + workers = 8 + tile_bytes_list = source.read_ranges(fetch_ranges, max_workers=workers) + + # Pass 3: decode each tile and place it. + for (tr, tc), tile_data in zip(placements, tile_bytes_list): + tile_pixels = _decode_strip_or_tile( + tile_data, compression, tw, th, samples, + bps, bytes_per_sample, is_sub_byte, dtype, pred, + byte_order=header.byte_order) - # Place tile - y0 = tr * th - x0 = tc * tw - y1 = min(y0 + th, height) - x1 = min(x0 + tw, width) - actual_h = y1 - y0 - actual_w = x1 - x0 - result[y0:y1, x0:x1] = tile_pixels[:actual_h, :actual_w] + y0 = tr * th + x0 = tc * tw + y1 = min(y0 + th, height) + x1 = min(x0 + tw, width) + actual_h = y1 - y0 + actual_w = x1 - x0 + result[y0:y1, x0:x1] = tile_pixels[:actual_h, :actual_w] source.close() return result, geo_info diff --git a/xrspatial/geotiff/tests/test_cog_http_concurrent.py b/xrspatial/geotiff/tests/test_cog_http_concurrent.py new file mode 100644 index 00000000..991fc516 --- /dev/null +++ b/xrspatial/geotiff/tests/test_cog_http_concurrent.py @@ -0,0 +1,180 @@ +"""Tests for concurrent tile fetching in _read_cog_http (issue #1480).""" +from __future__ import annotations + +import http.server +import socketserver +import threading +import time + +import numpy as np +import pytest + +from xrspatial.geotiff._reader import ( + _HTTPSource, + _read_cog_http, + read_to_array, +) +from xrspatial.geotiff._writer import write + + +# --------------------------------------------------------------------------- +# read_ranges: ordering and concurrency +# --------------------------------------------------------------------------- + +class _FakeHTTPSource(_HTTPSource): + """_HTTPSource that fakes read_range with a configurable sleep. + + Tracks both total call count and the maximum observed in-flight + concurrency so tests can verify the threadpool dispatch directly + rather than relying on wall-clock timing (which is flaky on busy + CI runners). + """ + + def __init__(self, per_request_sleep: float = 0.05): + # Skip super().__init__ -- we're not making real HTTP calls. + self._url = 'fake://test' + self._size = None + self._pool = None + self._per_request_sleep = per_request_sleep + self.call_count = 0 + self.in_flight = 0 + self.max_in_flight = 0 + self._lock = threading.Lock() + + def read_range(self, start: int, length: int) -> bytes: + with self._lock: + self.call_count += 1 + self.in_flight += 1 + if self.in_flight > self.max_in_flight: + self.max_in_flight = self.in_flight + try: + time.sleep(self._per_request_sleep) + return f'{start}:{length}'.encode('ascii') + finally: + with self._lock: + self.in_flight -= 1 + + +def test_read_ranges_returns_results_in_input_order(): + src = _FakeHTTPSource(per_request_sleep=0.0) + ranges = [(0, 10), (100, 5), (50, 20), (200, 7)] + out = src.read_ranges(ranges, max_workers=4) + assert len(out) == len(ranges) + for (start, length), data in zip(ranges, out): + assert data == f'{start}:{length}'.encode('ascii') + + +def test_read_ranges_empty_list(): + src = _FakeHTTPSource(per_request_sleep=0.0) + assert src.read_ranges([]) == [] + + +def test_read_ranges_single_request_skips_pool(): + src = _FakeHTTPSource(per_request_sleep=0.0) + out = src.read_ranges([(42, 8)], max_workers=8) + assert out == [b'42:8'] + assert src.call_count == 1 + + +def test_read_ranges_dispatches_concurrently(): + """The threadpool should run multiple requests in flight at once. + + Asserting on observed in-flight concurrency is robust to CI scheduler + jitter; a wall-clock assertion of the same effect is flaky on busy + runners (the previous version of this test was a 50 ms per-request + × 20-request setup that occasionally exceeded its 0.5 s budget by a + few ms on macOS). + """ + n = 20 + workers = 8 + src = _FakeHTTPSource(per_request_sleep=0.02) + ranges = [(i * 100, 10) for i in range(n)] + + out = src.read_ranges(ranges, max_workers=workers) + + assert src.call_count == n + assert len(out) == n + # Sequential dispatch would peak at 1 in flight. The pool should + # run several in parallel; require at least 2 (very loose) to keep + # the test robust on heavily oversubscribed CI runners. + assert src.max_in_flight >= 2, ( + f'expected >=2 concurrent in-flight calls, ' + f'got max_in_flight={src.max_in_flight}' + ) + + +# --------------------------------------------------------------------------- +# _read_cog_http: correctness via local http.server +# --------------------------------------------------------------------------- + +class _RangeHandler(http.server.BaseHTTPRequestHandler): + """Serve a single in-memory bytes payload with HTTP Range support.""" + + payload: bytes = b'' + + def do_GET(self): # noqa: N802 + rng = self.headers.get('Range') + if rng and rng.startswith('bytes='): + spec = rng[len('bytes='):] + # Single range only -- matches what _HTTPSource sends. + start_s, _, end_s = spec.partition('-') + start = int(start_s) + end = int(end_s) if end_s else len(self.payload) - 1 + chunk = self.payload[start:end + 1] + self.send_response(206) + self.send_header('Content-Type', 'application/octet-stream') + self.send_header( + 'Content-Range', + f'bytes {start}-{start + len(chunk) - 1}/{len(self.payload)}', + ) + self.send_header('Content-Length', str(len(chunk))) + self.end_headers() + self.wfile.write(chunk) + return + self.send_response(200) + self.send_header('Content-Type', 'application/octet-stream') + self.send_header('Content-Length', str(len(self.payload))) + self.end_headers() + self.wfile.write(self.payload) + + def log_message(self, *_args, **_kwargs): + # Silence the default access log during tests. + pass + + +@pytest.fixture +def cog_http_server(tmp_path): + """Spin up a local http.server serving a tiled COG, yield (url, arr).""" + arr = np.arange(64 * 64, dtype=np.float32).reshape(64, 64) + path = str(tmp_path / 'tmp_1480_cog.tif') + write(arr, path, compression='deflate', tiled=True, tile_size=16, + cog=True, overview_levels=[1]) + + with open(path, 'rb') as f: + payload = f.read() + + handler_cls = type( + 'RangeHandler1480', (_RangeHandler,), {'payload': payload} + ) + httpd = socketserver.TCPServer(('127.0.0.1', 0), handler_cls) + port = httpd.server_address[1] + thread = threading.Thread(target=httpd.serve_forever, daemon=True) + thread.start() + + try: + yield f'http://127.0.0.1:{port}/cog.tif', arr + finally: + httpd.shutdown() + httpd.server_close() + + +def test_cog_http_round_trip_matches_local_read(cog_http_server): + url, expected = cog_http_server + result, _ = _read_cog_http(url) + np.testing.assert_array_equal(result, expected) + + +def test_read_to_array_dispatches_to_http(cog_http_server): + url, expected = cog_http_server + result, _ = read_to_array(url) + np.testing.assert_array_equal(result, expected)