Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Path for Adopting the Array API spec #22352

Open
thomasjpfan opened this issue Feb 1, 2022 · 31 comments
Open

Path for Adopting the Array API spec #22352

thomasjpfan opened this issue Feb 1, 2022 · 31 comments

Comments

@thomasjpfan
Copy link
Member

thomasjpfan commented Feb 1, 2022

I have been experimenting with adopting the Array API spec into scikit-learn. The Array API is one way for scikit-learn to run on other hardware such as GPUs.

I have some POCs on my fork for LinearDiscriminantAnalysis and GaussianMixture. Overall, there is runtime performance benefit when running on CuPy compared to NumPy, as shown in notesbooks for LDA (14x improvement) and GMM (7x improvement).

Official acceptance of the Array API in numpy is tracked as NEP 47.

Proposed User API

Here is the proposed API for dispatching. We require the array to adopt the Array API standard and we have a configuration option to turn on Array API dispatching:

# Create Array API arrays following spec
import cupy.array_api as xp
X_cu = xp.asarray(X_np)
y_cu = xp.asarray(y_np)

# Configure scikit-learn to dispatch
from sklearn import set_config
set_config(array_api_dispatch=True)

# Dispatches using `array_api`
lda_cu = LinearDiscriminantAnalysis()
lda_cu.fit(X_cu, y_cu)

This way the user can decide between the old behavior of potentially casting to NumPy and the new behavior of using the array api if available.

Developer Experience

The Array API spec and the NumPy API overlaps in many cases, but there is API we use in NumPy and not in Array API. There are a few ways to bridge this gap while trying to keep a maintainable code base:

  1. Wrap the Array-API namespace object to make it look "more like NumPy"
  2. Wrap the NumPy module to make it look "more like ArrayAPI"
  3. Helper functions everyone

1 and 2 are not mutually exclusive. To demonstrate these options, I'll do a case study on unique. The Array API spec does not define a unique function, but a unique_values instead.

Wrap the Array-API namespace object to make it look "more like NumPy"

def check_y(y):
    np, _ = get_namespace(y)  # Returns _ArrayAPIWrapper or NumPy
    classes = np.unique(y)

class _ArrayAPIWrapper:
    def unique(self, x):
        return self._array_namespace.unique_values(x)

Existing scikit-learn code does not need to change as much because the Array API "looks like NumPy"

Make NumPy object "look more like Array-API"

def check_y(y):
    xp, _ = get_namespace(y)  # Returns Array API namespace or _NumPyApiWrapper
    classes = xp.unique_values(y)

class _NumPyApiWrapper:
    def unique_values(self, x):
        return np.unique(x)

We need to update scikit-learn to use these new functions from the Array API spec.

Helper functions everyone

def check_y(y):
    classes = _unique_values(y)

def _unique_values(x):
    xp, is_array_api = get_namespace(x)
    if is_array_api:
        return xp.unique_values(x)
    return np.unique(x)

We need to update scikit-learn to use these helper functions when API diverges. Some notable functions that needs some wrapper or helper functions include concat, astype, asarray, unique, errstate, may_share_memory, etc.

For my POCs, I went with a mostly option 1 where I wrapped Array API to look like NumPy. (I did wrap NumPy once to get np.dtype, which is the same as array.astype).

CC @scikit-learn/core-devs

Other API considerations
  • Type promotion is more strict with Array API
import numpy.array_api as xp
X = xp.asarray([1])
y = xp.asarray([1.0])

# fails
X + y
  • No method chaining. (Array API arrays do not have methods on them)
(X.mean(axis=1) > 1.0).any()

# becomes
xp.any(xp.mean(X, axis=1) > 1.0)
@github-actions github-actions bot added the Needs Triage Issue requires triage label Feb 1, 2022
@thomasjpfan thomasjpfan changed the title Path for Adopting Array API spec Path for Adopting the Array API spec Feb 1, 2022
@GaelVaroquaux
Copy link
Member

Thanks @thomasjpfan , this is definitely a direction that I feel is important for scikit-learn, and your summary is super useful. The speed-ups are quite interesting (looking at factors 8 to 10 speed up).

I have a few questions / comments:

  1. Do you think that in the long run, we could get away without the config flag "set_config(array_api_dispatch=True)" ?
  2. The fact that numpy and the array API spec diverge on things as unique seems like a bug in my eyes. Is it felt like this by the upstream actors?
  3. Likewise, the impossibility to support Dask and JAX is quite a hard setback

I'm totally enthusiastic with regards to the performance perspectives, but we need to recognize that it is going to make coding algorithms significantly harder. The reasonable route will probably be that supporting this is not mandatory to add a method, and that we strive to improve as we go.

