Skip to content

Commit

Permalink
Update create_grid args to improve usability (#507)
Browse files Browse the repository at this point in the history
* Refactors create_bounds from BoundsAccessor

* Adds create_axis function

* Deprecates create_grid's **kwargs and implements new x, y, z arguments

* Fixes how create_grid creates the Dataset

* Updates create_*_grid methods to use new create_grid

* Fixes create_grid method signature

* Removes old documentation

* Adds proper deprecation notice to docstring

* Updates vertical regrid example to use new create_grid

* Apply suggestions from code review

Co-authored-by: Tom Vo <tomvothecoder@gmail.com>

* Fixes converting standard name to cf axis

* Fixes formatting

* Adds additional suggested fixes

---------

Co-authored-by: Tom Vo <tomvothecoder@gmail.com>
  • Loading branch information
jasonb5 and tomvothecoder committed Jul 7, 2023
1 parent b46c331 commit c2fd8fd
Show file tree
Hide file tree
Showing 6 changed files with 574 additions and 209 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Below is a list of top-level API functions that are available in ``xcdat``.
compare_datasets
get_dim_coords
get_dim_keys
create_axis
create_gaussian_grid
create_global_mean_grid
create_grid
Expand Down
173 changes: 101 additions & 72 deletions docs/examples/regridding-vertical.ipynb

Large diffs are not rendered by default.

151 changes: 146 additions & 5 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime
import re
import sys
import warnings
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -770,13 +772,144 @@ def test_preserve_bounds(self):


class TestGrid:
@pytest.fixture(autouse=True)
def setUp(self):
self.lat_data = np.array([-45, 0, 45])
self.lat = xr.DataArray(self.lat_data.copy(), dims=["lat"], name="lat")

self.lat_bnds_data = np.array([[-67.5, -22.5], [-22.5, 22.5], [22.5, 67.5]])
self.lat_bnds = xr.DataArray(
self.lat_bnds_data.copy(), dims=["lat", "bnds"], name="lat_bnds"
)

self.lon_data = np.array([30, 60, 90, 120, 150])
self.lon = xr.DataArray(self.lon_data.copy(), dims=["lon"], name="lon")

self.lon_bnds_data = np.array(
[[15, 45], [45, 75], [75, 105], [105, 135], [135, 165]]
)
self.lon_bnds = xr.DataArray(
self.lon_bnds_data.copy(), dims=["lon", "bnds"], name="lon_bnds"
)

def test_create_axis(self):
expected_axis_attrs = {
"axis": "Y",
"units": "degrees_north",
"coordinate": "latitude",
"bounds": "lat_bnds",
}

axis, bnds = grid.create_axis("lat", self.lat_data)

assert np.array_equal(axis, self.lat_data)
assert bnds is not None
assert bnds.attrs["xcdat_bounds"] == "True"
assert axis.attrs == expected_axis_attrs

def test_create_axis_user_attrs(self):
expected_axis_attrs = {
"axis": "Y",
"units": "degrees_south",
"coordinate": "latitude",
"bounds": "lat_bnds",
"custom": "value",
}

axis, bnds = grid.create_axis(
"lat", self.lat_data, attrs={"custom": "value", "units": "degrees_south"}
)

assert np.array_equal(axis, self.lat_data)
assert bnds is not None
assert bnds.attrs["xcdat_bounds"] == "True"
assert axis.attrs == expected_axis_attrs

def test_create_axis_from_list(self):
axis, bnds = grid.create_axis("lat", self.lat_data, bounds=self.lat_bnds_data)

assert np.array_equal(axis, self.lat_data)
assert bnds is not None
assert np.array_equal(bnds, self.lat_bnds_data)

def test_create_axis_no_bnds(self):
expected_axis_attrs = {
"axis": "Y",
"units": "degrees_north",
"coordinate": "latitude",
}

axis, bnds = grid.create_axis("lat", self.lat_data, generate_bounds=False)

assert np.array_equal(axis, self.lat_data)
assert bnds is None
assert axis.attrs == expected_axis_attrs

def test_create_axis_user_bnds(self):
expected_axis_attrs = {
"axis": "Y",
"units": "degrees_north",
"coordinate": "latitude",
"bounds": "lat_bnds",
}

axis, bnds = grid.create_axis("lat", self.lat_data, bounds=self.lat_bnds_data)

assert np.array_equal(axis, self.lat_data)
assert bnds is not None
assert np.array_equal(bnds, self.lat_bnds_data)
assert "xcdat_bounds" not in bnds.attrs
assert axis.attrs == expected_axis_attrs

def test_create_axis_invalid_name(self):
with pytest.raises(
ValueError, match="The name 'mass' is not valid for an axis name."
):
grid.create_axis("mass", self.lat_data)

def test_empty_grid(self):
with pytest.raises(
ValueError, match="Must pass at least 1 coordinate to create a grid."
ValueError, match="Must pass at least 1 axis to create a grid."
):
grid.create_grid()

def test_unexpected_coordinate(self):
def test_create_grid(self):
new_grid = grid.create_grid(x=self.lon, y=self.lat)

assert np.array_equal(new_grid.lat, self.lat)
assert np.array_equal(new_grid.lon, self.lon)

def test_create_grid_with_bounds(self):
new_grid = grid.create_grid(
x=(self.lon, self.lon_bnds), y=(self.lat, self.lat_bnds)
)

assert np.array_equal(new_grid.lat, self.lat)
assert new_grid.lat.attrs["bounds"] == self.lat_bnds.name
assert np.array_equal(new_grid.lat_bnds, self.lat_bnds)

assert np.array_equal(new_grid.lon, self.lon)
assert new_grid.lon.attrs["bounds"] == self.lon_bnds.name
assert np.array_equal(new_grid.lon_bnds, self.lon_bnds)

def test_create_grid_user_attrs(self):
lev = xr.DataArray(np.linspace(1000, 1, 2), dims=["lev"], name="lev")

new_grid = grid.create_grid(z=lev, attrs={"custom": "value"})

assert "custom" in new_grid.attrs
assert new_grid.attrs["custom"] == "value"

def test_create_grid_wrong_axis_value(self):
with pytest.raises(
ValueError,
match=re.escape(
"Argument 'x' should be an xr.DataArray representing coordinates or a tuple (xr.DataArray, xr.DataArray) representing coordinates and bounds."
),
):
grid.create_grid(x=(self.lon, self.lon_bnds, self.lat)) # type: ignore[arg-type]

def test_deprecated_unexpected_coordinate(self):
lev = np.linspace(1000, 1, 2)

with pytest.raises(
Expand All @@ -785,16 +918,24 @@ def test_unexpected_coordinate(self):
):
grid.create_grid(lev=lev, mass=np.linspace(10, 20, 2))

def test_create_grid_lev(self):
def test_deprecated_create_grid_lev(self):
lev = np.linspace(1000, 1, 2)
lev_bnds = np.array([[1499.5, 500.5], [500.5, -498.5]])

new_grid = grid.create_grid(lev=(lev, lev_bnds))
with warnings.catch_warnings(record=True) as w:
new_grid = grid.create_grid(lev=(lev, lev_bnds))

assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert (
str(w[0].message)
== "**kwargs will be deprecated, see docstring and use 'x', 'y', or 'z' arguments"
)

assert np.array_equal(new_grid.lev, lev)
assert np.array_equal(new_grid.lev_bnds, lev_bnds)

def test_create_grid(self):
def test_deprecated_create_grid(self):
lat = np.array([-45, 0, 45])
lon = np.array([30, 60, 90, 120, 150])
lat_bnds = np.array([[-67.5, -22.5], [-22.5, 22.5], [22.5, 67.5]])
Expand Down
1 change: 1 addition & 0 deletions xcdat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from xcdat.dataset import decode_time, open_dataset, open_mfdataset # noqa: F401
from xcdat.regridder.accessor import RegridderAccessor # noqa: F401
from xcdat.regridder.grid import ( # noqa: F401
create_axis,
create_gaussian_grid,
create_global_mean_grid,
create_grid,
Expand Down
Loading

0 comments on commit c2fd8fd

Please sign in to comment.