Skip to content

Commit 0314e3d

Browse files
committed
Fix Array API tests for Numba backend.
1 parent 2bca00c commit 0314e3d

File tree

4 files changed

+44
-4
lines changed

4 files changed

+44
-4
lines changed

ci/Numba-array-api-xfails.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__
4242
array_api_tests/test_has_names.py::test_has_names[array_method-__setitem__]
4343
array_api_tests/test_indexing_functions.py::test_take
4444
array_api_tests/test_linalg.py::test_vecdot
45-
array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
46-
array_api_tests/test_operators_and_elementwise_functions.py::test_trunc
4745
array_api_tests/test_set_functions.py::test_unique_all
4846
array_api_tests/test_set_functions.py::test_unique_inverse
4947
array_api_tests/test_signatures.py::test_func_signature[unique_all]

ci/test_array_api.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ source ci/clone_array_api_tests.sh
66
if [ "${SPARSE_BACKEND}" = "Finch" ]; then
77
python -c 'import finch'
88
fi
9-
ARRAY_API_TESTS_MODULE="sparse" pytest "$ARRAY_API_TESTS_DIR/array_api_tests/" -v -c "$ARRAY_API_TESTS_DIR/pytest.ini" --ci --max-examples=2 --derandomize --disable-deadline -o xfail_strict=True -n auto --xfails-file ../sparse/ci/${SPARSE_BACKEND}-array-api-xfails.txt --skips-file ../sparse/ci/${SPARSE_BACKEND}-array-api-skips.txt
9+
ARRAY_API_TESTS_MODULE="sparse" pytest "$ARRAY_API_TESTS_DIR/array_api_tests/" -v -c "$ARRAY_API_TESTS_DIR/pytest.ini" --ci --max-examples=2 --derandomize --disable-deadline --disable-warnings -o xfail_strict=True -n auto --xfails-file ../sparse/ci/${SPARSE_BACKEND}-array-api-xfails.txt --skips-file ../sparse/ci/${SPARSE_BACKEND}-array-api-skips.txt

sparse/numba_backend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
bitwise_not,
55
bitwise_or,
66
bitwise_xor,
7-
can_cast,
87
ceil,
98
complex64,
109
complex128,
@@ -86,6 +85,7 @@
8685
astype,
8786
broadcast_arrays,
8887
broadcast_to,
88+
can_cast,
8989
concat,
9090
concatenate,
9191
dot,

sparse/numba_backend/_common.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,13 @@ def eye(N, M=None, k=0, dtype=float, format="coo", *, device=None, **kwargs):
16011601
k = int(k)
16021602

16031603
data_length = builtins.min(N, M)
1604+
if k > 0:
1605+
data_length = builtins.max(builtins.min(data_length, M - k), 0)
1606+
elif k < 0:
1607+
data_length = builtins.max(builtins.min(data_length, N + k), 0)
1608+
1609+
if data_length == 0:
1610+
return zeros((N, M), dtype=dtype, format=format, device=device)
16041611

16051612
if k > 0:
16061613
data_length = builtins.max(builtins.min(data_length, M - k), 0)
@@ -1854,6 +1861,41 @@ def empty_like(a, dtype=None, shape=None, format=None, *, device=None, **kwargs)
18541861
empty_like.__doc__ = zeros_like.__doc__
18551862

18561863

1864+
def can_cast(from_: SparseArray, to: np.dtype, /, *, casting: str = "safe") -> bool:
1865+
"""Determines if one data type can be cast to another data type
1866+
1867+
Parameters
1868+
----------
1869+
from_ : dtype or SparseArray
1870+
Source array or dtype.
1871+
to : dtype
1872+
Destination dtype.
1873+
casting: str
1874+
Casting kind
1875+
1876+
Returns
1877+
-------
1878+
out : bool
1879+
Whether or not a cast is possible.
1880+
1881+
Examples
1882+
--------
1883+
>>> x = sparse.ones((2, 3), dtype=sparse.int8)
1884+
>>> sparse.can_cast(x, sparse.float64)
1885+
True
1886+
1887+
See Also
1888+
--------
1889+
- [`numpy.can_cast`][] : NumPy equivalent function
1890+
"""
1891+
try:
1892+
from_ = np.dtype(from_)
1893+
except TypeError:
1894+
from_ = from_.dtype
1895+
1896+
return np.can_cast(from_, to, casting=casting)
1897+
1898+
18571899
def outer(a, b, out=None):
18581900
"""
18591901
Return outer product of two sparse arrays.

0 commit comments

Comments
 (0)