@ogrisel
Copy link
Member

ogrisel commented Feb 1, 2022

In previous experiments with NEP 18 (accepted) / NEP 37 (more powerful but still draft and probably subsumed by the Array API / NEP 47 effort), we identified estimators that might benefit from an Array-API aware implementation:

I would be in favor of starting to accept such Array API aware implementations in scikit-learn on an estimator-per-estimator basis while making it an experimental feature subject to change without following the usual deprecation cycle.

I think estimators who support array-api dispatching should make that explicit somehow, in their docstring, probably an estimator tag for a new common test and a dedicated page with a list of estimators in the doc.

@lorentzenchr
Copy link
Member

I have the same question as @GaelVaroquaux

Do you think that in the long run, we could get away without the config flag "set_config(array_api_dispatch=True)" ?

From the Array API doc, I would have expected the following:

class LinearDiscriminantAnalysis():

    def fit(X, y):
        xp = X.__array_namespace__()
        # use xp instead of np
        ...

# Create Array API arrays following spec
import cupy.array_api as xp
X_cu = xp.asarray(X_np)
y_cu = xp.asarray(y_np)

# Dispatching is handled automatically using the `array_api`
lda = LinearDiscriminantAnalysis()
lda.fit(X_cu, y_cu)

@ogrisel
Copy link
Member

ogrisel commented Feb 1, 2022

What is the current behavior when passing a CuPy array to LinearDiscriminantAnalysis in scikit-learn 1.0?

  • a ValueError?
  • or a silent copy to a numpy array to main memory?

If this is the latter, then we need an explicit flag to turn-on the new behavior explicitly, at least during the experimental phase.

@thomasjpfan
Copy link
Member Author

thomasjpfan commented Feb 1, 2022

What is the current behavior when passing a CuPy array to LinearDiscriminantAnalysis in scikit-learn 1.0?

It either TypeError or ValueError, depending on which CuPy array you pass in.

If one passes in an "Array API compatible CuPy array", a ValueError is raised. check_array calls numpy.asarary which wraps the array into an object:

import cupy.array_api as cu_xp
import numpy

X = cu_xp.asarray([1, 2, 3])

# numpy object scalar and `check_array` will `ValueError`
print(numpy.asarray(X).dtype)
# object

If one passes a "normal cupy array" the asarray will fail with a TypeError because cupy does not allow silent copies:

import cupy
import numpy

X = cupy.asarray([1, 2, 3])

# TypeError: Implicit conversion to a NumPy array is not allowed.
numpy.asarray(X)

The errors raised are very specific to CuPy's Array API implementation. Other Array API implementations may result in different errors or silently convert. The Array API spec does not say how numpy.asarray should work. NumPy defines __array__ in their Array API compatible array to make numpy.asarray work.

For example, calling numpy.asarray on an Array API compatible NumPy array, results in a silent conversion to a numpy.ndarray:

import numpy.array_api as np_xp
import numpy

X = np_xp.asarray([1, 2, 3])
X_convert = numpy.asarray(X)

print(type(X_convert), X_convert.dtype)
# <class 'numpy.ndarray'> int64

Note that numpy.array_api arrays follow the Array API spec, while numpy.ndarray does not. In other words, functions in numpy.array_api does not work on numpy.ndarray.

Do you think that in the long run, we could get away without the config flag "set_config(array_api_dispatch=True)" ?

In the long run, I think we can away with not using the flag. If a user wants to opt out of Array API, they can convert their arrays to numpy.ndarray and pass it into scikit-learn.

In the short term, I think we need the flag so users can opt into experimental behavior.

@ogrisel
Copy link
Member

ogrisel commented Feb 1, 2022

In the short term, I think we need the flag so users can opt into experimental behavior.

I 100% agree. Thanks for the detailed answer.

@thomasjpfan
Copy link
Member Author

The fact that numpy and the array API spec diverge on things as unique seems like a bug in my eyes. Is it felt like this by the upstream actors?

From my understanding, many of the choices made in the Array API spec were to drive consensus between the array libraries. In the case of unique, the polymorphic return values were painful for array libraries as described here: data-apis/array-api#212

Likewise, the impossibility to support Dask and JAX is quite a hard setback

For scikit-learn, the biggest blockers are no unique and no boolean indexing as their output shape are data dependent. I can see a way forward to support Dask and JAX as long as we "execute the computational graph when we need to" and somehow re-wrap the result back into their Array API compatible counterparts. In summary, I think we can revisit Dask and JAX in the future, but not block adopting Array API for them.

but we need to recognize that it is going to make coding algorithms significantly harder.

