Skip to content

Commit 185c956

Browse files
committed
WIP
1 parent afb5212 commit 185c956

File tree

8 files changed

+96
-27
lines changed

8 files changed

+96
-27
lines changed

sparse/numba_backend/_common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212

1313
from ._coo import as_coo
14+
from ._settings import SUPPORTED_ARRAY_TYPE
1415
from ._sparse_array import SparseArray
1516
from ._utils import (
1617
_zero_of_dtype,
@@ -30,6 +31,13 @@ def _is_scipy_sparse_obj(x):
3031
return bool(hasattr(x, "__module__") and x.__module__.startswith("scipy.sparse"))
3132

3233

34+
def _coerce_to_supported_dense(x) -> SUPPORTED_ARRAY_TYPE:
35+
if isinstance(x, SUPPORTED_ARRAY_TYPE):
36+
return x
37+
38+
return np.asarray(x)
39+
40+
3341
def _check_device(func):
3442
@wraps(func)
3543
def wrapped(*args, **kwargs):

sparse/numba_backend/_compressed/compressed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ class GCXS(SparseArray, NDArrayOperatorsMixin):
132132

133133
__array_priority__ = 12
134134

135+
__array_members__ = ("data", "indices", "indptr")
136+
135137
def __init__(
136138
self,
137139
arg,

sparse/numba_backend/_coo/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,15 @@ def asCOO(x, name="asCOO", check=True):
5555

5656

5757
def linear_loc(coords, shape):
58+
namespace = coords.__array_namespace__()
5859
if shape == () and len(coords) == 0:
5960
# `np.ravel_multi_index` is not aware of arrays, so cannot produce a
6061
# sensible result here (https://github.com/numpy/numpy/issues/15690).
6162
# Since `coords` is an array and not a sequence, we know the correct
6263
# dimensions.
63-
return np.zeros(coords.shape[1:], dtype=np.intp)
64+
return namespace.zeros(coords.shape[1:], dtype=namespace.intp)
6465

65-
return np.ravel_multi_index(coords, shape)
66+
return namespace.ravel_multi_index(coords, shape)
6667

6768

6869
def kron(a, b):

sparse/numba_backend/_coo/core.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ class COO(SparseArray, NDArrayOperatorsMixin): # lgtm [py/missing-equals]
195195

196196
__array_priority__ = 12
197197

198+
__array_members__ = ("data", "coords")
199+
198200
def __init__(
199201
self,
200202
coords,
@@ -207,6 +209,8 @@ def __init__(
207209
fill_value=None,
208210
idx_dtype=None,
209211
):
212+
from .._common import _coerce_to_supported_dense
213+
210214
if isinstance(coords, COO):
211215
self._make_shallow_copy_of(coords)
212216
if data is not None or shape is not None:
@@ -226,8 +230,8 @@ def __init__(
226230
self.enable_caching()
227231
return
228232

229-
self.data = np.asarray(data)
230-
self.coords = np.asarray(coords)
233+
self.data = _coerce_to_supported_dense(data)
234+
self.coords = _coerce_to_supported_dense(coords)
231235

232236
if self.coords.ndim == 1:
233237
if self.coords.size == 0 and shape is not None:
@@ -236,7 +240,7 @@ def __init__(
236240
self.coords = self.coords[None, :]
237241

238242
if self.data.ndim == 0:
239-
self.data = np.broadcast_to(self.data, self.coords.shape[1])
243+
self.data = self._component_namespace.broadcast_to(self.data, self.coords.shape[1])
240244

241245
if self.data.ndim != 1:
242246
raise ValueError("`data` must be a scalar or 1-dimensional.")
@@ -251,7 +255,9 @@ def __init__(
251255
shape = tuple(shape)
252256

253257
if shape and not self.coords.size:
254-
self.coords = np.zeros((len(shape) if isinstance(shape, Iterable) else 1, 0), dtype=np.intp)
258+
self.coords = self._component_namespace.zeros(
259+
(len(shape) if isinstance(shape, Iterable) else 1, 0), dtype=np.intp
260+
)
255261
super().__init__(shape, fill_value=fill_value)
256262
if idx_dtype:
257263
if not can_store(idx_dtype, max(shape)):
@@ -1307,10 +1313,10 @@ def _sort_indices(self):
13071313
"""
13081314
linear = self.linear_loc()
13091315

1310-
if (np.diff(linear) >= 0).all(): # already sorted
1316+
if (self._component_namespace.diff(linear) >= 0).all(): # already sorted
13111317
return
13121318

1313-
order = np.argsort(linear, kind="mergesort")
1319+
order = self._component_namespace.argsort(linear, kind="mergesort")
13141320
self.coords = self.coords[:, order]
13151321
self.data = self.data[order]
13161322

@@ -1336,16 +1342,16 @@ def _sum_duplicates(self):
13361342
# Inspired by scipy/sparse/coo.py::sum_duplicates
13371343
# See https://github.com/scipy/scipy/blob/main/LICENSE.txt
13381344
linear = self.linear_loc()
1339-
unique_mask = np.diff(linear) != 0
1345+
unique_mask = self._component_namespace.diff(linear) != 0
13401346

13411347
if unique_mask.sum() == len(unique_mask): # already unique
13421348
return
13431349

1344-
unique_mask = np.append(True, unique_mask)
1350+
unique_mask = self._component_namespace.append(True, unique_mask)
13451351

13461352
coords = self.coords[:, unique_mask]
1347-
(unique_inds,) = np.nonzero(unique_mask)
1348-
data = np.add.reduceat(self.data, unique_inds, dtype=self.data.dtype)
1353+
(unique_inds,) = self._component_namespace.nonzero(unique_mask)
1354+
data = self._component_namespace.add.reduceat(self.data, unique_inds, dtype=self.data.dtype)
13491355

13501356
self.data = data
13511357
self.coords = coords

sparse/numba_backend/_settings.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib.util
12
import os
23

34
import numpy as np
@@ -17,4 +18,20 @@ def __array_function__(self, *args, **kwargs):
1718
return False
1819

1920

21+
def _supported_array_type() -> type[np.ndarray]:
22+
try:
23+
import cupy as cp
24+
25+
return np.ndarray | cp.ndarray
26+
except ImportError:
27+
return np.ndarray
28+
29+
30+
def _cupy_available() -> bool:
31+
return importlib.util.find_spec("cupy") is not None
32+
33+
2034
NEP18_ENABLED = _is_nep18_enabled()
35+
NUMPY_DEVICE = np.asarray(5).device
36+
SUPPORTED_ARRAY_TYPE = _supported_array_type()
37+
CUPY_AVAILABLE = _cupy_available()

sparse/numba_backend/_sparse_array.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import copy
23
import operator
34
import warnings
45
from abc import ABCMeta, abstractmethod
@@ -27,6 +28,7 @@ class SparseArray:
2728
"""
2829

2930
__metaclass__ = ABCMeta
31+
__array_members__: tuple[str, ...] = ()
3032

3133
def __init__(self, shape, fill_value=None):
3234
if not isinstance(shape, Iterable):
@@ -45,18 +47,54 @@ def __init__(self, shape, fill_value=None):
4547
else:
4648
self.fill_value = _zero_of_dtype(self.dtype)
4749

50+
self.device # noqa: B018
51+
4852
dtype = None
4953

5054
@property
5155
def device(self):
5256
data = getattr(self, "data", None)
53-
return getattr(data, "device", "cpu")
57+
device = getattr(data, "device", "cpu")
58+
assert all(getattr(self, m).device == device for m in self.__array_members__)
59+
return device
60+
61+
@property
62+
def _component_namespace(self):
63+
data = getattr(self, "data", None)
64+
namespace = getattr(data, "__array_namespace__", np)
65+
if namespace is not np:
66+
namespace = namespace()
67+
68+
assert all(getattr(self, m).__array_namespace__() == namespace for m in self.__array_members__)
69+
return namespace
5470

5571
def to_device(self, device, /, *, stream=None):
56-
if device != "cpu":
57-
raise ValueError("Only `device='cpu'` is supported.")
72+
if stream is not None:
73+
raise NotImplementedError("Only `stream=None` is supported at the moment.")
74+
75+
if device == self.device:
76+
return self
77+
78+
import cupy as cp
79+
80+
from ._settings import NUMPY_DEVICE
81+
82+
self_copy = copy.copy(self)
83+
if device == NUMPY_DEVICE:
84+
for member_name in self.__array_members__:
85+
member_array_gpu = getattr(self, member_name)
86+
member_array_cpu = cp.asnumpy(member_array_gpu)
87+
setattr(self_copy, member_array_cpu)
88+
89+
return self_copy
90+
91+
for member_name in self.__array_members__:
92+
member_array_source = getattr(self, member_name)
93+
with cp.cuda.Device(device):
94+
member_array_dest = cp.asarray(member_array_source)
95+
setattr(self_copy, member_array_dest)
5896

59-
return self
97+
return self_copy
6098

6199
@property
62100
@abstractmethod

sparse/numba_backend/_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,11 @@ def equivalent(x, y, /, loose=False):
431431
>>> equivalent(np.float64(0.0), np.float64(-0.0))
432432
np.False_
433433
"""
434-
x = np.asarray(x)
435-
y = np.asarray(y)
434+
from ._common import _coerce_to_supported_dense
435+
436+
x = _coerce_to_supported_dense(x)
437+
y = _coerce_to_supported_dense(y)
438+
namespace = x.__array_namespace__()
436439
# Can't contain NaNs
437440
dt = np.result_type(x.dtype, y.dtype)
438441
if not any(np.issubdtype(dt, t) for t in [np.floating, np.complexfloating]):
@@ -446,9 +449,9 @@ def equivalent(x, y, /, loose=False):
446449
return (x == y) | ((x != x) & (y != y))
447450

448451
if x.size == 0 or y.size == 0:
449-
shape = np.broadcast_shapes(x.shape, y.shape)
450-
return np.empty(shape, dtype=np.bool_)
451-
x, y = np.broadcast_arrays(x[..., None], y[..., None])
452+
shape = namespace.broadcast_shapes(x.shape, y.shape)
453+
return namespace.empty(shape, dtype=np.bool_)
454+
x, y = namespace.broadcast_arrays(x[..., None], y[..., None])
452455
return (x.astype(dt).view(np.uint8) == y.astype(dt).view(np.uint8)).all(axis=-1)
453456

454457

sparse/numba_backend/tests/test_coo.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,9 +1913,3 @@ def test_to_device():
19131913
s2 = s.to_device(s.device)
19141914

19151915
assert s is s2
1916-
1917-
1918-
def test_to_invalid_device():
1919-
s = sparse.random((5, 5), density=0.5)
1920-
with pytest.raises(ValueError, match=r"Only .* is supported."):
1921-
s.to_device("invalid_device")

0 commit comments

Comments
 (0)