Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi-file dataset spatial average orientation and weights when lon bounds span prime meridian #495

Merged
merged 4 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions tests/test_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,24 @@ def test_spatial_average_for_lat_region(self):

assert result.identical(expected)

def test_spatial_average_for_domain_wrapping_p_meridian_non_cf_conventions(
self,
):
ds = self.ds.copy()

# get spatial average for original dataset
ref = ds.spatial.average("ts").ts

# change first bound from -0.9375 to 359.0625
lon_bnds = ds.lon_bnds.copy()
lon_bnds[0, 0] = 359.0625
ds["lon_bnds"] = lon_bnds

# check spatial average with new (bad) bound
result = ds.spatial.average("ts").ts

assert result.identical(ref)

Comment on lines +183 to +200
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test that non-CF compliant bounds still lead to correct spatial averaging.

@requires_dask
def test_spatial_average_for_lat_region_and_keep_weights_with_dask(self):
ds = self.ds.copy().chunk(2)
Expand Down Expand Up @@ -269,26 +287,17 @@ def setup(self):
decode_times=True, cf_compliant=False, has_bounds=True
)

def test_bounds_reordered_when_upper_indexed_first(self):
def test_value_error_thrown_for_multiple_out_of_order_lon_bounds(self):
domain_bounds = xr.DataArray(
name="lon_bnds",
data=np.array(
[[-89.375, -90], [0.0, -89.375], [0.0, 89.375], [89.375, 90]]
),
data=np.array([[3, 1], [5, 3], [5, 7], [7, 9]]),
coords={"lat": self.ds.lat},
dims=["lat", "bnds"],
)
result = self.ds.spatial._force_domain_order_low_to_high(domain_bounds)

expected_domain_bounds = xr.DataArray(
name="lon_bnds",
data=np.array(
[[-90, -89.375], [-89.375, 0.0], [0.0, 89.375], [89.375, 90]]
),
coords={"lat": self.ds.lat},
dims=["lat", "bnds"],
)
assert result.identical(expected_domain_bounds)
# Check _get_longitude_weights raises error when there are
# > 1 out-of-order bounds for the dataset.
with pytest.raises(ValueError):
self.ds.spatial._get_longitude_weights(domain_bounds, region_bounds=None)
Comment on lines +299 to +300
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test here is much simpler than has been included in past iterations.

It basically checks to see if there is more than one longitude bound in which the lower bound (i.e., bound[0]) is greater than the upper bound (i.e., bounds[1]). This can happen once for longitude (when spanning a prime meridian) but not more than once for rectilinear datasets. If it happens more than once, a ValueError is thrown.


def test_raises_error_if_dataset_has_multiple_bounds_variables_for_an_axis(self):
ds = self.ds.copy()
Expand Down
56 changes: 18 additions & 38 deletions xcdat/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,6 @@ def get_weights(
"to reference a specific data variable's axis bounds."
)

# The logic for generating longitude weights depends on the
# bounds being ordered such that d_bounds[:, 0] < d_bounds[:, 1].
# They are re-ordered (if need be) for the purpose of creating
# weights.
d_bounds = self._force_domain_order_low_to_high(d_bounds)

Comment on lines -284 to -289
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessary for latitude because we use the absolute difference of the domain bounds. So the logic was updated in the _get_longitude_weights() function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this also caused the error

xarray can't set arrays with multiple array indices to dask yet.

because d_bounds were not loaded / copied.

r_bounds = axis_bounds[key]["region"]

weights = axis_bounds[key]["weights_method"](d_bounds, r_bounds)
Expand Down Expand Up @@ -324,33 +318,6 @@ def _validate_axis_arg(self, axis: List[SpatialAxis]):
# Check the axis coordinate variable exists in the Dataset.
get_dim_coords(self._dataset, key)

def _force_domain_order_low_to_high(self, domain_bounds: xr.DataArray):
"""Reorders the ``domain_bounds`` low-to-high.

This method ensures all lower bound values are less than the upper bound
values (``domain_bounds[:, 1] < domain_bounds[:, 1]``).

Parameters
----------
domain_bounds: xr.DataArray
The bounds of an axis.

Returns
------
xr.DataArray
The bounds of an axis (re-ordered if applicable).
"""
index_bad_cells = np.where(domain_bounds[:, 1] - domain_bounds[:, 0] < 0)[0]

if len(index_bad_cells) > 0:
new_domain_bounds = domain_bounds.copy()
new_domain_bounds[index_bad_cells, 0] = domain_bounds[index_bad_cells, 1]
new_domain_bounds[index_bad_cells, 1] = domain_bounds[index_bad_cells, 0]

return new_domain_bounds

return domain_bounds

Comment on lines -327 to -353
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is removed and replaced with a simpler consistency check and existing logic.

def _validate_region_bounds(self, axis: SpatialAxis, bounds: RegionAxisBounds):
"""Validates the ``bounds`` arg based on a set of criteria.

Expand Down Expand Up @@ -441,12 +408,29 @@ def _get_longitude_weights(
-------
xr.DataArray
The longitude axis weights.

Raises
------
ValueError
If the there are multiple instances in which the
domain_bounds[:, 0] > domain_bounds[:, 1]
"""
p_meridian_index: Optional[np.ndarray] = None
d_bounds = domain_bounds.copy()

pm_cells = np.where(domain_bounds[:, 1] - domain_bounds[:, 0] < 0)[0]
if len(pm_cells) > 1:
raise ValueError(
"More than one longitude bound is out of order. Only one bound "
"value spanning the prime meridian is permitted in data on "
"a rectilinear grid."
)
d_bounds: xr.DataArray = self._swap_lon_axis(d_bounds, to=360) # type: ignore
p_meridian_index = _get_prime_meridian_index(d_bounds)
if p_meridian_index is not None:
d_bounds = _align_lon_bounds_to_360(d_bounds, p_meridian_index)

if region_bounds is not None:
d_bounds: xr.DataArray = self._swap_lon_axis(d_bounds, to=360) # type: ignore
r_bounds: np.ndarray = self._swap_lon_axis(
region_bounds, to=360
) # type:ignore
Expand All @@ -455,10 +439,6 @@ def _get_longitude_weights(
if is_region_circular:
r_bounds = np.array([0.0, 360.0])

p_meridian_index = _get_prime_meridian_index(d_bounds)
if p_meridian_index is not None:
d_bounds = _align_lon_bounds_to_360(d_bounds, p_meridian_index)

d_bounds = self._scale_domain_to_region(d_bounds, r_bounds)

weights = self._calculate_weights(d_bounds)
Expand Down
Loading