From def3f693805a5e5397de134cf4042b9d17b5e951 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 4 May 2026 12:49:51 -0700 Subject: [PATCH] Refresh transform and nodata attrs on resample output (#1465) At the end of resample(), refresh three attrs that the prior code carried forward unchanged from the input: - transform: re-emit the rasterio 6-tuple (res_x, 0.0, left, 0.0, -res_y, top) from the new edge coordinates. _new_coords now returns the edge values it already computed so the caller can use them. Gated on 'transform' in agg.attrs. - _FillValue: resample outputs float32 with NaN as the missing-data sentinel; replace the input value with NaN so metadata matches the data. Gated on '_FillValue' in agg.attrs. - nodatavals: same treatment for the rasterio plural form. Other attrs (crs, crs_wkt, scales, offsets, units, long_name) round trip unchanged because resample does not touch CRS, dtype scaling, or units. Six new tests in TestMetadataPropagation cover refresh, absence, and preservation of unrelated attrs. --- xrspatial/resample.py | 33 +++++++++- xrspatial/tests/test_resample.py | 106 +++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 3 deletions(-) diff --git a/xrspatial/resample.py b/xrspatial/resample.py index 5eb93978..2d421f48 100644 --- a/xrspatial/resample.py +++ b/xrspatial/resample.py @@ -864,14 +864,41 @@ def _new_coords(vals, n_out): edge_start = vals[0] - half_first edge_end = vals[-1] + half_last px = (edge_end - edge_start) / n_out - return np.linspace(edge_start + px / 2, edge_end - px / 2, n_out), px + coords = np.linspace(edge_start + px / 2, edge_end - px / 2, n_out) + return coords, px, edge_start, edge_end - new_y, py = _new_coords(y_vals, out_h) - new_x, px = _new_coords(x_vals, out_w) + new_y, py, y_edge_start, y_edge_end = _new_coords(y_vals, out_h) + new_x, px, x_edge_start, x_edge_end = _new_coords(x_vals, out_w) new_attrs = dict(agg.attrs) new_attrs['res'] = (abs(px), abs(py)) + # Refresh `transform` if the input had one. The rasterio 6-tuple is + # (res_x, 0.0, left, 0.0, -res_y, top). `top` is the upper edge of + # the first row, which is `y_edge_start` when y is descending and + # `y_edge_end` when y is ascending. `left` is the lower edge of the + # first column, which is `x_edge_start` when x is ascending and + # `x_edge_end` when x is descending. + if 'transform' in agg.attrs: + out_res_x = abs(px) + out_res_y = abs(py) + top = y_edge_start if y_vals[0] > y_vals[-1] else y_edge_end + left = x_edge_start if x_vals[0] < x_vals[-1] else x_edge_end + new_attrs['transform'] = ( + out_res_x, 0.0, left, 0.0, -out_res_y, top, + ) + + # Resample currently emits float32 with NaN as the missing-data + # sentinel regardless of input dtype. If the input declared a + # different sentinel via `_FillValue` or `nodatavals`, replace the + # value with NaN so the metadata matches the actual data. Leave the + # keys absent when the input did not have them. + if '_FillValue' in agg.attrs: + new_attrs['_FillValue'] = float('nan') + if 'nodatavals' in agg.attrs: + old = agg.attrs['nodatavals'] + new_attrs['nodatavals'] = tuple(float('nan') for _ in old) + result = xr.DataArray( result_data, name=name, diff --git a/xrspatial/tests/test_resample.py b/xrspatial/tests/test_resample.py index d77e3061..6469295e 100644 --- a/xrspatial/tests/test_resample.py +++ b/xrspatial/tests/test_resample.py @@ -120,6 +120,112 @@ def test_res_attribute_updated(self, grid_8x8): assert pytest.approx(out.attrs['res'][1], abs=0.01) == 2.0 +# --------------------------------------------------------------------------- +# Metadata propagation +# --------------------------------------------------------------------------- + +class TestMetadataPropagation: + """resample() should refresh stale grid attrs and nodata sentinels.""" + + @staticmethod + def _raster_with_transform(transform): + """4x4 raster whose coords match the given rasterio 6-tuple.""" + res_x, _, left, _, neg_res_y, top = transform + res_y = -neg_res_y + x = np.array([left + (i + 0.5) * res_x for i in range(4)]) + y = np.array([top - (i + 0.5) * res_y for i in range(4)]) + data = np.arange(16, dtype=np.float32).reshape(4, 4) + return xr.DataArray( + data, + dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'transform': transform, 'res': (res_x, res_y)}, + ) + + def test_transform_refreshed_on_downsample(self): + raster = self._raster_with_transform( + (1.0, 0.0, 100.0, 0.0, -1.0, 200.0) + ) + out = resample(raster, scale_factor=0.5) + assert tuple(out.attrs['transform']) == ( + 2.0, 0.0, 100.0, 0.0, -2.0, 200.0, + ) + + def test_transform_absent_stays_absent(self): + # grid_4x4 fixture has no transform attr. + data = np.arange(16, dtype=np.float32).reshape(4, 4) + raster = xr.DataArray( + data, + dims=['y', 'x'], + coords={'y': np.linspace(3, 0, 4), 'x': np.linspace(0, 3, 4)}, + attrs={'res': (1.0, 1.0)}, + ) + out = resample(raster, scale_factor=0.5) + assert 'transform' not in out.attrs + + def test_fill_value_replaced_with_nan(self): + data = np.arange(16, dtype=np.float32).reshape(4, 4) + raster = xr.DataArray( + data, + dims=['y', 'x'], + coords={'y': np.linspace(3, 0, 4), 'x': np.linspace(0, 3, 4)}, + attrs={'res': (1.0, 1.0), '_FillValue': -9999}, + ) + out = resample(raster, scale_factor=0.5) + assert '_FillValue' in out.attrs + assert np.isnan(out.attrs['_FillValue']) + + def test_fill_value_absent_stays_absent(self): + data = np.arange(16, dtype=np.float32).reshape(4, 4) + raster = xr.DataArray( + data, + dims=['y', 'x'], + coords={'y': np.linspace(3, 0, 4), 'x': np.linspace(0, 3, 4)}, + attrs={'res': (1.0, 1.0)}, + ) + out = resample(raster, scale_factor=0.5) + assert '_FillValue' not in out.attrs + + def test_nodatavals_replaced_with_nan(self): + data = np.arange(16, dtype=np.float32).reshape(4, 4) + raster = xr.DataArray( + data, + dims=['y', 'x'], + coords={'y': np.linspace(3, 0, 4), 'x': np.linspace(0, 3, 4)}, + attrs={'res': (1.0, 1.0), 'nodatavals': (-9999,)}, + ) + out = resample(raster, scale_factor=0.5) + assert 'nodatavals' in out.attrs + nv = out.attrs['nodatavals'] + assert len(nv) == 1 + assert np.isnan(nv[0]) + + def test_other_attrs_preserved(self): + # crs, units, long_name, scales, offsets should round-trip. + data = np.arange(16, dtype=np.float32).reshape(4, 4) + raster = xr.DataArray( + data, + dims=['y', 'x'], + coords={'y': np.linspace(3, 0, 4), 'x': np.linspace(0, 3, 4)}, + attrs={ + 'res': (1.0, 1.0), + 'crs': 'EPSG:4326', + 'crs_wkt': 'GEOGCS["WGS 84"]', + 'units': 'm', + 'long_name': 'elevation', + 'scales': (1.0,), + 'offsets': (0.0,), + }, + ) + out = resample(raster, scale_factor=0.5) + assert out.attrs['crs'] == 'EPSG:4326' + assert out.attrs['crs_wkt'] == 'GEOGCS["WGS 84"]' + assert out.attrs['units'] == 'm' + assert out.attrs['long_name'] == 'elevation' + assert out.attrs['scales'] == (1.0,) + assert out.attrs['offsets'] == (0.0,) + + # --------------------------------------------------------------------------- # Correctness: known values # ---------------------------------------------------------------------------