Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: refactor array API testing #20368

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 5 additions & 8 deletions doc/source/dev/api-dev/array_api.rst
Expand Up @@ -238,20 +238,18 @@ The following pytest markers are available:

The following is an example using the markers::

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import array_api_compatible, skip_xp_backends
...
@pytest.mark.skip_xp_backends(np_only=True,
reasons=['skip reason'])
@skip_xp_backends(np_only=True, reasons=['skip reason'])
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
def test_toto1(self, xp):
a = xp.asarray([1, 2, 3])
b = xp.asarray([0, 2, 5])
toto(a, b)
...
@pytest.mark.skip_xp_backends('array_api_strict', 'cupy',
reasons=['skip reason 1',
'skip reason 2',])
@skip_xp_backends('array_api_strict', 'cupy',
reasons=['skip reason 1', 'skip reason 2',])
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
def test_toto2(self, xp):
Expand All @@ -268,10 +266,9 @@ When every test function in a file has been updated for array API
compatibility, one can reduce verbosity by telling ``pytest`` to apply the
markers to every test function using ``pytestmark``::

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import array_api_compatible, skip_xp_backends

pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends
...
@skip_xp_backends(np_only=True, reasons=['skip reason'])
def test_toto1(self, xp):
Expand Down
45 changes: 38 additions & 7 deletions scipy/_lib/_array_api.py
Expand Up @@ -20,13 +20,44 @@
numpy as np_compat,
)

__all__ = ['array_namespace', '_asarray', 'size']


# To enable array API and strict array-like input validation
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
# To control the default device - for use in the test suite only
SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")
try:
# try to get the experimental variables and pytest marks from conftest,
# will error if pytest is not installed.
import pytest
from scipy.conftest import (
array_api_compatible, SCIPY_ARRAY_API, SCIPY_DEVICE
)

skip_xp_backends = pytest.mark.skip_xp_backends
except ImportError:
# define the experimental variables here instead

# To enable array API and strict array-like input validation
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False) # type: ignore[no-redef] # noqa: E501
# To control the default device - for use in the test suite only
SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")

__all__ = [
'_GLOBAL_CONFIG',
'array_api_compatible',
'array_namespace',
'_asarray',
'atleast_nd',
'copy',
'cov',
'is_complex',
'is_cupy',
'is_numpy',
'is_torch',
'SCIPY_ARRAY_API',
'SCIPY_DEVICE',
'size',
'skip_xp_backends',
'xp_assert_close',
'xp_assert_equal',
'xp_assert_less',
'xp_unsupported_param_msg'
]

_GLOBAL_CONFIG = {
"SCIPY_ARRAY_API": SCIPY_ARRAY_API,
Expand Down
6 changes: 3 additions & 3 deletions scipy/_lib/tests/test__util.py
Expand Up @@ -9,15 +9,15 @@
import pytest
from pytest import raises as assert_raises
import hypothesis.extra.numpy as npst
from hypothesis import given, strategies, reproduce_failure # noqa: F401
from scipy.conftest import array_api_compatible
from hypothesis import given, strategies # noqa: F401

from scipy._lib._array_api import xp_assert_equal
from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
getfullargspec_no_self, FullArgSpec,
rng_integers, _validate_int, _rename_parameter,
_contains_nan, _rng_html_rewrite, _lazywhere)

from scipy._lib._array_api import array_api_compatible, xp_assert_equal


def test__aligned_zeros():
niter = 10
Expand Down
20 changes: 11 additions & 9 deletions scipy/_lib/tests/test_array_api.py
@@ -1,18 +1,24 @@
import numpy as np
import pytest

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import (
_GLOBAL_CONFIG, array_namespace, _asarray, copy, xp_assert_equal, is_numpy
_GLOBAL_CONFIG,
array_api_compatible,
array_namespace,
_asarray,
copy,
is_numpy,
xp_assert_equal,
)
import scipy._lib.array_api_compat.numpy as np_compat


@array_api_compatible
@pytest.mark.skipif(not _GLOBAL_CONFIG["SCIPY_ARRAY_API"],
reason="Array API test; set environment variable SCIPY_ARRAY_API=1 to run it")
class TestArrayAPI:

def test_array_namespace(self):
def test_array_namespace(self, xp):
x, y = np.array([0, 1, 2]), np.array([0, 1, 2])
xp = array_namespace(x, y)
assert 'array_api_compat.numpy' in xp.__name__
Expand All @@ -22,15 +28,14 @@ def test_array_namespace(self):
assert 'array_api_compat.numpy' in xp.__name__
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = True

