Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: special.logsumexp: add array API support #20935

Closed
wants to merge 4 commits into from

Conversation

mdhaber
Copy link
Contributor

@mdhaber mdhaber commented Jun 10, 2024

Reference issue

Toward gh-20930

What does this implement/fix?

Adds array API support to special.logsumexp. I'll need this for integrate.qmc_quad.

@lucascolley @j-bowhay the xp_broadcast_promote utility should replace the need for promoting dtypes manually.

@mdhaber mdhaber added scipy.special array types Items related to array API support and input array validation (see gh-18286) labels Jun 10, 2024
@mdhaber mdhaber added the enhancement A new feature or improvement label Jun 10, 2024
@mdhaber mdhaber changed the title ENH: special.logsumexp: add array API support ENH: special.logsumexp: add array API support Jun 10, 2024
@j-bowhay
Copy link
Member

As an aside I realised we have not been keeping https://scipy.github.io/devdocs/dev/api-dev/array_api.html#currently-supported-functionality up to date. Do we want to try to get a doc update into 1.14?

@mdhaber
Copy link
Contributor Author

mdhaber commented Jun 10, 2024

Oops. If I knew that existed, I forgot. Clearly I knew that existed, but I forgot. Yes, would you like to do that or should I? The offerings as of 1.14 are in the release notes. Update: see gh-20936.

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just one comment otherwise LGTM!

