Skip to content

Commit

Permalink
sign_test allow datasets (#199)
Browse files Browse the repository at this point in the history
* confidence one-dim, allow datasets

* Update test_comparative.py

* add_as_coords
  • Loading branch information
aaronspring committed Oct 8, 2020
1 parent 5252261 commit f628918
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 29 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ Features
:py:func:`~xskillscore.me`. (:issue:`202`, :pr:`200`)
`Andrew Huang`_

Bug Fixes
~~~~~~~~~
- :py:func:`~xskillscore.sign_test` now works for ``xr.Dataset`` inputs.
(:issue:`198`, :pr:`199`) `Aaron Spring`_

Internal Changes
~~~~~~~~~~~~~~~~
- Added Python 3.7 and Python 3.8 to the CI. Use the latest version of Python 3
Expand Down
22 changes: 7 additions & 15 deletions docs/source/quick-start.ipynb

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions xskillscore/core/comparative.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import scipy.stats as st
import xarray as xr

from .utils import _add_as_coord


def sign_test(
forecasts1,
Expand Down Expand Up @@ -79,8 +81,8 @@ def sign_test(
... coords=[('time', np.arange(30))])
>>> st = sign_test(f1, f2, o, time_dim'time', metric='mae', orientation='negative')
>>> st.plot()
>>> st['confidence'].plot(c='gray')
>>> (-1*st['confidence']).plot(c='gray')
>>> st['confidence'].plot(color='gray')
>>> (-1*st['confidence']).plot(color='gray')
References
----------
Expand Down Expand Up @@ -152,16 +154,16 @@ def _categorical_metric(observations, forecasts, dim):
metric_f1o = -metric_f1o
metric_f2o = -metric_f2o

sign_test = (1 * (metric_f1o < metric_f2o) - 1 * (metric_f2o < metric_f1o)).cumsum(
walk = (1 * (metric_f1o < metric_f2o) - 1 * (metric_f2o < metric_f1o)).cumsum(
time_dim
)

# Estimate 95% confidence interval -----
# Estimate 1 - alpha confidence interval -----
notnan = 1 * (metric_f1o.notnull() & metric_f2o.notnull())
N = notnan.cumsum(time_dim)
# z_alpha is the value at which the standardized cumulative Gaussian distributed
# exceeds alpha
confidence = st.norm.ppf(1 - alpha / 2) * xr.ufuncs.sqrt(N)
sign_test.coords["alpha"] = alpha
sign_test.coords["confidence"] = confidence
return sign_test
walk.coords["alpha"] = alpha
walk = _add_as_coord(walk, confidence, "confidence")
return walk
40 changes: 33 additions & 7 deletions xskillscore/tests/test_comparative.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
import pytest
from dask import is_dask_collection

import xskillscore as xs
from xskillscore import sign_test
Expand All @@ -20,15 +22,31 @@ def logical(ds):
return ds > 0.5


def test_sign_test_raw(a_1d, a_1d_worse, b_1d):
"""Test sign_test where significance crossed (for np.random.seed(42) values)."""
@pytest.mark.parametrize("chunk", [True, False])
@pytest.mark.parametrize("input", ["Dataset", "multidim Dataset", "DataArray", "mixed"])
def test_sign_test_inputs(a_1d, a_1d_worse, b_1d, input, chunk):
"""Test sign_test with xr inputs and chunked."""
if "Dataset" in input:
name = "var"
a_1d = a_1d.to_dataset(name=name)
a_1d_worse = a_1d_worse.to_dataset(name=name)
b_1d = b_1d.to_dataset(name=name)
if input == "multidim Dataset":
a_1d["var2"] = a_1d["var"] * 2
a_1d_worse["var2"] = a_1d_worse["var"] * 2
b_1d["var2"] = b_1d["var"] * 2
elif input == "mixed":
name = "var"
a_1d = a_1d.to_dataset(name=name)
if chunk:
a_1d = a_1d.chunk()
a_1d_worse = a_1d_worse.chunk()
b_1d = b_1d.chunk()
actual = sign_test(
a_1d, a_1d_worse, b_1d, time_dim="time", alpha=0.05, metric="mae"
)
walk_larger_significance = actual > actual.confidence
crossing_after_timesteps = walk_larger_significance.argmax(dim="time")
# check timesteps after which sign_test larger confidence
assert crossing_after_timesteps == 3
# check dask collection preserved
assert is_dask_collection(actual) if chunk else not is_dask_collection(actual)


@pytest.mark.parametrize("observation", [True, False])
Expand Down Expand Up @@ -183,7 +201,6 @@ def test_sign_test_dim(a, a_worse, b):
assert len(actual.dims) == 1


@pytest.mark.xfail()
def test_sign_test_dim_fails(a_1d, a_1d_worse, b_1d):
"""Sign_test fails if no time_dim in dim."""
with pytest.raises(ValueError) as e:
Expand All @@ -194,3 +211,12 @@ def test_sign_test_dim_fails(a_1d, a_1d_worse, b_1d):
def test_sign_test_metric_correlation(a, a_worse, b):
"""Sign_test work for correlation metrics over other dimensions that time_dim."""
sign_test(a, a_worse, b, time_dim="time", dim=["lon", "lat"], metric="pearson_r")


def test_sign_test_NaNs_confidence(a, a_worse, b):
"""Sign_test confidence with NaNs."""
actual = sign_test(a, a_worse, b, time_dim="time", metric="mse")
a_nan = a.copy()
a_nan[1:3, 1:3, 1:3] = np.nan
actual_nan = sign_test(a_nan, a_worse, b, time_dim="time", metric="mse")
assert not (actual_nan.confidence == actual.confidence).all()

0 comments on commit f628918

Please sign in to comment.