Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@
_GPU_DEPRECATED_SENTINEL = object()
_ON_GPU_FAILURE_SENTINEL = object()

# Names of dims that ``to_geotiff`` / ``write_geotiff_gpu`` treat as the
# non-spatial band axis. Used both to remap ``(band, y, x)`` inputs to
# ``(y, x, band)`` before writing and to skip the band axis when inferring
# a GeoTransform from coords (see ``_coords_to_transform`` and issue #1643).
_BAND_DIM_NAMES = ('band', 'bands', 'channel')


def _wkt_to_epsg(wkt_or_proj: str) -> int | None:
"""Try to extract an EPSG code from a WKT or PROJ string.
Expand Down Expand Up @@ -191,9 +197,34 @@ def _coords_to_transform(da: xr.DataArray) -> GeoTransform | None:
on raster_type:
- PixelIsArea (default): origin = center - half_pixel (edge of pixel 0)
- PixelIsPoint: origin = center (center of pixel 0)

For 3D arrays the spatial dims are the two non-band dims. The helper
filters out any dim named ``band`` / ``bands`` / ``channel`` (see
``_BAND_DIM_NAMES``) regardless of position, so a ``(y, x, band)``,
``(band, y, x)``, or ``(y, band, x)`` DataArray returns the y/x
transform rather than picking up the band axis spacing as a pixel
size. ``to_geotiff`` itself remaps ``(band, y, x)`` arrays to
``(y, x, band)`` before writing pixel bytes, but it calls
:func:`_coords_to_transform` against the original DataArray, so the
helper must handle both layouts to keep the geo-transform consistent
with the file's coord arrays. See issue #1643.
"""
ydim = da.dims[-2]
xdim = da.dims[-1]
if da.ndim == 3:
# Drop the band-like dim and keep the two spatial dims in their
# original (y, x) order. Position-based fallback covers the case
# where none of the dims are named like a band axis.
spatial = tuple(d for d in da.dims if d not in _BAND_DIM_NAMES)
if len(spatial) == 2:
ydim, xdim = spatial[0], spatial[1]
else:
# No identifiable band dim; fall back to dims[-2:] so the
# original 2-D-style behaviour applies. This branch only
# triggers for unusual 3D layouts callers built by hand.
ydim = da.dims[-2]
xdim = da.dims[-1]
else:
ydim = da.dims[-2]
xdim = da.dims[-1]

if xdim not in da.coords or ydim not in da.coords:
return None
Expand Down Expand Up @@ -1166,7 +1197,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path, *,
if hasattr(raw, 'dask') and not cog and not _path_is_file_like:
dask_arr = raw
# Handle band-first dimension order (band, y, x) -> (y, x, band)
if raw.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
if raw.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES:
import dask.array as da
dask_arr = da.moveaxis(raw, 0, -1)
if dask_arr.ndim not in (2, 3):
Expand Down Expand Up @@ -1215,7 +1246,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path, *,
else:
arr = np.asarray(raw)
# Handle band-first dimension order (band, y, x) -> (y, x, band)
if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
if arr.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES:
arr = np.moveaxis(arr, 0, -1)
else:
if hasattr(data, 'get'):
Expand Down Expand Up @@ -2830,7 +2861,7 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray,
# this remap the writer treats arr.shape[2] as the band axis and
# produces a transposed file (issue #1580). The CPU writer does
# the same remap at the matching step in to_geotiff().
if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
if arr.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES:
arr = cupy.ascontiguousarray(cupy.moveaxis(arr, 0, -1))

# Prefer attrs['transform'] over the coord-derived transform: it
Expand Down
213 changes: 213 additions & 0 deletions xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Regression test for issue #1643.

``_coords_to_transform`` previously used ``dims[-2]`` and ``dims[-1]`` to
look up y/x coords. On a 3D ``(y, x, band)`` DataArray that picked
``x`` and ``band``, so ``to_geotiff`` silently wrote a wrong
GeoTransform when ``attrs['transform']`` was absent. The helper now
detects the band-like trailing/leading dim and uses the two spatial
dims regardless of position.
"""
from __future__ import annotations

import importlib.util

import numpy as np
import pytest
import xarray as xr

from xrspatial.geotiff import _coords_to_transform, open_geotiff, to_geotiff


def _gpu_available() -> bool:
if importlib.util.find_spec("cupy") is None:
return False
try:
import cupy
return bool(cupy.cuda.is_available())
except Exception:
return False


_HAS_GPU = _gpu_available()


def _make_geo_da_3d(dims):
"""3D DataArray with georeferenced y/x coords and a band axis."""
shape = []
for d in dims:
if d in ('y',):
shape.append(10)
elif d in ('x',):
shape.append(20)
else:
shape.append(3)
arr = np.arange(int(np.prod(shape)), dtype=np.uint8).reshape(shape)
coords = {
'y': np.linspace(100.0, 200.0, 10),
'x': np.linspace(500.0, 700.0, 20),
'band': np.arange(3),
}
return xr.DataArray(arr, dims=list(dims), coords=coords)


def test_coords_to_transform_yxband_returns_yx_spacing():
"""3D (y, x, band) picks y/x spacing rather than (x, band) spacing."""
da = _make_geo_da_3d(('y', 'x', 'band'))
gt = _coords_to_transform(da)
# y spacing = (200 - 100) / 9, x spacing = (700 - 500) / 19
assert gt is not None
np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19)
np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9)


def test_coords_to_transform_bandyx_returns_yx_spacing():
"""3D (band, y, x) also returns the y/x transform."""
da = _make_geo_da_3d(('band', 'y', 'x'))
gt = _coords_to_transform(da)
assert gt is not None
np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19)
np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9)

Comment on lines +53 to +70

@pytest.mark.parametrize('band_name', ['band', 'bands', 'channel'])
def test_coords_to_transform_3d_band_name_variants(band_name):
"""All recognized band-dim names (band, bands, channel) are filtered
out when picking the y/x spatial dims."""
arr = np.zeros((10, 20, 3), dtype=np.uint8)
da = xr.DataArray(
arr,
dims=['y', 'x', band_name],
coords={
'y': np.linspace(100.0, 200.0, 10),
'x': np.linspace(500.0, 700.0, 20),
band_name: np.arange(3),
},
)
gt = _coords_to_transform(da)
assert gt is not None
np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19)
np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9)


def test_coords_to_transform_2d_unchanged():
"""2D (y, x) keeps its original behaviour."""
da = xr.DataArray(
np.zeros((10, 20), dtype=np.uint8),
dims=['y', 'x'],
coords={
'y': np.linspace(100.0, 200.0, 10),
'x': np.linspace(500.0, 700.0, 20),
},
)
gt = _coords_to_transform(da)
assert gt is not None
np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19)
np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9)


def test_to_geotiff_roundtrip_3d_yxband_no_transform_attr(tmp_path):
"""to_geotiff -> open_geotiff round-trip on 3D arrays preserves coords.

