-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: special.logsumexp
: add array API support
#20935
Conversation
special.logsumexp
: add array API support
As an aside I realised we have not been keeping https://scipy.github.io/devdocs/dev/api-dev/array_api.html#currently-supported-functionality up to date. Do we want to try to get a doc update into 1.14? |
Oops. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just one comment otherwise LGTM!
@@ -93,43 +96,49 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): | |||
1.6094379124341005, 1.6094379124341005 | |||
""" | |||
a = _asarray_validated(a, check_finite=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can quite replace _asarray_validated
with the checks in array_namespace
. Specifically, the sparse_ok=False
part. I do think that it is worth incorporating _asarray_validated
into array_namespace
though, which should just mean adding a sparse_ok
kwarg which gets passed to compliance_scipy
, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a sparse check to array_namespace
sounds like a good Idea. If the goal of those checks is to exclude common array-like things that we don't want to support, I think we'd at least need to exclude sparse matrices (always).
IIUC eventually sparse arrays will have an array API compliant interfaces. That would mean they could be added into the array_api_compliant
test matrix and supported properly. I suppose there is still a question of whether they should be allowed if the algorithm does not take advantage of the sparsity, and that will affect the default.
Can I open an issue about this instead? I think it is relevant to a lot of other functions (e.g. stats
functions never used _asarray_validated
, but I think all the same reasoning would apply. It will also take a few opinions to resolve, but it would be good to have access to _broadcast_promote
in the meantime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC eventually sparse arrays will have an array API compliant interfaces
Not necessarily. My understanding is "let's try as far as possible, but some things might not make sense", for scipy.sparse
(gh-18915). I think pydata/sparse
is planning to have a compliant inferface though.
Can I open an issue about this instead? I think it is relevant to a lot of other functions (e.g. stats functions never used _asarray_validated, but I think all the same reasoning would apply.
Sure, I was just wanting to avoid a regression, where sparse matrices were rejected but now pass through. That doesn't seem too problematic though, so as long as it is on your radar 👍.
An issue already exists, gh-18972, and I commented about this topic there last year.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sparse matrices were rejected but now pass through
They're still rejected, just with a little less obvious error.
from scipy import sparse, special
x = sparse.csc_array([[1, 2, 3, 0]]) # same with csc_matrix
special.logsumexp(x, axis=-1)
# TypeError: ufunc 'isfinite' not supported for the input types...
Yeah, this will be on my radar, and I see it's already tracked. Thanks!
scipy/special/_logsumexp.py
Outdated
a, b = xp_broadcast_promote(a, b, ensure_writeable=True) | ||
axis = tuple(range(a.ndim)) if axis is None else axis | ||
|
||
if b is not None and xp.any(b==0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How important is it to support staged-out computation as in JAX? Conditioning control flow on array contents (e.g. if xp.any(b==0)
) will fail for staged-out operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if b is not None and xp.any(b==0): | |
if b is not None: |
should be fine.
That's OK with me.
|
||
if a_max.ndim > 0: | ||
a_max[~np.isfinite(a_max)] = 0 | ||
elif not np.isfinite(a_max): | ||
a_max[~xp.isfinite(a_max)] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To support libraries with immutable arrays, you could write this as
a_max = xp.where(xp.isfinite(a_max), a_max, 0)
although this will be less memory-efficient in the case of NumPy. Perhaps this is a good case for an array-compat utilty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, this has come up several times, as recently as yesterday - data-apis/array-api-compat#144 (comment). I think we need a function that mutates if possible and copies otherwise, but IIRC others (#20085 (comment)) wanted to wait for something before adding a utility. Letting JAX fail for now will help us find the places where the utility can be used when it exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, we could also dispatch to jax.scipy.special.logsumexp
if it passes the tests. I'll try that.
Still, it might be better to just except JAX from testing. Ultimately, it would be nice not to special case for JAX, and the exception in the test suite reminds us to check on JAX when we have the utility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wanted to wait for something
It's not clear to me what we're waiting for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @rgommers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's bring this discussion back to #20085 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just realised that we need to add these tests to the CI job! was wondering how some of the indexing was "passing" with JAX 😅
Oops. Will do. I should know by now that if an array API PR passes in the first run, something is definitely wrong. |
It will be easier to review the latest commit separately with "Hide whitespace" on, since I moved all the |
[skip cirrus] [skip circle]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mdhaber. Before looking in detail at this, the immediate question I have is: why not add this to _support_alternative_backends.py
? PyTorch, JAX and CuPy all have a logsumexp
in their special
namespaces. And if we'd like array-api-strict
and/or array-api-compat
to support it, we should move data-apis/array-api#596 forward.
I'm a little uncomfortable with how many very large diffs are coming through for array API/types support, and this is one function where I'd expect that if we do it right the diff (modulo boilerplate test changes) could be very small instead.
As the author of As for moving data-apis/array-api#596 forward, that sounds great. But it's over a year old, and it will probably take at least another year for that implementation to become available in SciPy given that standards seem to be released annually and 2023 standard support still isn't in The diff actually is quite readable if you were to take a look. It would be smaller and more boilerplate but for the array API standard not supporting simple things. For example, But I have plenty of other work to do, so I'll just go ahead and close this for now and put follow-up work on hold. If you have general concerns about all these conversions, consider setting up a meeting. I'm not sure how everything can be one for one tranlsations if the array API standard is not the same as NumPy, but maybe you will see a way to do it. |
Hey Matt, it was a genuine (and fairly obvious, since it was your first idea as well) question, not a blocking review- so no need to close this straight away. It still isn't quite clear to me to what extent weights or complex support are needed. I had a quick look, and the
For complex support it's a little harder to tell without auditing all the code. Both are probably not that hard to implement in PyTorch - just wondering how pressing the need is.
It's still ~300 lines with whitespace hidden, and mixes in two separate refactors (input validation and tests in a class) in the same commit. If those refactors were separated out, this would probably look cleaner? Note that it's not that it's me who is most concerned here - I'm very aware that there's a subset of maintainers who don't have a personal interest in array types support; they're fine with it as long as it doesn't get in the way too much. These refactors sound like perfectly fine code cleanups in isolation, but I'm worried that to the casual reader they seem necessary for array types support and make that look much more complex/churny than it really is.
Having a higher-bandwidth conversation more regularly sounds like a great idea. Are you able to join the community meeting today? |
Reference issue
Toward gh-20930
What does this implement/fix?
Adds array API support to
special.logsumexp
. I'll need this forintegrate.qmc_quad
.@lucascolley @j-bowhay the
xp_broadcast_promote
utility should replace the need for promoting dtypes manually.