@array_api_compatible
def test_asarray(self, xp):
x, y = _asarray([0, 1, 2], xp=xp), _asarray(np.arange(3), xp=xp)
ref = xp.asarray([0, 1, 2])
xp_assert_equal(x, ref)
xp_assert_equal(y, ref)

@pytest.mark.filterwarnings("ignore: the matrix subclass")
def test_raises(self):
def test_raises(self, xp):
msg = "of type `numpy.ma.MaskedArray` are not supported"
with pytest.raises(TypeError, match=msg):
array_namespace(np.ma.array(1), np.array(1))
Expand All @@ -45,13 +50,12 @@ def test_raises(self):
with pytest.raises(TypeError, match=msg):
array_namespace('abc')

def test_array_likes(self):
def test_array_likes(self, xp):
# should be no exceptions
array_namespace([0, 1, 2])
array_namespace(1, 2, 3)
array_namespace(1)

@array_api_compatible
def test_copy(self, xp):
for _xp in [xp, None]:
x = xp.asarray([1, 2, 3])
Expand All @@ -67,7 +71,6 @@ def test_copy(self, xp):
assert x[2] != y[2]
assert id(x) != id(y)

@array_api_compatible
@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float32', 'float64'])
@pytest.mark.parametrize('shape', [(), (3,)])
def test_strict_checks(self, xp, dtype, shape):
Expand Down Expand Up @@ -98,7 +101,6 @@ def test_strict_checks(self, xp, dtype, shape):
with pytest.raises(AssertionError, match="Shapes do not match."):
xp_assert_equal(x, y, **options)

@array_api_compatible
def test_check_scalar(self, xp):
if not is_numpy(xp):
pytest.skip("Scalars only exist in NumPy")
Expand Down
8 changes: 3 additions & 5 deletions scipy/cluster/tests/test_hierarchy.py
Expand Up @@ -47,11 +47,11 @@
_order_cluster_tree, _hierarchy, _LINKAGE_METHODS)
from scipy.spatial.distance import pdist
from scipy.cluster._hierarchy import Heap
from scipy.conftest import array_api_compatible
from scipy._lib._array_api import xp_assert_close, xp_assert_equal

from . import hierarchy_test_data

from scipy._lib._array_api import (
array_api_compatible, skip_xp_backends, xp_assert_close, xp_assert_equal,
)

# Matplotlib is not a scipy dependency but is optionally used in dendrogram, so
# check if it's available
Expand All @@ -65,9 +65,7 @@
except Exception:
have_matplotlib = False


pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends


class TestLinkage:
Expand Down
10 changes: 7 additions & 3 deletions scipy/cluster/tests/test_vq.py
Expand Up @@ -12,15 +12,19 @@
from scipy.cluster.vq import (kmeans, kmeans2, py_vq, vq, whiten,
ClusterError, _krandinit)
from scipy.cluster import _vq
from scipy.conftest import array_api_compatible
from scipy.sparse._sputils import matrix

from scipy._lib._array_api import (
SCIPY_ARRAY_API, copy, cov, xp_assert_close, xp_assert_equal
array_api_compatible,
copy,
cov,
SCIPY_ARRAY_API,
skip_xp_backends,
xp_assert_close,
xp_assert_equal,
)

pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends

