Skip to content

Commit

Permalink
MAINT: refactor array API testing
Browse files Browse the repository at this point in the history
[skip circle] [skip cirrus]
  • Loading branch information
lucascolley committed Apr 1, 2024
1 parent 8b22ba9 commit 8d7e623
Show file tree
Hide file tree
Showing 23 changed files with 142 additions and 78 deletions.
13 changes: 5 additions & 8 deletions doc/source/dev/api-dev/array_api.rst
Expand Up @@ -238,20 +238,18 @@ The following pytest markers are available:

The following is an example using the markers::

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import array_api_compatible, skip_xp_backends
...
@pytest.mark.skip_xp_backends(np_only=True,
reasons=['skip reason'])
@skip_xp_backends(np_only=True, reasons=['skip reason'])
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
def test_toto1(self, xp):
a = xp.asarray([1, 2, 3])
b = xp.asarray([0, 2, 5])
toto(a, b)
...
@pytest.mark.skip_xp_backends('array_api_strict', 'cupy',
reasons=['skip reason 1',
'skip reason 2',])
@skip_xp_backends('array_api_strict', 'cupy',
reasons=['skip reason 1', 'skip reason 2',])
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
def test_toto2(self, xp):
Expand All @@ -268,10 +266,9 @@ When every test function in a file has been updated for array API
compatibility, one can reduce verbosity by telling ``pytest`` to apply the
markers to every test function using ``pytestmark``::

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import array_api_compatible, skip_xp_backends
pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends
...
@skip_xp_backends(np_only=True, reasons=['skip reason'])
def test_toto1(self, xp):
Expand Down
45 changes: 38 additions & 7 deletions scipy/_lib/_array_api.py
Expand Up @@ -20,13 +20,44 @@
numpy as np_compat,
)

__all__ = ['array_namespace', '_asarray', 'size']


# To enable array API and strict array-like input validation
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
# To control the default device - for use in the test suite only
SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")
try:
# try to get the experimental variables and pytest marks from conftest,
# will error if pytest is not installed.
from scipy.conftest import (
array_api_compatible, SCIPY_ARRAY_API, SCIPY_DEVICE, skip_xp_mark
)
except ImportError:
# define the experimental variables here instead

# To enable array API and strict array-like input validation
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False) # type: ignore[no-redef] # noqa: E501
# To control the default device - for use in the test suite only
SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")

__all__ = [
'_GLOBAL_CONFIG',
'array_api_compatible',
'array_namespace',
'_asarray',
'atleast_nd',
'copy',
'cov',
'is_complex',
'is_cupy',
'is_numpy',
'is_torch',
'SCIPY_ARRAY_API',
'SCIPY_DEVICE',
'size',
'skip_xp_backends',
'xp_assert_close',
'xp_assert_equal',
'xp_assert_less',
'xp_unsupported_param_msg'
]

# If pytest is installed, take the marker from conftest and export it
skip_xp_backends = skip_xp_mark or None

_GLOBAL_CONFIG = {
"SCIPY_ARRAY_API": SCIPY_ARRAY_API,
Expand Down
6 changes: 3 additions & 3 deletions scipy/_lib/tests/test__util.py
Expand Up @@ -9,15 +9,15 @@
import pytest
from pytest import raises as assert_raises
import hypothesis.extra.numpy as npst
from hypothesis import given, strategies, reproduce_failure # noqa: F401
from scipy.conftest import array_api_compatible
from hypothesis import given, strategies # noqa: F401

from scipy._lib._array_api import xp_assert_equal
from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
getfullargspec_no_self, FullArgSpec,
rng_integers, _validate_int, _rename_parameter,
_contains_nan, _rng_html_rewrite, _lazywhere)

from scipy._lib._array_api import array_api_compatible, xp_assert_equal


def test__aligned_zeros():
niter = 10
Expand Down
20 changes: 11 additions & 9 deletions scipy/_lib/tests/test_array_api.py
@@ -1,18 +1,24 @@
import numpy as np
import pytest

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import (
_GLOBAL_CONFIG, array_namespace, _asarray, copy, xp_assert_equal, is_numpy
_GLOBAL_CONFIG,
array_api_compatible,
array_namespace,
_asarray,
copy,
is_numpy,
xp_assert_equal,
)
import scipy._lib.array_api_compat.numpy as np_compat


@array_api_compatible
@pytest.mark.skipif(not _GLOBAL_CONFIG["SCIPY_ARRAY_API"],
reason="Array API test; set environment variable SCIPY_ARRAY_API=1 to run it")
class TestArrayAPI:

