Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions xrspatial/geotiff/_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,15 @@ def _fp_predictor_encode_row(row_data, width, bps):
row_data[i] = np.uint8((np.int32(row_data[i]) - np.int32(row_data[i - 1])) & 0xFF)


@ngjit
def _fp_predictor_encode_rows(data, width, height, bps):
"""Dispatch per-row encode from Numba, avoiding Python loop overhead."""
row_len = width * bps
for row in range(height):
start = row * row_len
_fp_predictor_encode_row(data[start:start + row_len], width, bps)


def fp_predictor_encode(data: np.ndarray, width: int, height: int,
bytes_per_sample: int) -> np.ndarray:
"""Apply floating-point predictor (predictor=3).
Expand All @@ -715,10 +724,7 @@ def fp_predictor_encode(data: np.ndarray, width: int, height: int,
Encoded array.
"""
buf = np.ascontiguousarray(data)
row_len = width * bytes_per_sample
for row in range(height):
start = row * row_len
_fp_predictor_encode_row(buf[start:start + row_len], width, bytes_per_sample)
_fp_predictor_encode_rows(buf, width, height, bytes_per_sample)
return buf


Expand Down
70 changes: 70 additions & 0 deletions xrspatial/geotiff/tests/test_predictor_fp_write_1313.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""
from __future__ import annotations

import os
import struct

import numpy as np
Expand Down Expand Up @@ -201,3 +202,72 @@ def test_predictor3_multiband_round_trip(tmp_path):
else:
out_arr = out.values
np.testing.assert_array_equal(out_arr, arr)


def test_predictor3_large_round_trip_value_exact(tmp_path):
"""1024x1024 float32 deflate+predictor=3 round-trips with no value drift.

The encode path was refactored to dispatch the per-row kernel from
inside an ``@ngjit`` wrapper instead of from a Python ``for`` loop.
Guards against any silent corruption from the refactor by asserting
the output array is byte-for-byte identical to the input: dtype must
match, and a ``uint8`` view of the bytes must compare equal so the
check catches signed-zero drift, NaN payload changes, and any other
bit-level divergence that ``assert_array_equal`` would mask.
"""
Comment thread
brendancol marked this conversation as resolved.
h, w = 1024, 1024
arr = _smooth_float((h, w), np.float32)
da = _da(arr)
path = tmp_path / 'pred3_large_round_trip_1313.tif'
to_geotiff(da, str(path), compression='deflate', predictor=3)

assert _read_predictor_tag(str(path)) == 3
out = open_geotiff(str(path))
out_arr = np.ascontiguousarray(out.values)
assert out_arr.dtype == arr.dtype, (
f"dtype drift: in={arr.dtype}, out={out_arr.dtype}"
)
assert out_arr.shape == arr.shape
assert out_arr.tobytes() == arr.tobytes(), (
"predictor=3 round-trip diverged at the bit level "
"(signed zero, NaN payload, or actual corruption)"
)


def test_predictor3_encode_within_2x_of_predictor2(tmp_path):
"""Loose regression check: predictor=3 encode is within 2x of predictor=2.

Before the ngjit row-loop refactor, predictor=3 was ~2.5x slower than
predictor=2 because the row loop was in Python. Opt-in via
``XRSPATIAL_RUN_PERF_TESTS=1`` -- shared CI runners, CPU throttling,
debug builds, and noisy filesystems all make absolute wall-clock
timings flaky, so the test stays off by default. Matches the
convention from ``test_streaming_write_parallel.py``.
"""
if os.environ.get('XRSPATIAL_RUN_PERF_TESTS') != '1':
pytest.skip(
"set XRSPATIAL_RUN_PERF_TESTS=1 to run wall-clock perf tests")

import time

arr = _smooth_float((1024, 1024), np.float32)
da = _da(arr)
p2 = tmp_path / 'pred2_timing.tif'
p3 = tmp_path / 'pred3_timing.tif'

# Warm up numba
to_geotiff(da, str(p2), compression='deflate', predictor=2)
to_geotiff(da, str(p3), compression='deflate', predictor=3)

t0 = time.perf_counter()
to_geotiff(da, str(p2), compression='deflate', predictor=2)
t_p2 = time.perf_counter() - t0

t0 = time.perf_counter()
to_geotiff(da, str(p3), compression='deflate', predictor=3)
t_p3 = time.perf_counter() - t0

assert t_p3 < 2.0 * t_p2, (
f'predictor=3 ({t_p3*1000:.1f} ms) is more than 2x slower than '
f'predictor=2 ({t_p2*1000:.1f} ms); ngjit row loop may have regressed'
)
Loading