Before the fix the on-disk transform was derived from (x, band)
spacing, so the round-tripped y/x coords had wrong pixel size and
origin. After the fix the 3D output matches the 2D output.
"""
da_3d = _make_geo_da_3d(('y', 'x', 'band'))
da_2d = xr.DataArray(
np.zeros((10, 20), dtype=np.uint8),
dims=['y', 'x'],
coords={
'y': np.linspace(100.0, 200.0, 10),
'x': np.linspace(500.0, 700.0, 20),
},
)

p2 = str(tmp_path / 'roundtrip_1643_2d.tif')
p3 = str(tmp_path / 'roundtrip_1643_3d.tif')
to_geotiff(da_2d, p2)
to_geotiff(da_3d, p3)

rt2 = open_geotiff(p2)
rt3 = open_geotiff(p3)
np.testing.assert_allclose(rt3.y.values, rt2.y.values)
np.testing.assert_allclose(rt3.x.values, rt2.x.values)
assert rt3.attrs.get('transform') == rt2.attrs.get('transform')


def test_to_geotiff_roundtrip_3d_bandyx_no_transform_attr(tmp_path):
"""(band, y, x) input round-trips with the correct transform.

``to_geotiff`` remaps a (band, y, x) input to (y, x, band) before
writing, but ``_coords_to_transform`` runs against the original
dim order. The fix handles both 3D layouts.
"""
da_3d = _make_geo_da_3d(('band', 'y', 'x'))
da_2d = xr.DataArray(
np.zeros((10, 20), dtype=np.uint8),
dims=['y', 'x'],
coords={
'y': np.linspace(100.0, 200.0, 10),
'x': np.linspace(500.0, 700.0, 20),
},
)

p2 = str(tmp_path / 'roundtrip_1643_2d_b.tif')
p3 = str(tmp_path / 'roundtrip_1643_3d_bandfirst.tif')
to_geotiff(da_2d, p2)
to_geotiff(da_3d, p3)

rt2 = open_geotiff(p2)
rt3 = open_geotiff(p3)
np.testing.assert_allclose(rt3.y.values, rt2.y.values)
np.testing.assert_allclose(rt3.x.values, rt2.x.values)


def test_to_geotiff_3d_without_transform_attr_does_not_invent_unit_pixels(
tmp_path):
"""Regression sanity: the bad transform was pixel_width=1.0 (band
axis spacing). Assert the round-tripped pixel_width is finite,
non-unit, and matches the source x spacing.
"""
da = _make_geo_da_3d(('y', 'x', 'band'))
p = str(tmp_path / 'roundtrip_1643_3d_not_unit.tif')
to_geotiff(da, p)
rt = open_geotiff(p)
pw = abs(float(rt.x.values[1] - rt.x.values[0]))
# Source x spacing is (700-500)/19 = ~10.526. The buggy path would
# have produced pw=1.0 (the band axis spacing).
assert pw > 1.5, (
f"round-tripped pixel_width={pw} suggests the band-axis spacing "
f"leaked into the GeoTransform; expected ~10.526")


@pytest.mark.skipif(not _HAS_GPU, reason="cupy + CUDA required")
def test_write_geotiff_gpu_roundtrip_3d_no_transform_attr(tmp_path):
"""GPU writer shares ``_coords_to_transform`` with the CPU writer.

Same regression on the GPU path: a 3D ``(y, x, band)`` cupy
DataArray without ``attrs['transform']`` would previously round-trip
through a unit pixel-width transform.
"""
import cupy as cp

from xrspatial.geotiff import write_geotiff_gpu

np_arr = np.arange(10 * 20 * 3, dtype=np.uint8).reshape(10, 20, 3)
da = xr.DataArray(
cp.asarray(np_arr),
dims=['y', 'x', 'band'],
coords={
'y': np.linspace(100.0, 200.0, 10),
'x': np.linspace(500.0, 700.0, 20),
'band': np.arange(3),
},
)
p = str(tmp_path / 'roundtrip_1643_3d_gpu.tif')
write_geotiff_gpu(da, p)
rt = open_geotiff(p)
pw = abs(float(rt.x.values[1] - rt.x.values[0]))
assert pw > 1.5, (
f"GPU writer round-tripped pixel_width={pw}; expected ~10.526")
ph = abs(float(rt.y.values[1] - rt.y.values[0]))
assert ph > 1.5, (
f"GPU writer round-tripped pixel_height={ph}; expected ~11.111")
Loading