This was my biggest concern when working on the POCs. For me, the developer experience is a mixed bag. Working around the diverging APIs will make developing algorithms harder and maintenance harder. Although, it is nice having stricter type promotion rules forcing us to explicit cast and using a subset of NumPy functions makes the code base a little more consistent.

The reasonable route will probably be that supporting this is not mandatory to add a method, and that we strive to improve as we go.

I agree this is the best way forward.

@amueller
Copy link
Member

amueller commented Feb 10, 2022

@GaelVaroquaux

The fact that numpy and the array API spec diverge on things as unique seems like a bug in my eyes. Is it felt like this by the upstream actors?

Do you mean numpy upstream actors or data api consortium upstream actors, cause they only partially overlap. I am not familiar with the numpy internal discussions, but from a data api consortium perspective, basically people felt that it wouldn't be good to force everybody to repeat what is seen as deficiencies in numpy's API.

There is an in-depth discussion of unique in the thread that @thomasjpfan mentioned, but basically returning different types / different lengths of tuples of arrays depending on keywords is generally seen as bad practice, and it makes static analysis basically impossible.
The variable-length output array is also an issue, but at least there we know the return type.

One thing that I haven't really thought about in depth is whether not having methods is a big deal. Right now that certainly requires some rewriting. Ideally that rewriting would be easy, I don't think we can entirely avoid it with any solution.

However, what I really don't like about the current situation is that we can't just implement everything in the array API and be done with it (after waiting for support for old numpy to drop), but that we likely have to maintain wrappers if we want to rely on functionality that is in numpy but not in the array API, like order and views.

I would really like to end up with a situation where we don't have to maintain separate code-paths for array API and non-array API, without defining our own flavor of the API (which is essentially what the wrapper does).

I think we should figure out which features of numpy we definitely need that are not in the array API, and if we can do without them, and if it's reasonable to add them to the standard, or at least an official flavor of the standard.

@amueller
Copy link
Member

@thomasjpfan For posterity, do you want to edit your summary at the top to explain why we need the configuration during the deprecation phase (to have the user decide between the old behavior of potentially casting to numpy and the new behavior of using the array api if available)?

@amueller
Copy link
Member

Also, for these two algorithms, do we need unique anywhere apart from preprocessing y in LDA? We might be able to support dask and Jax partially by casting to numpy for that and doing the rest in the array library?

@asmeurer
Copy link

For example, calling numpy.asarray on an Array API compatible NumPy array, results in a silent conversion to a numpy.ndarray:

I'm curious, do you see this as a good thing or a bad thing? I added __array__ to the numpy.array_api Array object as a convenience (it isn't required by the spec), but I can see how this does somewhat break with the "strict spec compliance" convention that numpy.array_api follows everywhere else.

Note that the more general question of how to wrap and unwrap input arrays into array API compliant objects is something that we're discussing for the spec.

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Feb 10, 2022 via email

@asmeurer
Copy link

asmeurer commented Feb 10, 2022

@GaelVaroquaux the array API spec adds completely new functions for getting unique values: unique_all, unique_counts, unique_inverse, and unique_values (see https://data-apis.org/array-api/latest/API_specification/set_functions.html). This should make it possible to add these functions to NumPy without affecting the compatibility of the existing unique (which is not present in the spec). I don't know if anyone has proposed this yet. I personally agree that it would be great if the main NumPy namespace eventually converged to the array API, at least in the places where it wouldn't require major compatibility breaks.

@thomasjpfan
Copy link
Member Author

I'm curious, do you see this as a good thing or a bad thing? I added array to the numpy.array_api Array object as a convenience (it isn't required by the spec), but I can see how this does somewhat break with the "strict spec compliance" convention that numpy.array_api follows everywhere else.

@asmeurer I'm overall +0.5 on it. With __array__ in numpy.array_api Array using numpy.asarray works, but it could be confusing for those who are adopting the API. One may expect that cupy.asarray would also work on the cupy.array_api Array.

I personally agree that it would be great if the main NumPy namespace eventually converged to the array API

Yes, if the NumPy namespace can adopt the all Array API functions, it will become easier for scikit-learn to adopt the spec.

I agree with @amueller that writing code against the Array API is the preferred end state where we do not have to maintain separate code-paths for array API and non-array API.

@asmeurer
Copy link

One may expect that cupy.asarray would also work on the cupy.array_api Array.

Perhaps cupy just hasn't backported numpy/numpy#20527. CuPy generally tries to match NumPy exactly wherever possible. In general, though, the "proper" way to wrap/unwrap is something we're still discussing, as I mentioned.