def test_array_namespace(self):
def test_array_namespace(self, xp):
x, y = np.array([0, 1, 2]), np.array([0, 1, 2])
xp = array_namespace(x, y)
assert 'array_api_compat.numpy' in xp.__name__
Expand All @@ -22,15 +28,14 @@ def test_array_namespace(self):
assert 'array_api_compat.numpy' in xp.__name__
_GLOBAL_CONFIG["SCIPY_ARRAY_API"] = True

@array_api_compatible
def test_asarray(self, xp):
x, y = _asarray([0, 1, 2], xp=xp), _asarray(np.arange(3), xp=xp)
ref = xp.asarray([0, 1, 2])
xp_assert_equal(x, ref)
xp_assert_equal(y, ref)

@pytest.mark.filterwarnings("ignore: the matrix subclass")
def test_raises(self):
def test_raises(self, xp):
msg = "of type `numpy.ma.MaskedArray` are not supported"
with pytest.raises(TypeError, match=msg):
array_namespace(np.ma.array(1), np.array(1))
Expand All @@ -45,13 +50,12 @@ def test_raises(self):
with pytest.raises(TypeError, match=msg):
array_namespace('abc')

def test_array_likes(self):
def test_array_likes(self, xp):
# should be no exceptions
array_namespace([0, 1, 2])
array_namespace(1, 2, 3)
array_namespace(1)

@array_api_compatible
def test_copy(self, xp):
for _xp in [xp, None]:
x = xp.asarray([1, 2, 3])
Expand All @@ -67,7 +71,6 @@ def test_copy(self, xp):
assert x[2] != y[2]
assert id(x) != id(y)

@array_api_compatible
@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float32', 'float64'])
@pytest.mark.parametrize('shape', [(), (3,)])
def test_strict_checks(self, xp, dtype, shape):
Expand Down Expand Up @@ -98,7 +101,6 @@ def test_strict_checks(self, xp, dtype, shape):
with pytest.raises(AssertionError, match="Shapes do not match."):
xp_assert_equal(x, y, **options)

@array_api_compatible
def test_check_scalar(self, xp):
if not is_numpy(xp):
pytest.skip("Scalars only exist in NumPy")
Expand Down
8 changes: 3 additions & 5 deletions scipy/cluster/tests/test_hierarchy.py
Expand Up @@ -47,11 +47,11 @@
_order_cluster_tree, _hierarchy, _LINKAGE_METHODS)
from scipy.spatial.distance import pdist
from scipy.cluster._hierarchy import Heap
from scipy.conftest import array_api_compatible
from scipy._lib._array_api import xp_assert_close, xp_assert_equal

from . import hierarchy_test_data

from scipy._lib._array_api import (
array_api_compatible, skip_xp_backends, xp_assert_close, xp_assert_equal,
)

# Matplotlib is not a scipy dependency but is optionally used in dendrogram, so
# check if it's available
Expand All @@ -65,9 +65,7 @@
except Exception:
have_matplotlib = False


pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends


class TestLinkage:
Expand Down
10 changes: 7 additions & 3 deletions scipy/cluster/tests/test_vq.py
Expand Up @@ -12,15 +12,19 @@
from scipy.cluster.vq import (kmeans, kmeans2, py_vq, vq, whiten,
ClusterError, _krandinit)
from scipy.cluster import _vq
from scipy.conftest import array_api_compatible
from scipy.sparse._sputils import matrix

from scipy._lib._array_api import (
SCIPY_ARRAY_API, copy, cov, xp_assert_close, xp_assert_equal
array_api_compatible,
copy,
cov,
SCIPY_ARRAY_API,
skip_xp_backends,
xp_assert_close,
xp_assert_equal,
)

pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends

TESTDATA_2D = np.array([
-2.2, 1.17, -1.63, 1.69, -2.04, 4.38, -3.09, 0.95, -1.7, 4.79, -1.68, 0.68,
Expand Down
9 changes: 5 additions & 4 deletions scipy/cluster/vq.py
Expand Up @@ -67,14 +67,15 @@
import warnings
import numpy as np
from collections import deque
from scipy._lib._array_api import (
_asarray, array_namespace, size, atleast_nd, copy, cov
)

from scipy._lib._util import check_random_state, rng_integers
from scipy.spatial.distance import cdist

from . import _vq

from scipy._lib._array_api import (
_asarray, array_namespace, atleast_nd, copy, cov, size,
)

__docformat__ = 'restructuredtext'

__all__ = ['whiten', 'vq', 'kmeans', 'kmeans2']
Expand Down
9 changes: 8 additions & 1 deletion scipy/conftest.py
@@ -1,4 +1,5 @@
# Pytest customization
from __future__ import annotations
import json
import os
import warnings
Expand All @@ -12,7 +13,6 @@
from scipy._lib._fpumode import get_fpu_mode
from scipy._lib._testutils import FPUModeChangeWarning
from scipy._lib import _pep440
from scipy._lib._array_api import SCIPY_ARRAY_API, SCIPY_DEVICE


def pytest_configure(config):
Expand Down Expand Up @@ -107,6 +107,11 @@ def check_fpu_mode(request):
# Array API backend handling
xp_available_backends = {'numpy': np}

# To enable array API and strict array-like input validation
SCIPY_ARRAY_API: str | bool = os.environ.get("SCIPY_ARRAY_API", False)
# To control the default device - for use in the test suite only
SCIPY_DEVICE = os.environ.get("SCIPY_DEVICE", "cpu")

if SCIPY_ARRAY_API and isinstance(SCIPY_ARRAY_API, str):
# fill the dict of backends with available libraries
try:
Expand Down Expand Up @@ -151,6 +156,8 @@ def check_fpu_mode(request):

array_api_compatible = pytest.mark.parametrize("xp", xp_available_backends.values())

skip_xp_mark = pytest.mark.skip_xp_backends


@pytest.fixture
def skip_xp_backends(xp, request):
Expand Down
7 changes: 4 additions & 3 deletions scipy/fft/_basic_backend.py
@@ -1,8 +1,9 @@
import numpy as np
from . import _pocketfft

from scipy._lib._array_api import (
array_namespace, is_numpy, xp_unsupported_param_msg, is_complex
array_namespace, is_complex, is_numpy, xp_unsupported_param_msg,
)
from . import _pocketfft
import numpy as np


def _validate_fft_args(workers, plan, norm):
Expand Down
2 changes: 1 addition & 1 deletion scipy/fft/_helper.py
@@ -1,9 +1,9 @@
from functools import update_wrapper, lru_cache
import inspect

import numpy as np
from ._pocketfft import helper as _helper

import numpy as np
from scipy._lib._array_api import array_namespace


Expand Down
3 changes: 2 additions & 1 deletion scipy/fft/_realtransforms_backend.py
@@ -1,7 +1,8 @@
from scipy._lib._array_api import array_namespace
import numpy as np
from . import _pocketfft

from scipy._lib._array_api import array_namespace

__all__ = ['dct', 'idct', 'dst', 'idst', 'dctn', 'idctn', 'dstn', 'idstn']


Expand Down
11 changes: 8 additions & 3 deletions scipy/fft/tests/test_basic.py
Expand Up @@ -6,14 +6,19 @@
from numpy.random import random
from numpy.testing import assert_array_almost_equal, assert_allclose
from pytest import raises as assert_raises

import scipy.fft as fft
from scipy.conftest import array_api_compatible

from scipy._lib._array_api import (
array_namespace, size, xp_assert_close, xp_assert_equal
array_api_compatible,
array_namespace,
size,
skip_xp_backends,
xp_assert_close,
xp_assert_equal,
)

pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends


# Expected input dtypes. Note that `scipy.fft` is more flexible for numpy,
Expand Down
3 changes: 1 addition & 2 deletions scipy/fft/tests/test_fftlog.py
Expand Up @@ -5,8 +5,7 @@
from scipy.fft._fftlog import fht, ifht, fhtoffset
from scipy.special import poch

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import xp_assert_close
from scipy._lib._array_api import array_api_compatible, xp_assert_close

pytestmark = array_api_compatible

Expand Down
12 changes: 8 additions & 4 deletions scipy/fft/tests/test_helper.py
Expand Up @@ -4,18 +4,21 @@
Modified for Array API, 2023
"""
from scipy.fft._helper import next_fast_len, _init_nd_shape_and_axes

from numpy.testing import assert_equal
from pytest import raises as assert_raises
import pytest
import numpy as np
import sys
from scipy.conftest import array_api_compatible
from scipy._lib._array_api import xp_assert_close, SCIPY_DEVICE

from scipy import fft
from scipy.fft._helper import next_fast_len, _init_nd_shape_and_axes

from scipy._lib._array_api import (
array_api_compatible, SCIPY_DEVICE, skip_xp_backends, xp_assert_close,
)

pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends")]
skip_xp_backends = pytest.mark.skip_xp_backends

_5_smooth_numbers = [
2, 3, 4, 5, 6, 8, 9, 10,
Expand All @@ -24,6 +27,7 @@
2**3 * 3**3 * 5**2,
]


def test_next_fast_len():
for n in _5_smooth_numbers:
assert_equal(next_fast_len(n), n)
Expand Down

0 comments on commit 8d7e623

Please sign in to comment.