Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: stats.circ___: add array-API support #20595

Merged
merged 12 commits into from May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/array_api.yml
Expand Up @@ -96,3 +96,4 @@ jobs:
python dev.py --no-build test -b all -t scipy._lib.tests.test_array_api
python dev.py --no-build test -b all -t scipy._lib.tests.test__util -- --durations 3 --timeout=60
python dev.py --no-build test -b all -t scipy.stats.tests.test_stats -- --durations 3 --timeout=60
python dev.py --no-build test -b all -t scipy.stats.tests.test_morestats -- --durations 3 --timeout=60
21 changes: 20 additions & 1 deletion scipy/_lib/_array_api.py
Expand Up @@ -283,13 +283,20 @@ def xp_assert_equal(actual, desired, check_namespace=True, check_dtype=True,
return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)


def xp_assert_close(actual, desired, rtol=1e-07, atol=0, check_namespace=True,
def xp_assert_close(actual, desired, rtol=None, atol=0, check_namespace=True,
check_dtype=True, check_shape=True, err_msg='', xp=None):
__tracebackhide__ = True # Hide traceback for py.test
if xp is None:
xp = array_namespace(actual)
desired = _strict_check(actual, desired, xp, check_namespace=check_namespace,
check_dtype=check_dtype, check_shape=check_shape)

floating = xp.isdtype(desired.dtype, ('real floating', 'complex floating'))
if rtol is None and floating:
rtol = xp.finfo(desired.dtype).eps**0.5
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
elif rtol is None:
rtol = 1e-7

if is_cupy(xp):
return xp.testing.assert_allclose(actual, desired, rtol=rtol,
atol=atol, err_msg=err_msg)
Expand Down Expand Up @@ -354,3 +361,15 @@ def xp_unsupported_param_msg(param):

def is_complex(x, xp):
return xp.isdtype(x.dtype, 'complex floating')


# temporary substitute for xp.minimum, which is not yet in all backends
# or covered by array_api_compat.
def xp_minimum(x1, x2):
# xp won't be passed in because it doesn't need to be passed in to xp.minimum
xp = array_namespace(x1, x2)
x1, x2 = xp.broadcast_arrays(x1, x2)
res = xp.asarray(x1, copy=True)
i = (x2 < x1) | xp.isnan(x2)
res[i] = x2[i]
return res
23 changes: 13 additions & 10 deletions scipy/_lib/_util.py
Expand Up @@ -268,8 +268,8 @@ def check_random_state(seed):
if isinstance(seed, (np.random.RandomState, np.random.Generator)):
return seed

raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
' instance' % seed)
raise ValueError(f"'{seed}' cannot be used to seed a numpy.random.RandomState"
" instance")


