Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Jun 13, 2022
1 parent 3cc253a commit 225df77
Showing 1 changed file with 16 additions and 52 deletions.
68 changes: 16 additions & 52 deletions tests/test_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,11 @@ def test_spatial_average_for_lat_region_and_keep_weights(self):

xr.testing.assert_allclose(result, expected)

@requires_dask
def test_chunked_spatial_average_for_lat_region(self):
ds = self.ds.copy().chunk(2)
def test_spatial_average_for_lat_region(self):
ds = self.ds.copy()

# Specifying axis as a str instead of list of str.
result = ds.spatial.average(
"ts", axis=["Y"], lat_bounds=(-5.0, 5), keep_weights=True
)
result = ds.spatial.average("ts", axis=["Y"], lat_bounds=(-5.0, 5))

expected = self.ds.copy()
expected["ts"] = xr.DataArray(
Expand All @@ -160,19 +157,17 @@ def test_chunked_spatial_average_for_lat_region(self):
coords={"time": expected.time, "lon": expected.lon},
dims=["time", "lon"],
)
expected["lat_wts"] = xr.DataArray(
name="lat_wts",
data=np.array([0.0, 0.08715574, 0.08715574, 0.0]),
dims=["lat"],
)

xr.testing.assert_allclose(result, expected)
assert result.identical(expected)

def test_spatial_average_for_lat_region(self):
ds = self.ds.copy()
@requires_dask
def test_spatial_average_for_lat_region_and_keep_weights_with_dask(self):
ds = self.ds.copy().chunk(2)

# Specifying axis as a str instead of list of str.
result = ds.spatial.average("ts", axis=["Y"], lat_bounds=(-5.0, 5))
result = ds.spatial.average(
"ts", axis=["Y"], lat_bounds=(-5.0, 5), keep_weights=True
)

expected = self.ds.copy()
expected["ts"] = xr.DataArray(
Expand All @@ -182,8 +177,13 @@ def test_spatial_average_for_lat_region(self):
coords={"time": expected.time, "lon": expected.lon},
dims=["time", "lon"],
)
expected["lat_wts"] = xr.DataArray(
name="lat_wts",
data=np.array([0.0, 0.08715574, 0.08715574, 0.0]),
dims=["lat"],
)

assert result.identical(expected)
xr.testing.assert_allclose(result, expected)

def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self):
ds = self.ds.copy()
Expand Down Expand Up @@ -216,42 +216,6 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self):

xr.testing.assert_allclose(result, expected)

@requires_dask
def test_spatial_average_for_lat_region_and_keep_weights_with_dask(self):
ds = self.ds.copy().chunk(2)

result = ds.spatial.average(
"ts",
axis=["Y"],
lat_bounds=(-5.0, 5),
lon_bounds=(-170, -120.1),
keep_weights=True,
)

expected = self.ds.copy()
expected["ts"] = xr.DataArray(
data=np.array(
[[2.25, 2.25, 2.25, 2.25], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]
),
coords={"time": expected.time, "lon": expected.lon},
dims=["time", "lon"],
)

expected["lat_lon_wts"] = xr.DataArray(
name="ts_weights",
data=np.array(
[
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 4.34907156, 4.34907156, 0.0],
[0.0, 0.0, 0.0, 0.0],
]
),
dims=["lon", "lat"],
)

xr.testing.assert_allclose(result, expected)

def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self):
ds = self.ds.copy()

Expand Down

0 comments on commit 225df77

Please sign in to comment.