Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions xrspatial/reproject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,10 +675,16 @@ def reproject(
)

ydim, xdim = _find_spatial_dims(raster)
out_attrs = {
'crs': tgt_wkt,
'nodata': nd,
}
# Carry input attrs forward so units, long_name, scale_factor, etc.
# survive the transform. Pop attrs that are stale after reprojection:
# the affine `transform` and grid `res` describe the old grid, and
# `crs_wkt` would duplicate (or contradict) the canonical `crs` we re-emit.
out_attrs = {**raster.attrs}
out_attrs.pop('transform', None)
out_attrs.pop('crs_wkt', None)
out_attrs.pop('res', None)
out_attrs['crs'] = tgt_wkt
out_attrs['nodata'] = nd
if tgt_vertical_crs is not None:
out_attrs['vertical_crs'] = tgt_vertical_crs

Expand Down Expand Up @@ -1485,15 +1491,22 @@ def merge(
ydim = rasters[0].dims[-2]
xdim = rasters[0].dims[-1]

# Carry the first raster's attrs forward (matches the default
# strategy='first'). Drop attrs describing the old grid: `transform`,
# `res`, and the duplicate `crs_wkt` are no longer accurate.
out_attrs = {**rasters[0].attrs}
out_attrs.pop('transform', None)
out_attrs.pop('crs_wkt', None)
out_attrs.pop('res', None)
out_attrs['crs'] = tgt_wkt
out_attrs['nodata'] = nd

result = xr.DataArray(
result_data,
dims=[ydim, xdim],
coords={ydim: y_coords, xdim: x_coords},
name=name or 'merged',
attrs={
'crs': tgt_wkt,
'nodata': nd,
},
name=name or rasters[0].name or 'merged',
attrs=out_attrs,
)
return result

Expand Down
114 changes: 114 additions & 0 deletions xrspatial/tests/test_reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -1756,6 +1756,120 @@ def test_detect_nodata_accepts_finite(self):
assert _detect_nodata(r, nodata=-9999) == -9999.0


class TestMetadataPreservation:
"""reproject() and merge() must carry input attrs forward."""

@staticmethod
def _raster_with_attrs(extra_attrs=None, h=8, w=8,
crs='EPSG:4326',
x_range=(-1, 1), y_range=(-1, 1),
name='dem'):
data = np.ones((h, w), dtype=np.float64)
attrs = {'crs': crs, 'nodata': np.nan}
if extra_attrs:
attrs.update(extra_attrs)
y = np.linspace(y_range[1], y_range[0], h)
x = np.linspace(x_range[0], x_range[1], w)
return xr.DataArray(
data, dims=['y', 'x'],
coords={'y': y, 'x': x},
name=name, attrs=attrs,
)

# reproject() ----------------------------------------------------------

def test_reproject_preserves_units_attr(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs({'units': 'meters'})
result = reproject(raster, 'EPSG:4326', resolution=0.25)
assert result.attrs.get('units') == 'meters'

def test_reproject_preserves_scale_offset(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs(
{'scale_factor': 0.1, 'add_offset': 10.0}
)
result = reproject(raster, 'EPSG:4326', resolution=0.25)
assert result.attrs.get('scale_factor') == 0.1
assert result.attrs.get('add_offset') == 10.0

def test_reproject_preserves_long_name(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs({'long_name': 'elevation'})
result = reproject(raster, 'EPSG:4326', resolution=0.25)
assert result.attrs.get('long_name') == 'elevation'

def test_reproject_drops_stale_transform(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs(
{'transform': (1.0, 0.0, 0.0, 0.0, -1.0, 0.0)}
)
result = reproject(raster, 'EPSG:3857')
assert 'transform' not in result.attrs

def test_reproject_drops_stale_res(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs({'res': (1.0, 1.0)})
result = reproject(raster, 'EPSG:3857')
assert 'res' not in result.attrs

def test_reproject_overrides_crs(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs(crs='EPSG:4326')
result = reproject(raster, 'EPSG:3857')
# Output crs is the new target CRS WKT, not the input EPSG:4326
assert 'crs' in result.attrs
out_crs = result.attrs['crs']
assert out_crs != 'EPSG:4326'
# WKT for 3857 mentions Mercator / pseudo-mercator
assert 'Mercator' in out_crs or '3857' in out_crs

def test_reproject_drops_stale_crs_wkt(self):
from xrspatial.reproject import reproject
raster = self._raster_with_attrs({'crs_wkt': 'OLD_DUPLICATE_WKT'})
result = reproject(raster, 'EPSG:3857')
assert 'crs_wkt' not in result.attrs

# merge() --------------------------------------------------------------

def test_merge_preserves_first_raster_attrs(self):
from xrspatial.reproject import merge
a = self._raster_with_attrs(
{'units': 'm', 'long_name': 'elev'},
x_range=(-5, 0), y_range=(-5, 5), name='dem_a',
)
b = self._raster_with_attrs(
{'units': 'feet'},
x_range=(0, 5), y_range=(-5, 5), name='dem_b',
)
result = merge([a, b], resolution=1.0)
assert result.attrs.get('units') == 'm'
assert result.attrs.get('long_name') == 'elev'

def test_merge_drops_stale_transform(self):
from xrspatial.reproject import merge
a = self._raster_with_attrs(
{'transform': (1.0, 0.0, 0.0, 0.0, -1.0, 0.0)},
x_range=(-5, 0), y_range=(-5, 5),
)
b = self._raster_with_attrs(
x_range=(0, 5), y_range=(-5, 5),
)
result = merge([a, b], resolution=1.0)
assert 'transform' not in result.attrs

def test_merge_name_falls_back_to_first_raster(self):
from xrspatial.reproject import merge
a = self._raster_with_attrs(
x_range=(-5, 0), y_range=(-5, 5), name='dem_a',
)
b = self._raster_with_attrs(
x_range=(0, 5), y_range=(-5, 5), name='dem_b',
)
result = merge([a, b], resolution=1.0)
assert result.name == 'dem_a'


# ---------------------------------------------------------------------------
# Backend parity: dask dtype + same-CRS dask merge + cupy
# ---------------------------------------------------------------------------
Expand Down