Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions xrspatial/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,14 +864,41 @@ def _new_coords(vals, n_out):
edge_start = vals[0] - half_first
edge_end = vals[-1] + half_last
px = (edge_end - edge_start) / n_out
return np.linspace(edge_start + px / 2, edge_end - px / 2, n_out), px
coords = np.linspace(edge_start + px / 2, edge_end - px / 2, n_out)
return coords, px, edge_start, edge_end

new_y, py = _new_coords(y_vals, out_h)
new_x, px = _new_coords(x_vals, out_w)
new_y, py, y_edge_start, y_edge_end = _new_coords(y_vals, out_h)
new_x, px, x_edge_start, x_edge_end = _new_coords(x_vals, out_w)

new_attrs = dict(agg.attrs)
new_attrs['res'] = (abs(px), abs(py))

# Refresh `transform` if the input had one. The rasterio 6-tuple is
# (res_x, 0.0, left, 0.0, -res_y, top). `top` is the upper edge of
# the first row, which is `y_edge_start` when y is descending and
# `y_edge_end` when y is ascending. `left` is the lower edge of the
# first column, which is `x_edge_start` when x is ascending and
# `x_edge_end` when x is descending.
if 'transform' in agg.attrs:
out_res_x = abs(px)
out_res_y = abs(py)
top = y_edge_start if y_vals[0] > y_vals[-1] else y_edge_end
left = x_edge_start if x_vals[0] < x_vals[-1] else x_edge_end
new_attrs['transform'] = (
out_res_x, 0.0, left, 0.0, -out_res_y, top,
)

# Resample currently emits float32 with NaN as the missing-data
# sentinel regardless of input dtype. If the input declared a
# different sentinel via `_FillValue` or `nodatavals`, replace the
# value with NaN so the metadata matches the actual data. Leave the
# keys absent when the input did not have them.
if '_FillValue' in agg.attrs:
new_attrs['_FillValue'] = float('nan')
if 'nodatavals' in agg.attrs:
old = agg.attrs['nodatavals']
new_attrs['nodatavals'] = tuple(float('nan') for _ in old)

result = xr.DataArray(
result_data,
name=name,
Expand Down
106 changes: 106 additions & 0 deletions xrspatial/tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,112 @@ def test_res_attribute_updated(self, grid_8x8):
assert pytest.approx(out.attrs['res'][1], abs=0.01) == 2.0


# ---------------------------------------------------------------------------
# Metadata propagation
# ---------------------------------------------------------------------------

class TestMetadataPropagation:
"""resample() should refresh stale grid attrs and nodata sentinels."""

@staticmethod
def _raster_with_transform(transform):
"""4x4 raster whose coords match the given rasterio 6-tuple."""
res_x, _, left, _, neg_res_y, top = transform
res_y = -neg_res_y
x = np.array([left + (i + 0.5) * res_x for i in range(4)])
y = np.array([top - (i + 0.5) * res_y for i in range(4)])
data = np.arange(16, dtype=np.float32).reshape(4, 4)
return xr.DataArray(
data,
dims=['y', 'x'],
coords={'y': y, 'x': x},
attrs={'transform': transform, 'res': (res_x, res_y)},
)

def test_transform_refreshed_on_downsample(self):
raster = self._raster_with_transform(
(1.0, 0.0, 100.0, 0.0, -1.0, 200.0)
)
out = resample(raster, scale_factor=0.5)
assert tuple(out.attrs['transform']) == (
2.0, 0.0, 100.0, 0.0, -2.0, 200.0,
)

def test_transform_absent_stays_absent(self):
# grid_4x4 fixture has no transform attr.
data = np.arange(16, dtype=np.float32).reshape(4, 4)
raster = xr.DataArray(
data,
dims=['y', 'x'],
coords={'y': np.linspace(3, 0, 4), 'x': np.linspace(0, 3, 4)},
attrs={'res': (1.0, 1.0)},
)
out = resample(raster, scale_factor=0.5)
assert 'transform' not in out.attrs

def test_fill_value_replaced_with_nan(self):
data = np.arange(16, dtype=np.float32).reshape(4, 4)
raster = xr.DataArray(
data,
dims=['y', 'x'],
coords={'y': np.linspace(3, 0, 4), 'x': np.linspace(0, 3, 4)},
attrs={'res': (1.0, 1.0), '_FillValue': -9999},
)
out = resample(raster, scale_factor=0.5)
assert '_FillValue' in out.attrs
assert np.isnan(out.attrs['_FillValue'])

def test_fill_value_absent_stays_absent(self):
data = np.arange(16, dtype=np.float32).reshape(4, 4)
raster = xr.DataArray(
data,
dims=['y', 'x'],
coords={'y': np.linspace(3, 0, 4), 'x': np.linspace(0, 3, 4)},
attrs={'res': (1.0, 1.0)},
)
out = resample(raster, scale_factor=0.5)
assert '_FillValue' not in out.attrs

def test_nodatavals_replaced_with_nan(self):
data = np.arange(16, dtype=np.float32).reshape(4, 4)
raster = xr.DataArray(
data,
dims=['y', 'x'],
coords={'y': np.linspace(3, 0, 4), 'x': np.linspace(0, 3, 4)},
attrs={'res': (1.0, 1.0), 'nodatavals': (-9999,)},
)
out = resample(raster, scale_factor=0.5)
assert 'nodatavals' in out.attrs
nv = out.attrs['nodatavals']
assert len(nv) == 1
assert np.isnan(nv[0])

def test_other_attrs_preserved(self):
# crs, units, long_name, scales, offsets should round-trip.
data = np.arange(16, dtype=np.float32).reshape(4, 4)
raster = xr.DataArray(
data,
dims=['y', 'x'],
coords={'y': np.linspace(3, 0, 4), 'x': np.linspace(0, 3, 4)},
attrs={
'res': (1.0, 1.0),
'crs': 'EPSG:4326',
'crs_wkt': 'GEOGCS["WGS 84"]',
'units': 'm',
'long_name': 'elevation',
'scales': (1.0,),
'offsets': (0.0,),
},
)
out = resample(raster, scale_factor=0.5)
assert out.attrs['crs'] == 'EPSG:4326'
assert out.attrs['crs_wkt'] == 'GEOGCS["WGS 84"]'
assert out.attrs['units'] == 'm'
assert out.attrs['long_name'] == 'elevation'
assert out.attrs['scales'] == (1.0,)
assert out.attrs['offsets'] == (0.0,)


# ---------------------------------------------------------------------------
# Correctness: known values
# ---------------------------------------------------------------------------
Expand Down
Loading