Skip to content

Commit

Permalink
Merge pull request xarray-contrib#155 from aaronspring/AS_accessors_easy
Browse files Browse the repository at this point in the history
use args and kwargs in accessors
  • Loading branch information
raybellwaves committed Aug 26, 2020
2 parents 781ea26 + 785204e commit 49eafc6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 135 deletions.
172 changes: 45 additions & 127 deletions xskillscore/core/accessor.py
Expand Up @@ -43,194 +43,112 @@ def _in_ds(self, x):
else:
return self._obj[x]

def pearson_r(self, a, b, dim=None, weights=None, skipna=False, keep_attrs=False):
def pearson_r(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return pearson_r(
a, b, dim=dim, weights=weights, skipna=skipna, keep_attrs=keep_attrs
)
return pearson_r(a, b, *args, **kwargs)

def r2(self, a, b, dim=None, weights=None, skipna=False, keep_attrs=False):
def r2(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return r2(a, b, dim=dim, weights=weights, skipna=skipna, keep_attrs=keep_attrs)
return r2(a, b, *args, **kwargs)

def pearson_r_p_value(
self, a, b, dim=None, weights=None, skipna=False, keep_attrs=False
):
def pearson_r_p_value(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return pearson_r_p_value(
a, b, dim=dim, weights=weights, skipna=skipna, keep_attrs=keep_attrs
)
return pearson_r_p_value(a, b, *args, **kwargs)

def effective_sample_size(self, a, b, dim='time', skipna=False, keep_attrs=False):
def effective_sample_size(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return effective_sample_size(
a, b, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)
return effective_sample_size(a, b, *args, **kwargs)

def pearson_r_eff_p_value(self, a, b, dim='time', skipna=False, keep_attrs=False):
def pearson_r_eff_p_value(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return pearson_r_eff_p_value(
a, b, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)
return pearson_r_eff_p_value(a, b, *args, **kwargs)

def spearman_r(self, a, b, dim=None, weights=None, skipna=False, keep_attrs=False):
def spearman_r(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return spearman_r(
a, b, dim, weights=weights, skipna=skipna, keep_attrs=keep_attrs
)
return spearman_r(a, b, *args, **kwargs)

def spearman_r_p_value(
self, a, b, dim=None, weights=None, skipna=False, keep_attrs=False
):
def spearman_r_p_value(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return spearman_r_p_value(
a, b, dim, weights=weights, skipna=skipna, keep_attrs=keep_attrs
)
return spearman_r_p_value(a, b, *args, **kwargs)

def spearman_r_eff_p_value(self, a, b, dim='time', skipna=False, keep_attrs=False):
def spearman_r_eff_p_value(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return spearman_r_eff_p_value(
a, b, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)
return spearman_r_eff_p_value(a, b, *args, **kwargs)

def rmse(self, a, b, dim=None, weights=None, skipna=False, keep_attrs=False):
def rmse(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return rmse(
a, b, dim=dim, weights=weights, skipna=skipna, keep_attrs=keep_attrs
)
return rmse(a, b, *args, **kwargs)

def mse(self, a, b, dim=None, weights=None, skipna=False, keep_attrs=False):
def mse(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return mse(a, b, dim=dim, weights=weights, skipna=skipna, keep_attrs=keep_attrs)
return mse(a, b, *args, **kwargs)

def mae(self, a, b, dim=None, weights=None, skipna=False, keep_attrs=False):
def mae(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return mae(a, b, dim=dim, weights=weights, skipna=skipna, keep_attrs=keep_attrs)
return mae(a, b, *args, **kwargs)

def median_absolute_error(self, a, b, dim=None, skipna=False, keep_attrs=False):
def median_absolute_error(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return median_absolute_error(
a, b, dim=dim, skipna=skipna, keep_attrs=keep_attrs
)
return median_absolute_error(a, b, *args, **kwargs)

def mape(self, a, b, dim=None, weights=None, skipna=False, keep_attrs=False):
def mape(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return mape(
a, b, dim=dim, weights=weights, skipna=skipna, keep_attrs=keep_attrs
)
return mape(a, b, *args, **kwargs)

def smape(self, a, b, dim=None, weights=None, skipna=False, keep_attrs=False):
def smape(self, a, b, *args, **kwargs):
a = self._in_ds(a)
b = self._in_ds(b)
return smape(
a, b, dim=dim, weights=weights, skipna=skipna, keep_attrs=keep_attrs
)
return smape(a, b, *args, **kwargs)

def crps_gaussian(
self, observations, mu, sig, dim=None, weights=None, keep_attrs=False
):
def crps_gaussian(self, observations, mu, sig, *args, **kwargs):
observations = self._in_ds(observations)
mu = self._in_ds(mu)
sig = self._in_ds(sig)
return crps_gaussian(observations, mu, sig, dim=dim, weights=weights)

def crps_ensemble(
self,
observations,
forecasts,
member_weights=None,
issorted=False,
dim=None,
member_dim='member',
weights=None,
):
return crps_gaussian(observations, mu, sig, *args, **kwargs)

def crps_ensemble(self, observations, forecasts, *args, **kwargs):
observations = self._in_ds(observations)
forecasts = self._in_ds(forecasts)
return crps_ensemble(
observations,
forecasts,
member_weights=member_weights,
issorted=issorted,
member_dim=member_dim,
dim=dim,
weights=weights,
)
return crps_ensemble(observations, forecasts, *args, **kwargs)

def crps_quadrature(
self,
x,
cdf_or_dist,
xmin=None,
xmax=None,
tol=1e-6,
dim=None,
weights=None,
keep_attrs=False,
):
def crps_quadrature(self, x, cdf_or_dist, *args, **kwargs):
x = self._in_ds(x)
cdf_or_dist = self._in_ds(cdf_or_dist)
return crps_quadrature(
x, cdf_or_dist, xmin=xmin, xmax=xmax, tol=1e-6, dim=dim, weights=weights
)
return crps_quadrature(x, cdf_or_dist, *args, **kwargs)

def threshold_brier_score(
self,
observations,
forecasts,
threshold,
issorted=False,
dim=None,
member_dim='member',
weights=None,
keep_attrs=False,
self, observations, forecasts, threshold, *args, **kwargs
):
observations = self._in_ds(observations)
forecasts = self._in_ds(forecasts)
threshold = self._in_ds(threshold)
return threshold_brier_score(
observations,
forecasts,
threshold,
issorted=issorted,
dim=dim,
member_dim=member_dim,
weights=weights,
observations, forecasts, threshold, *args, **kwargs
)

def brier_score(self, observations, forecasts, dim=None, weights=None):
def brier_score(self, observations, forecasts, *args, **kwargs):
observations = self._in_ds(observations)
forecasts = self._in_ds(forecasts)
return brier_score(observations, forecasts, dim=dim, weights=weights)
return brier_score(observations, forecasts, *args, **kwargs)

def rank_histogram(self, observations, forecasts, dim=None, member_dim='member'):
def rank_histogram(self, observations, forecasts, *args, **kwargs):
observations = self._in_ds(observations)
forecasts = self._in_ds(forecasts)
return rank_histogram(observations, forecasts, dim=dim, member_dim=member_dim)

def discrimination(
self,
observations,
forecasts,
dim=None,
probability_bin_edges=np.linspace(0, 1 + 1e-8, 6),
):
return rank_histogram(observations, forecasts, *args, **kwargs)

def discrimination(self, observations, forecasts, *args, **kwargs):
observations = self._in_ds(observations)
forecasts = self._in_ds(forecasts)
return discrimination(
observations,
forecasts,
dim=dim,
probability_bin_edges=probability_bin_edges,
)
return discrimination(observations, forecasts, *args, **kwargs)
10 changes: 2 additions & 8 deletions xskillscore/tests/test_accessor_probabilistic.py
Expand Up @@ -80,22 +80,16 @@ def test_crps_ensemble_accessor(o, f, dask_bool, outer_bool):
assert_allclose(actual, expected)


@pytest.mark.parametrize('outer_bool', [False, True])
@pytest.mark.parametrize('dask_bool', [False, True])
def test_crps_quadrature_accessor(o, dask_bool, outer_bool):
def test_crps_quadrature_accessor(o, dask_bool):
cdf_or_dist = norm
if dask_bool:
o = o.chunk()
actual = crps_quadrature(o, cdf_or_dist)

ds = xr.Dataset()
ds['o'] = o
ds['cdf_or_dist'] = cdf_or_dist
if outer_bool:
ds = ds.drop_vars('cdf_or_dist')
expected = ds.xs.crps_quadrature('o', cdf_or_dist)
else:
expected = ds.xs.crps_quadrature('o', 'cdf_or_dist')
expected = ds.xs.crps_quadrature('o', cdf_or_dist)
assert_allclose(actual, expected)


Expand Down

0 comments on commit 49eafc6

Please sign in to comment.