def _asarray_validated(a, check_finite=True,
Expand Down Expand Up @@ -716,10 +716,9 @@ def _contains_nan(a, nan_policy='propagate', use_summation=True,
if not_numpy:
use_summation = False # some array_likes ignore nans (e.g. pandas)
if policies is None:
policies = ['propagate', 'raise', 'omit']
policies = {'propagate', 'raise', 'omit'}
if nan_policy not in policies:
raise ValueError("nan_policy must be one of {%s}" %
', '.join("'%s'" % s for s in policies))
raise ValueError(f"nan_policy must be one of {set(policies)}.")

inexact = (xp.isdtype(a.dtype, "real floating")
or xp.isdtype(a.dtype, "complex floating"))
Expand Down Expand Up @@ -815,15 +814,19 @@ def _rng_spawn(rng, n_children):
return child_rngs


def _get_nan(*data):
def _get_nan(*data, xp=None):
if data == ():
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
return None
xp = array_namespace(*data) if xp is None else xp
# Get NaN of appropriate dtype for data
data = [np.asarray(item) for item in data]
data = [xp.asarray(item) for item in data]
try:
dtype = np.result_type(*data, np.half) # must be a float16 at least
min_float = getattr(xp, 'float16', xp.float32)
dtype = xp.result_type(*data, min_float) # must be at least a float
except DTypePromotionError:
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
# fallback to float64
return np.array(np.nan, dtype=np.float64)[()]
return np.array(np.nan, dtype=dtype)[()]
dtype = xp.float64
return xp.asarray(xp.nan, dtype=dtype)[()]


def normalize_axis_index(axis, ndim):
Expand Down
10 changes: 5 additions & 5 deletions scipy/cluster/tests/test_hierarchy.py
Expand Up @@ -769,7 +769,7 @@ def test_maxdists_one_cluster_linkage(self, xp):
Z = xp.asarray([[0, 1, 0.3, 4]], dtype=xp.float64)
MD = maxdists(Z)
expectedMD = calculate_maximum_distances(Z, xp)
xp_assert_close(MD, expectedMD, atol=1e-15)
xp_assert_close(MD, expectedMD, rtol=1e-7, atol=1e-15)
mdhaber marked this conversation as resolved.
Show resolved Hide resolved

def test_maxdists_Q_linkage(self, xp):
for method in ['single', 'complete', 'ward', 'centroid', 'median']:
Expand All @@ -781,7 +781,7 @@ def check_maxdists_Q_linkage(self, method, xp):
Z = linkage(X, method)
MD = maxdists(Z)
expectedMD = calculate_maximum_distances(Z, xp)
xp_assert_close(MD, expectedMD, atol=1e-15)
xp_assert_close(MD, expectedMD, rtol=1e-7, atol=1e-15)


class TestMaxInconsts:
Expand All @@ -808,7 +808,7 @@ def test_maxinconsts_one_cluster_linkage(self, xp):
R = xp.asarray([[0, 0, 0, 0.3]], dtype=xp.float64)
MD = maxinconsts(Z, R)
expectedMD = calculate_maximum_inconsistencies(Z, R, xp=xp)
xp_assert_close(MD, expectedMD, atol=1e-15)
xp_assert_close(MD, expectedMD, rtol=1e-7, atol=1e-15)

@skip_xp_backends(cpu_only=True)
def test_maxinconsts_Q_linkage(self, xp):
Expand All @@ -822,7 +822,7 @@ def check_maxinconsts_Q_linkage(self, method, xp):
R = inconsistent(Z)
MD = maxinconsts(Z, R)
expectedMD = calculate_maximum_inconsistencies(Z, R, xp=xp)
xp_assert_close(MD, expectedMD, atol=1e-15)
xp_assert_close(MD, expectedMD, rtol=1e-7, atol=1e-15)


class TestMaxRStat:
Expand Down Expand Up @@ -889,7 +889,7 @@ def check_maxRstat_Q_linkage(self, method, i, xp):
R = inconsistent(Z)
MD = maxRstat(Z, R, 1)
expectedMD = calculate_maximum_inconsistencies(Z, R, 1, xp)
xp_assert_close(MD, expectedMD, atol=1e-15)
xp_assert_close(MD, expectedMD, rtol=1e-7, atol=1e-15)


@skip_xp_backends(cpu_only=True)
Expand Down
4 changes: 2 additions & 2 deletions scipy/cluster/tests/test_vq.py
Expand Up @@ -413,9 +413,9 @@ def test_kmeans_and_kmeans2_random_seed(self, xp):
# test for kmeans
res1, _ = kmeans(data, 2, seed=seed1)
res2, _ = kmeans(data, 2, seed=seed2)
xp_assert_close(res1, res2, xp=xp) # should be same results
xp_assert_close(res1, res2) # should be same results
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
# test for kmeans2
for minit in ["random", "points", "++"]:
res1, _ = kmeans2(data, 2, minit=minit, seed=seed1)
res2, _ = kmeans2(data, 2, minit=minit, seed=seed2)
xp_assert_close(res1, res2, xp=xp) # should be same results
xp_assert_close(res1, res2) # should be same results
9 changes: 5 additions & 4 deletions scipy/optimize/tests/test_minpack.py
Expand Up @@ -321,7 +321,7 @@ def test_full_output(self):
args=(self.y_meas, self.x),
full_output=True)
params_fit, cov_x, infodict, mesg, ier = full_output
assert_(ier in (1,2,3,4), 'solution not found: %s' % mesg)
assert_(ier in (1,2,3,4), f'solution not found: {mesg}')
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved

def test_input_untouched(self):
p0 = array([0,0,0],dtype=float64)
Expand All @@ -330,7 +330,7 @@ def test_input_untouched(self):
args=(self.y_meas, self.x),
full_output=True)
params_fit, cov_x, infodict, mesg, ier = full_output
assert_(ier in (1,2,3,4), 'solution not found: %s' % mesg)
assert_(ier in (1,2,3,4), f'solution not found: {mesg}')
assert_array_equal(p0, p0_copy)