TESTDATA_2D = np.array([
-2.2, 1.17, -1.63, 1.69, -2.04, 4.38, -3.09, 0.95, -1.7, 4.79, -1.68, 0.68,
Expand Down
9 changes: 5 additions & 4 deletions scipy/cluster/vq.py
Expand Up @@ -67,14 +67,15 @@
import warnings
import numpy as np
from collections import deque
from scipy._lib._array_api import (
_asarray, array_namespace, size, atleast_nd, copy, cov
)

from scipy._lib._util import check_random_state, rng_integers
from scipy.spatial.distance import cdist

from . import _vq

from scipy._lib._array_api import (
_asarray, array_namespace, atleast_nd, copy, cov, size,
)

__docformat__ = 'restructuredtext'

__all__ = ['whiten', 'vq', 'kmeans', 'kmeans2']
Expand Down
7 changes: 6 additions & 1 deletion scipy/conftest.py
@@ -1,4 +1,5 @@
# Pytest customization
from __future__ import annotations
import json
import os
import warnings
Expand All @@ -12,7 +13,6 @@
from scipy._lib._fpumode import get_fpu_mode
from scipy._lib._testutils import FPUModeChangeWarning
from scipy._lib import _pep440
from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE


def pytest_configure(config):
Expand Down Expand Up @@ -107,6 +107,11 @@ def check_fpu_mode(request):
# Array API backend handling
xp_available_backends = {'numpy': np}

# To enable array API and strict array-like input validation
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
# To control the default device - for use in the test suite only
SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")

if SCIPY_ARRAY_API and isinstance(SCIPY_ARRAY_API, str):
# fill the dict of backends with available libraries
try:
Expand Down
7 changes: 4 additions & 3 deletions scipy/fft/_basic_backend.py
@@ -1,8 +1,9 @@
import numpy as np
from . import _pocketfft

from scipy._lib._array_api import (
array_namespace, is_numpy, xp_unsupported_param_msg, is_complex
array_namespace, is_complex, is_numpy, xp_unsupported_param_msg,
)
from . import _pocketfft
import numpy as np


def _validate_fft_args(workers, plan, norm):
Expand Down
2 changes: 1 addition & 1 deletion scipy/fft/_helper.py
@@ -1,9 +1,9 @@
from functools import update_wrapper, lru_cache
import inspect

import numpy as np
from ._pocketfft import helper as _helper

import numpy as np
from scipy._lib._array_api import array_namespace


Expand Down
3 changes: 2 additions & 1 deletion scipy/fft/_realtransforms_backend.py
@@ -1,7 +1,8 @@
from scipy._lib._array_api import array_namespace
import numpy as np
from . import _pocketfft

from scipy._lib._array_api import array_namespace

__all__ = ['dct', 'idct', 'dst', 'idst', 'dctn', 'idctn', 'dstn', 'idstn']


Expand Down
11 changes: 8 additions & 3 deletions scipy/fft/tests/test_basic.py
Expand Up @@ -6,14 +6,19 @@
from numpy.random import random
from numpy.testing import assert_array_almost_equal, assert_allclose
from pytest import raises as assert_raises

import scipy.fft as fft
from scipy.conftest import array_api_compatible

from scipy._lib._array_api import (
array_namespace, size, xp_assert_close, xp_assert_equal
array_api_compatible,
array_namespace,
size,
skip_xp_backends,
xp_assert_close,
xp_assert_equal,
)

pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends


# Expected input dtypes. Note that `scipy.fft` is more flexible for numpy,
Expand Down
3 changes: 1 addition & 2 deletions scipy/fft/tests/test_fftlog.py
Expand Up @@ -5,8 +5,7 @@
from scipy.fft._fftlog import fht, ifht, fhtoffset
from scipy.special import poch

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import xp_assert_close
from scipy._lib._array_api import array_api_compatible, xp_assert_close

pytestmark = array_api_compatible

Expand Down
12 changes: 8 additions & 4 deletions scipy/fft/tests/test_helper.py
Expand Up @@ -4,18 +4,21 @@
Modified for Array API, 2023

"""
from scipy.fft._helper import next_fast_len, _init_nd_shape_and_axes

from numpy.testing import assert_equal
from pytest import raises as assert_raises
import pytest
import numpy as np
import sys
from scipy.conftest import array_api_compatible
from scipy._lib._array_api import xp_assert_close, SCIPY_DEVICE

from scipy import fft
from scipy.fft._helper import next_fast_len, _init_nd_shape_and_axes

from scipy._lib._array_api import (
array_api_compatible, SCIPY_DEVICE, skip_xp_backends, xp_assert_close,
)

pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends

_5_smooth_numbers = [
2, 3, 4, 5, 6, 8, 9, 10,
Expand All @@ -24,6 +27,7 @@
2**3 * 3**3 * 5**2,
]


def test_next_fast_len():
for n in _5_smooth_numbers:
assert_equal(next_fast_len(n), n)
Expand Down
7 changes: 4 additions & 3 deletions scipy/fft/tests/test_real_transforms.py
Expand Up @@ -6,11 +6,12 @@
from scipy.fft import dct, idct, dctn, idctn, dst, idst, dstn, idstn
import scipy.fft as fft
from scipy import fftpack
from scipy.conftest import array_api_compatible
from scipy._lib._array_api import copy, xp_assert_close

from scipy._lib._array_api import (
array_api_compatible, copy, skip_xp_backends, xp_assert_close,
)

pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends

SQRT_2 = math.sqrt(2)

Expand Down