Now calling numpy.asarray on a cupy array (or cupy.array_api array) is a different story. That's entirely up to CuPy if it wants to support that. The spec itself is silent on how different libraries interface with one another, other than explicit interchange via DLPack https://data-apis.org/array-api/latest/design_topics/data_interchange.html. In general libraryA.func(libraryB.array(...)) should not be expected to work.

@rgommers
Copy link
Contributor

My point was indeed that the current status of the ecosystem puts us in an uncomfortable situation. A more comfortable one would be if a compromise could be reached that would enable numpy to implement the array API, so that we could code for the array API.

I personally agree that it would be great if the main NumPy namespace eventually converged to the array API, at least in the places where it wouldn't require major compatibility breaks.

Yes, if the NumPy namespace can adopt the all Array API functions, it will become easier for scikit-learn to adopt the spec.

This makes a lot of sense, and I think it's feasible. For context: immediately having NumPy support the array API standard in its main namespace was the initial goal when we started writing NEP 47. There were a few incompatible behaviors in the ndarray object that made this hard in the short term (casting related for example), so we reluctantly switched to a separate numpy.array_api namespace. However, we should now revisit making the main namespace as compatible as possible (new functions like the unique_* ones can be easily added for example). And longer-term, the behaviors in the array API are preferred also for numpy.ndarray and we could over time get to full or almost-full compatibilty. For example, one key issue is value-based casting - and there's now an experimental effort to try and get rid of that. It'll be painful and take quite a while, but it should be doable.

@ogrisel
Copy link
Member

ogrisel commented Feb 14, 2022

For example, one key issue is value-based casting - and there's now an experimental effort to try and get rid of that. It'll be painful and take quite a while, but it should be doable.

Do you have pointers for this specific point?

@seberg
Copy link
Contributor

seberg commented Feb 14, 2022

However, we should now revisit making the main namespace as compatible as possible

Sorry to barge in. I do not disagree with bending NumPy towards that slowly and carefully. I have always believed in making NumPy as consistent and clear as possible as a way to ease transitions to any future implementation or alternatives. However, I disagree that NumPy should do this for the sake of allowing a pattern of try: ns = get_array_module(x); except: ns = np.

The ndarray object itself is a bit of a problem. The main point here being the value-based promotion rules, yes. But let me point out that for any implementation where you fall back to plain np, you also simply choose to ignore this subtle difference.