def test_wrong_shape_func_callable(self):
Expand Down Expand Up @@ -602,8 +602,9 @@ def _check_nan_policy(f, xdata_with_nan, xdata_without_nan,
assert_allclose(result_with_nan, result_without_nan)

# not valid policy test
error_msg = ("nan_policy must be one of "
"{'None', 'raise', 'omit'}")
# check for argument names in any order
error_msg = (r"nan_policy must be one of \{(?:'raise'|'omit'|None)"
r"(?:, ?(?:'raise'|'omit'|None))*\}")
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
with assert_raises(ValueError, match=error_msg):
curve_fit(**kwargs, nan_policy="hi")

Expand Down
80 changes: 45 additions & 35 deletions scipy/stats/_morestats.py
Expand Up @@ -7,11 +7,12 @@
from numpy import (isscalar, r_, log, around, unique, asarray, zeros,
arange, sort, amin, amax, sqrt, array, atleast_1d, # noqa: F401
compress, pi, exp, ravel, count_nonzero, sin, cos, # noqa: F401
arctan2, hypot)
arctan2, hypot) # noqa: F401

from scipy import optimize, special, interpolate, stats
from scipy._lib._bunch import _make_tuple_bunch
from scipy._lib._util import _rename_parameter, _contains_nan, _get_nan
from scipy._lib._array_api import array_namespace, xp_minimum, size as xp_size

from ._ansari_swilk_statistics import gscale, swilk
from . import _stats_py, _wilcoxon
Expand Down Expand Up @@ -130,8 +131,7 @@ def bayes_mvs(data, alpha=0.90):
"""
m, v, s = mvsdist(data)
if alpha >= 1 or alpha <= 0:
raise ValueError("0 < alpha < 1 is required, but alpha=%s was given."
% alpha)
raise ValueError(f"0 < alpha < 1 is required, but {alpha=} was given.")

m_res = Mean(m.mean(), m.interval(alpha))
v_res = Variance(v.mean(), v.interval(alpha))
Expand Down Expand Up @@ -454,7 +454,7 @@ def _parse_dist_kw(dist, enforce_subclass=True):
try:
dist = getattr(distributions, dist)
except AttributeError as e:
raise ValueError("%s is not a valid distribution name" % dist) from e
raise ValueError(f"{dist} is not a valid distribution name") from e
elif enforce_subclass:
msg = ("`dist` should be a stats.distributions instance or a string "
"with the name of such a distribution.")
Expand Down Expand Up @@ -831,7 +831,7 @@ def ppcc_plot(x, a, b, dist='tukeylambda', plot=None, N=80):
plot.plot(svals, ppcc, 'x')
_add_axis_labels_title(plot, xlabel='Shape Values',
ylabel='Prob Plot Corr. Coef.',
title='(%s) PPCC Plot' % dist)
title=f'({dist}) PPCC Plot')

return svals, ppcc

Expand Down Expand Up @@ -1323,7 +1323,7 @@ def _all(x):
'mle': _mle,
'all': _all}
if method not in methods.keys():
raise ValueError("Method %s not recognized." % method)
raise ValueError(f"Method {method} not recognized.")

optimfunc = methods[method]

Expand Down Expand Up @@ -4311,11 +4311,9 @@ def median_test(*samples, ties='below', correction=True, lambda_=1,
# a zero in the table of expected frequencies.
rowsums = table.sum(axis=1)
if rowsums[0] == 0:
raise ValueError("All values are below the grand median (%r)." %
grand_median)
raise ValueError(f"All values are below the grand median ({grand_median}).")
if rowsums[1] == 0:
raise ValueError("All values are above the grand median (%r)." %
grand_median)
raise ValueError(f"All values are above the grand median ({grand_median}).")
if ties == "ignore":
# We already checked that each sample has at least one value, but it
# is possible that all those values equal the grand median. If `ties`
Expand All @@ -4333,16 +4331,21 @@ def median_test(*samples, ties='below', correction=True, lambda_=1,
return MedianTestResult(stat, p, grand_median, table)


def _circfuncs_common(samples, high, low):
def _circfuncs_common(samples, high, low, xp=None):
xp = array_namespace(samples) if xp is None else xp
# Ensure samples are array-like and size is not zero
if samples.size == 0:
NaN = _get_nan(samples)
if xp_size(samples) == 0:
NaN = _get_nan(samples, xp=xp)
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
return NaN, NaN, NaN

if xp.isdtype(samples.dtype, 'integral'):
dtype = xp.asarray(1.).dtype # get default float type
samples = xp.asarray(samples, dtype=dtype)
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved

# Recast samples as radians that range between 0 and 2 pi and calculate
# the sine and cosine
sin_samp = sin((samples - low)*2.*pi / (high - low))
cos_samp = cos((samples - low)*2.*pi / (high - low))
sin_samp = xp.sin((samples - low)*2.*xp.pi / (high - low))
cos_samp = xp.cos((samples - low)*2.*xp.pi / (high - low))

return samples, sin_samp, cos_samp

Expand Down Expand Up @@ -4403,16 +4406,17 @@ def circmean(samples, high=2*pi, low=0, axis=None, nan_policy='propagate'):
>>> plt.show()

"""
samples, sin_samp, cos_samp = _circfuncs_common(samples, high, low)
sin_sum = sin_samp.sum(axis)
cos_sum = cos_samp.sum(axis)
res = arctan2(sin_sum, cos_sum)
xp = array_namespace(samples)
samples, sin_samp, cos_samp = _circfuncs_common(samples, high, low, xp=xp)
sin_sum = xp.sum(sin_samp, axis=axis)
cos_sum = xp.sum(cos_samp, axis=axis)
res = xp.atan2(sin_sum, cos_sum)
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved

res = np.asarray(res)
res[res < 0] += 2*pi
res = res[()]
res = xp.asarray(res)
res[res < 0] += 2*xp.pi
res = res[()] if res.ndim == 0 else res

return res*(high - low)/2.0/pi + low
return res*(high - low)/2.0/xp.pi + low


@_axis_nan_policy_factory(
Expand Down Expand Up @@ -4482,12 +4486,15 @@ def circvar(samples, high=2*pi, low=0, axis=None, nan_policy='propagate'):
>>> plt.show()

"""
samples, sin_samp, cos_samp = _circfuncs_common(samples, high, low)
sin_mean = sin_samp.mean(axis)
cos_mean = cos_samp.mean(axis)
# hypot can go slightly above 1 due to rounding errors
xp = array_namespace(samples)
samples, sin_samp, cos_samp = _circfuncs_common(samples, high, low, xp=xp)
sin_mean = xp.mean(sin_samp, axis=axis)
cos_mean = xp.mean(cos_samp, axis=axis)
hypotenuse = (sin_mean**2. + cos_mean**2.)**0.5
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
# hypotenuse can go slightly above 1 due to rounding errors
with np.errstate(invalid='ignore'):
j-bowhay marked this conversation as resolved.
Show resolved Hide resolved
R = np.minimum(1, hypot(sin_mean, cos_mean))
one = xp.asarray(1., dtype=hypotenuse.dtype)
R = xp_minimum(one, hypotenuse)
mdhaber marked this conversation as resolved.
Show resolved Hide resolved

res = 1. - R
return res
Expand Down Expand Up @@ -4579,16 +4586,19 @@ def circstd(samples, high=2*pi, low=0, axis=None, nan_policy='propagate', *,
>>> plt.show()

"""
samples, sin_samp, cos_samp = _circfuncs_common(samples, high, low)
sin_mean = sin_samp.mean(axis) # [1] (2.2.3)
cos_mean = cos_samp.mean(axis) # [1] (2.2.3)
# hypot can go slightly above 1 due to rounding errors
xp = array_namespace(samples)
samples, sin_samp, cos_samp = _circfuncs_common(samples, high, low, xp=xp)
sin_mean = xp.mean(sin_samp, axis=axis) # [1] (2.2.3)
cos_mean = xp.mean(cos_samp, axis=axis) # [1] (2.2.3)
hypotenuse = (sin_mean**2. + cos_mean**2.)**0.5
# hypotenuse can go slightly above 1 due to rounding errors
with np.errstate(invalid='ignore'):
R = np.minimum(1, hypot(sin_mean, cos_mean)) # [1] (2.2.4)
one = xp.asarray(1., dtype=hypotenuse.dtype)
R = xp_minimum(one, hypotenuse) # [1] (2.2.4)

res = sqrt(-2*log(R))
res = xp.sqrt(-2*xp.log(R))
if not normalize:
res *= (high-low)/(2.*pi) # [1] (2.3.14) w/ (2.3.7)
res *= (high-low)/(2.*xp.pi) # [1] (2.3.14) w/ (2.3.7)
return res


Expand Down