-
-
Notifications
You must be signed in to change notification settings - Fork 132
Move to Array API version 2023.12. #696
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
import sparse.numba_backend._info as _info | ||
|
||
from numpy import ( | ||
add, | ||
bitwise_and, | ||
|
@@ -9,6 +11,7 @@ | |
complex64, | ||
complex128, | ||
conj, | ||
copysign, | ||
cos, | ||
cosh, | ||
divide, | ||
|
@@ -23,6 +26,7 @@ | |
floor_divide, | ||
greater, | ||
greater_equal, | ||
hypot, | ||
iinfo, | ||
inf, | ||
int8, | ||
|
@@ -41,6 +45,8 @@ | |
logical_not, | ||
logical_or, | ||
logical_xor, | ||
maximum, | ||
minimum, | ||
multiply, | ||
nan, | ||
negative, | ||
|
@@ -50,6 +56,7 @@ | |
positive, | ||
remainder, | ||
sign, | ||
signbit, | ||
sin, | ||
sinh, | ||
sqrt, | ||
|
@@ -119,6 +126,7 @@ | |
std, | ||
sum, | ||
tensordot, | ||
unstack, | ||
var, | ||
vecdot, | ||
zeros, | ||
|
@@ -157,10 +165,16 @@ | |
where, | ||
) | ||
from ._dok import DOK | ||
from ._info import capabilities, default_device, default_dtypes, devices, dtypes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we shouldn't import them to the main namespace. Let's keep them in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current array API tests require these in the main namespace as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm then I think it's a bug in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Array API standard doesn't say they need to be in the main namespace: https://data-apis.org/array-api/latest/API_specification/inspection.html#inspection-apis There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Demonstrated the failure in 42c9bfd, let's wait for |
||
from ._io import load_npz, save_npz | ||
from ._umath import elemwise | ||
from ._utils import random | ||
|
||
|
||
def __array_namespace_info__(): | ||
return _info | ||
|
||
|
||
__all__ = [ | ||
"COO", | ||
"DOK", | ||
|
@@ -196,19 +210,25 @@ | |
"broadcast_arrays", | ||
"broadcast_to", | ||
"can_cast", | ||
"capabilities", | ||
"ceil", | ||
"clip", | ||
"complex128", | ||
"complex64", | ||
"concat", | ||
"concatenate", | ||
"conj", | ||
"copysign", | ||
"cos", | ||
"cosh", | ||
"default_device", | ||
"default_dtypes", | ||
"devices", | ||
"diagonal", | ||
"diagonalize", | ||
"divide", | ||
"dot", | ||
"dtypes", | ||
"e", | ||
"einsum", | ||
"elemwise", | ||
|
@@ -230,6 +250,7 @@ | |
"full_like", | ||
"greater", | ||
"greater_equal", | ||
"hypot", | ||
"iinfo", | ||
"imag", | ||
"inf", | ||
|
@@ -258,8 +279,10 @@ | |
"matmul", | ||
"matrix_transpose", | ||
"max", | ||
"maximum", | ||
"mean", | ||
"min", | ||
"minimum", | ||
"moveaxis", | ||
"multiply", | ||
"nan", | ||
|
@@ -291,6 +314,7 @@ | |
"round", | ||
"save_npz", | ||
"sign", | ||
"signbit", | ||
"sin", | ||
"sinh", | ||
"sort", | ||
|
@@ -314,6 +338,7 @@ | |
"uint8", | ||
"unique_counts", | ||
"unique_values", | ||
"unstack", | ||
"var", | ||
"vecdot", | ||
"where", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,7 +155,7 @@ | |
bs = b.shape | ||
ndb = b.ndim | ||
equal = True | ||
if nda == 0 or ndb == 0: | ||
if not (builtins.all(-nda <= ax < nda for ax in axes_a) and builtins.all(-ndb <= ax < ndb for ax in axes_b)): | ||
pos = int(nda != 0) | ||
raise ValueError(f"Input {pos} operand does not have enough dimensions") | ||
if na != nb: | ||
|
@@ -2146,10 +2146,22 @@ | |
return x.reshape(shape=shape) | ||
|
||
|
||
def astype(x, dtype, /, *, copy=True): | ||
@_check_device | ||
def astype(x, dtype, /, *, copy=True, device=None): | ||
return x.astype(dtype, copy=copy) | ||
|
||
|
||
def unstack(x, /, *, axis=0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
axis = normalize_axis(axis, x.ndim) | ||
out = [] | ||
|
||
for i in range(x.shape[axis]): | ||
idx = (slice(None),) * axis + (i,) | ||
out.append(x[idx]) | ||
|
||
return tuple(out) | ||
|
||
|
||
@_support_numpy | ||
def squeeze(x, /, axis=None): | ||
"""Remove singleton dimensions from array. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want to test these functions. Let's add simple tests, like in NumPy: https://github.com/numpy/numpy/pull/26572/files#diff-db073cec9b943fac08cf9720c471d90dcbb7a0e00f4717433314ec95bee60fe2 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import numpy as np | ||
|
||
from ._common import _check_device | ||
|
||
__all__ = [ | ||
"capabilities", | ||
"default_device", | ||
"default_dtypes", | ||
"devices", | ||
"dtypes", | ||
] | ||
|
||
_CAPABILITIES = { | ||
"boolean indexing": True, | ||
"data-dependent shapes": True, | ||
} | ||
|
||
_DEFAULT_DTYPES = { | ||
"cpu": { | ||
"real floating": np.dtype(np.float64), | ||
"complex floating": np.dtype(np.complex128), | ||
"integral": np.dtype(np.int64), | ||
"indexing": np.dtype(np.int64), | ||
} | ||
} | ||
|
||
|
||
def _get_dtypes_with_prefix(prefix: str): | ||
out = set() | ||
for a in np.__all__: | ||
if not a.startswith(prefix): | ||
continue | ||
try: | ||
dt = np.dtype(getattr(np, a)) | ||
out.add(dt) | ||
except (ValueError, TypeError, AttributeError): | ||
pass | ||
return sorted(out) | ||
|
||
|
||
_DTYPES = { | ||
"cpu": { | ||
"bool": [np.bool_], | ||
"signed integer": _get_dtypes_with_prefix("int"), | ||
"unsigned integer": _get_dtypes_with_prefix("uint"), | ||
"real floating": _get_dtypes_with_prefix("float"), | ||
"complex floating": _get_dtypes_with_prefix("complex"), | ||
} | ||
} | ||
|
||
for _dtdict in _DTYPES.values(): | ||
_dtdict["integral"] = _dtdict["signed integer"] + _dtdict["unsigned integer"] | ||
_dtdict["numeric"] = _dtdict["integral"] + _dtdict["real floating"] + _dtdict["complex floating"] | ||
|
||
del _dtdict | ||
|
||
|
||
def capabilities(): | ||
return _CAPABILITIES | ||
|
||
|
||
def default_device(): | ||
return "cpu" | ||
|
||
|
||
@_check_device | ||
def default_dtypes(*, device=None): | ||
if device is None: | ||
device = default_device() | ||
return _DEFAULT_DTYPES[device] | ||
|
||
|
||
def devices(): | ||
return ["cpu"] | ||
|
||
|
||
@_check_device | ||
def dtypes(*, device=None, kind=None): | ||
if device is None: | ||
device = default_device() | ||
|
||
device_dtypes = _DTYPES[device] | ||
|
||
if kind is None: | ||
return device_dtypes | ||
|
||
if isinstance(kind, str): | ||
return device_dtypes[kind] | ||
|
||
out = {} | ||
|
||
for k in kind: | ||
out[k] = device_dtypes[k] | ||
|
||
return out | ||
Uh oh!
There was an error while loading. Please reload this page.