However, I find it irritating to say that bending the NumPy main namespace is of any importance. We have ndarray.__array_namespace__ and np.array_api. We have a defined way to get a "compatible namespace" without any need for modifications to the NumP main namespace! It was specifically designed that way to allow it to differ?!
(Yes np.array_api is not in itself useful due to the fact that it doesn't work for ndarray, but that is a choice that could differ fro ndarray.__array_namespace__ and frankly, it is a choice we could even just delete and replace.)

Yes, for the ndarray object, there are probably reasons for bending (which means the promotion rules and possible adding two awkward attributes). For the NumPy namespace? Sorry, but no, there is no reason to double down on that. In initial discussion, there were some discussions about helpers to allow working around the ndarray object limitations, but I guess that idea died?

@rgommers
Copy link
Contributor

Do you have pointers for this specific point?

https://discuss.scientific-python.org/t/poll-future-numpy-behavior-when-mixing-arrays-numpy-scalars-and-python-scalars/202/13 is a recent poll by @seberg which is relevant. Overall it's a thing that's now possible because of all the dtype-related infrastructure/feature work by @seberg. I don't think there's a single tracking issue, but it's been discussed in community meetings.

@rgommers
Copy link
Contributor

However, I find it irritating to say that bending the NumPy main namespace is of any importance.
...
Sorry, but no, there is no reason to double down on that.

This was the initial goal when we started to write NEP 47, and it seems that it would have major practices advantages - which is why multiple scikit-learn maintainers comment on this point. So asserting that there's no reason is kind of odd. I'll take the rest offline and over to the NumPy community meeting / mailing list, which is where this discussion belongs.

@jakirkham
Copy link
Contributor

For those poking at Array API adoption in scikit-learn, am curious if random number generation is coming up and if so what use cases have emerged from that. Is there value in a standard random generation API and if so what things in particular are expected?

For example recall in this context ( #14036 ), there were a few things that cropped up. In particular this type check ( #14036 (comment) ), which is probably motivated by some expectations of a random generation API. However a few years have passed and am sure many changes. So maybe totally different things would come up today as did then.

@ogrisel
Copy link
Member

ogrisel commented May 13, 2022

I haven't had the time to dig deeper but I believe it will be a limitation for some estimators. So indeed we might need to make check_random_state Array API aware if RandomState becomes part of the API spec.

However scikit-learn is still using the legacy RNG API via the numpy.random.RandomState class instead of the new API and implementation, so if an RNG API gets standardized it's likely to be the new API.

But then, the statefulness of numpy RNG objects might be a problem for other adopters of the Array API spec. JAX in particular uses only stateless RNG which I find really neat (especially when designing parallel code) but that imposes a completely different API that only JAX supports at the moment.

In either cases, that will cause a lot of changes in scikit-learn but maybe we can maintain a backward compat API adapter.

@thomasjpfan
Copy link
Member Author

For random state, there has been PRs to move toward adapting NumPy's Generator API: #22327, #22271. For the most part the RandomState object and the Generator object have similar interfaces. The last remaining item are random integers, which would look something like SciPy's rng_integer.

When scikit-learn supports NumPy's generator API and if Array API uses NumPy's Generator API as the standard, then scikit-learn should work out of the box.

@ogrisel
Copy link
Member

ogrisel commented May 13, 2022

When scikit-learn supports NumPy's generator API and if Array API uses NumPy's Generator API as the standard, then scikit-learn should work out of the box.

Thanks for the heads up. Indeed that would be nice.

What would be even nicer would be the Array API to also spec an alternative stateless RNG API and to use that one in scikit-learn instead of the hacks we do with re-seeding in warm-started parallel meta-estimators.

@NicolasHug
Copy link
Member

What would be even nicer would be the Array API to also spec an alternative stateless RNG API and to use that one in scikit-learn instead of the hacks we do with re-seeding in warm-started parallel meta-estimators.

If you're interested @ogrisel, I tried to summarize all the problems that our stateful RNG is causing in https://github.com/NicolasHug/enhancement_proposals/blob/random_state_v2/slep011/proposal.rst#issues-with-the-legacy-design.

@rgommers
Copy link
Contributor

@ogrisel I'm not sure I understand what you mean by a stateless RNG. Do you have a concrete example, in Python or elsewhere?

Re the hacks: with the new numpy.random API, couldn't you instead use the jumping the BitGenerator state feature instead of paying the cost of actually drawing as many random numbers as the non-warm-start version would have used?

@ogrisel
Copy link
Member

ogrisel commented May 13, 2022

Indeed, explicit jumping would probably be cleaner.

But other than this specific code, I believe that the JAX stateless RNG API makes it easier to reason about RNG, especially when some function calls are executed in parallel and it does not matter if such parallelism is thread-based (with shared memory) or processed-based or host based (no shared RNG instance) as in any case, the code is forced to be written in such a way to deal with immutable RNG instances: advancing the RNG (that is generating random numbers) creates a child RNG instance while the parent is still in its original state.

@jakirkham
Copy link
Contributor

Gotcha so it is kind of like how Haskell works with Monads then. Namely there is a state of RNG that is used as input and a state of RNG taken as output. So one needs to keep threading this state through.

That said, even with a stateful API, it should be possible for consumers to do this kind of operation themselves. Namely export state after each random number call, track it, and feed it into subsequent random number generation calls.

@rgommers
Copy link
Contributor

Thanks, that helps. Some thoughts after reviewing the JAX API:

  • It's not a stateless PRNG (there's no such thing), the state just has to be manually passed in by the user as the first argument (key).
  • It seems a little clumsier in single-threaded mode, because it takes two function calls instead of one to draw a second (set of) random number(s).
  • The multi-threaded behavior seems to be why keys must be manually split. NumPy's new API has those features as well. There's a very good explanation comparing the JAX API vs. the parallel NumPy API here: DOC: better document the spawn interface, compare and contrast it to Jax's "split" numpy/numpy#15656 (comment). It also touches on the difference between the two counter-based PRNGs (Threefly for JAX, Philox for NumPy - both from the same paper).
  • It looks like the discussion in scikit-learn around the numpy.random design is for the legacy API only. The new (Generator) APIs should do whatever you need?

@ogrisel
Copy link
Member

ogrisel commented May 14, 2022

Thanks for the details on the splittable capabilities of the new numpy RNG API. I was not aware of this and the issue you linked is very interesting.

I think we should definitely move to the new numpy API in scikit-learn and try to leverage the splitting capabilities to make the parallel code naturally work similarly in sequential mode, with thread-based parallelism and with process-based parallelism as documented in https://numpy.org/devdocs/reference/random/parallel.html .

@jakirkham
Copy link
Contributor

This looks useful as well

https://numpy.org/devdocs/reference/random/multithreading.html

FWIW we are going through this upgrade in Dask too if that's helpful

dask/dask#9038

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Discussion
Development

No branches or pull requests

10 participants