@@ -93,43 +96,49 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
1.6094379124341005, 1.6094379124341005
"""
a = _asarray_validated(a, check_finite=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can quite replace _asarray_validated with the checks in array_namespace. Specifically, the sparse_ok=False part. I do think that it is worth incorporating _asarray_validated into array_namespace though, which should just mean adding a sparse_ok kwarg which gets passed to compliance_scipy, I think.

Copy link
Contributor Author

@mdhaber mdhaber Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a sparse check to array_namespace sounds like a good Idea. If the goal of those checks is to exclude common array-like things that we don't want to support, I think we'd at least need to exclude sparse matrices (always).

IIUC eventually sparse arrays will have an array API compliant interfaces. That would mean they could be added into the array_api_compliant test matrix and supported properly. I suppose there is still a question of whether they should be allowed if the algorithm does not take advantage of the sparsity, and that will affect the default.

Can I open an issue about this instead? I think it is relevant to a lot of other functions (e.g. stats functions never used _asarray_validated, but I think all the same reasoning would apply. It will also take a few opinions to resolve, but it would be good to have access to _broadcast_promote in the meantime.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC eventually sparse arrays will have an array API compliant interfaces

Not necessarily. My understanding is "let's try as far as possible, but some things might not make sense", for scipy.sparse (gh-18915). I think pydata/sparse is planning to have a compliant inferface though.

Can I open an issue about this instead? I think it is relevant to a lot of other functions (e.g. stats functions never used _asarray_validated, but I think all the same reasoning would apply.

Sure, I was just wanting to avoid a regression, where sparse matrices were rejected but now pass through. That doesn't seem too problematic though, so as long as it is on your radar 👍.

An issue already exists, gh-18972, and I commented about this topic there last year.

Copy link
Contributor Author

@mdhaber mdhaber Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sparse matrices were rejected but now pass through

They're still rejected, just with a little less obvious error.

from scipy import sparse, special
x = sparse.csc_array([[1, 2, 3, 0]])  # same with csc_matrix
special.logsumexp(x, axis=-1)
# TypeError: ufunc 'isfinite' not supported for the input types...

Yeah, this will be on my radar, and I see it's already tracked. Thanks!

@lucascolley lucascolley added this to the 1.15.0 milestone Jun 11, 2024
a, b = xp_broadcast_promote(a, b, ensure_writeable=True)
axis = tuple(range(a.ndim)) if axis is None else axis

if b is not None and xp.any(b==0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How important is it to support staged-out computation as in JAX? Conditioning control flow on array contents (e.g. if xp.any(b==0)) will fail for staged-out operations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if b is not None and xp.any(b==0):
if b is not None:

should be fine.
That's OK with me.


if a_max.ndim > 0:
a_max[~np.isfinite(a_max)] = 0
elif not np.isfinite(a_max):
a_max[~xp.isfinite(a_max)] = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To support libraries with immutable arrays, you could write this as

a_max = xp.where(xp.isfinite(a_max), a_max, 0)

although this will be less memory-efficient in the case of NumPy. Perhaps this is a good case for an array-compat utilty?

Copy link
Contributor Author

@mdhaber mdhaber Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this has come up several times, as recently as yesterday - data-apis/array-api-compat#144 (comment). I think we need a function that mutates if possible and copies otherwise, but IIRC others (#20085 (comment)) wanted to wait for something before adding a utility. Letting JAX fail for now will help us find the places where the utility can be used when it exists.

Copy link
Contributor Author

@mdhaber mdhaber Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, we could also dispatch to jax.scipy.special.logsumexp if it passes the tests. I'll try that.

Still, it might be better to just except JAX from testing. Ultimately, it would be nice not to special case for JAX, and the exception in the test suite reminds us to check on JAX when we have the utility.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanted to wait for something

It's not clear to me what we're waiting for?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's bring this discussion back to #20085 (comment).

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realised that we need to add these tests to the CI job! was wondering how some of the indexing was "passing" with JAX 😅

@mdhaber
Copy link
Contributor Author

mdhaber commented Jun 11, 2024

Oops. Will do. I should know by now that if an array API PR passes in the first run, something is definitely wrong.

@mdhaber
Copy link
Contributor Author

mdhaber commented Jun 11, 2024

It will be easier to review the latest commit separately with "Hide whitespace" on, since I moved all the logsumexp tests into a class.

scipy/special/_logsumexp.py Outdated Show resolved Hide resolved
scipy/special/_logsumexp.py Outdated Show resolved Hide resolved
@mdhaber mdhaber mentioned this pull request Jun 11, 2024
3 tasks
[skip cirrus] [skip circle]
scipy/special/_logsumexp.py Outdated Show resolved Hide resolved
Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mdhaber. Before looking in detail at this, the immediate question I have is: why not add this to _support_alternative_backends.py? PyTorch, JAX and CuPy all have a logsumexp in their special namespaces. And if we'd like array-api-strict and/or array-api-compat to support it, we should move data-apis/array-api#596 forward.

I'm a little uncomfortable with how many very large diffs are coming through for array API/types support, and this is one function where I'd expect that if we do it right the diff (modulo boilerplate test changes) could be very small instead.

@mdhaber mdhaber closed this Jun 11, 2024
@mdhaber
Copy link
Contributor Author

mdhaber commented Jun 11, 2024

As the author of _support_alternative_backends, naturally, that was my first thought. After further consideration, I went this way because a) torch.logsumexp doesn't support weights or complex input and b) since this could be expressed rather simply in terms of array-API calls, converting this one implementation would allow all array API standard libraries to share an implementation. The cost is a handful of changes from NumPy-specific calls to array API calls.

As for moving data-apis/array-api#596 forward, that sounds great. But it's over a year old, and it will probably take at least another year for that implementation to become available in SciPy given that standards seem to be released annually and 2023 standard support still isn't in array_api_compat. And logsumexp is one of those functions that others depend on, so it's holding up other work. We could always change the implementation if and when it becomes available in the standard.

The diff actually is quite readable if you were to take a look. It would be smaller and more boilerplate but for the array API standard not supporting simple things. For example, xp.max doesn't have an initial argument, so I worked around it. xp.real doesn't work on real dtypes, and that adds another line. And the input validation of this test was interspersed throughout the implementation, so I replaced it with a general function that we can call at the top and re-use in other functions.

But I have plenty of other work to do, so I'll just go ahead and close this for now and put follow-up work on hold. If you have general concerns about all these conversions, consider setting up a meeting. I'm not sure how everything can be one for one tranlsations if the array API standard is not the same as NumPy, but maybe you will see a way to do it.

@rgommers
Copy link
Member

Hey Matt, it was a genuine (and fairly obvious, since it was your first idea as well) question, not a blocking review- so no need to close this straight away.

It still isn't quite clear to me to what extent weights or complex support are needed. I had a quick look, and the b input doesn't seem to be used outside of tests/docs:

% rg logsumexp -g'!_lib/boost_math' -g'!special/tests'
special/_logsumexp.py
4:__all__ = ["logsumexp", "softmax", "log_softmax"]
7:def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
58:    NumPy has a logaddexp function which is very similar to `logsumexp`, but
65:    >>> from scipy.special import logsumexp
67:    >>> logsumexp(a)
76:    >>> logsumexp(a, b=b)
83:    >>> logsumexp([1,2],b=[1,-1],return_sign=True)
86:    Notice that `logsumexp` does not directly support masked arrays. To use it
92:    >>> logsumexp(a.data, b=b), np.log(5)
172:    The `softmax` function is the gradient of `logsumexp`.

special/meson.build
342:  '_logsumexp.py',

special/__init__.py
765:   logsumexp -- Compute the log of the sum of exponentials of input elements.
819:from ._logsumexp import logsumexp, softmax, log_softmax
848:    'logsumexp',

stats/_morestats.py
854:    return special.logsumexp(logx, axis=0) - np.log(len(logx))
861:    logxmu = special.logsumexp([logx, logmean + pij], axis=0)
862:    return np.real(special.logsumexp(2 * logxmu, axis=0)) - np.log(len(logx))

stats/_continuous_distns.py
4146:            return -scale * (sc.logsumexp(-data / scale) - np.log(len(data)))
9875:# logsumexp trick for log(p + q) with only log(p) and log(q)
9877:    return sc.logsumexp([log_p, log_q], axis=0)
9882:    return sc.logsumexp([log_p, log_q+np.pi*1j], axis=0)

stats/_distn_infrastructure.py
1596:            # here we could use logcdf w/ logsumexp trick to take differences,

stats/_discrete_distns.py
8:from scipy.special import entr, logsumexp, betaln, gammaln as gamln, zeta
683:                # Integration over probability mass function using logsumexp
685:                res.append(logsumexp(self._logpmf(k2, tot, good, draw)))
695:                # Integration over probability mass function using logsumexp
697:                res.append(logsumexp(self._logpmf(k2, tot, good, draw)))

stats/_stats.pyx
799:cdef real logsumexp(real a, real b):
875:                estimate[j, k] = logsumexp(estimate[j, k],

integrate/_tanhsinh.py
401:            Sn = (special.logsumexp([Snm1 - np.log(2), Sn], axis=0) if log
670:    Sn = (special.logsumexp(fjwj + np.log(work.h), axis=-1) if work.log
705:        Snm1 = (special.logsumexp(fjwj, **axis_kwargs) + np.log(h) if work.log
724:        Snm2 = (special.logsumexp(fjwj, **axis_kwargs) + np.log(h) if work.log
739:        d1 = np.real(special.logsumexp([work.Sn, Snm1 + work.pi*1j], axis=0))
740:        d2 = np.real(special.logsumexp([work.Sn, Snm2 + work.pi*1j], axis=0))
853:def _logsumexp(x, axis=0):
854:    # logsumexp raises with empty array
861:        return special.logsumexp(x, axis=axis)
1165:    S = _logsumexp(fs, axis=-1) if log else np.sum(fs, axis=-1)
1179:    tol = _logsumexp((tol, rtol + lb.integral)) if log else tol + rtol*lb.integral
1224:        S = _logsumexp(S_terms, axis=0)
1226:        E = _logsumexp(E_terms, axis=0).real

integrate/_quadrature.py
10:from scipy.special import gammaln, logsumexp
1290:            return logsumexp(integrands) + np.log(dA)
1296:            return logsumexp(estimates) - np.log(n_estimates)
1305:            diff = logsumexp(temp, axis=0)
1306:            return np.real(0.5 * (logsumexp(2 * diff)

For complex support it's a little harder to tell without auditing all the code. Both are probably not that hard to implement in PyTorch - just wondering how pressing the need is.

The diff actually is quite readable if you were to take a look.

It's still ~300 lines with whitespace hidden, and mixes in two separate refactors (input validation and tests in a class) in the same commit. If those refactors were separated out, this would probably look cleaner? Note that it's not that it's me who is most concerned here - I'm very aware that there's a subset of maintainers who don't have a personal interest in array types support; they're fine with it as long as it doesn't get in the way too much. These refactors sound like perfectly fine code cleanups in isolation, but I'm worried that to the casual reader they seem necessary for array types support and make that look much more complex/churny than it really is.

If you have general concerns about all these conversions, consider setting up a meeting.

Having a higher-bandwidth conversation more regularly sounds like a great idea. Are you able to join the community meeting today?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy._lib scipy.special
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants