Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interest in add array-api module #10

Open
nstarman opened this issue Feb 11, 2024 · 13 comments
Open

Interest in add array-api module #10

nstarman opened this issue Feb 11, 2024 · 13 comments

Comments

@nstarman
Copy link
Contributor

nstarman commented Feb 11, 2024

Based on quax I wrote array-api-jax-compat for use with jax-quantity. Are you interested to upstream array-api-jax-compat as a submodule, e.g. quax.array_api? With the submodule users won't have to quaxif any function in the array-api themselves, just import quax.array_api as xp.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Feb 11, 2024

That's really nice!
So I'd like to avoid adding submodules directly to Quax itself -- unfortunately, I think adding anything here is just going to be too much for me to maintain.

However, once you've got user-facing docs etc., then I would be very happy to add a link to this from the README / advertise it as a downstream library in the Quax docs.

(On an unrelated note, by the way, I just went poking through jax-quantity and I can see you have magic methods like Quantity.__add__ etc. Heads-up that these were removed from Quax in the most recent hopefully-now-stable release. Unfortunately I couldn't see a way to keep these and have things remain consistent.)

@nstarman
Copy link
Contributor Author

However, once you've got user-facing docs etc., then I would be very happy to add a link to this from the README / advertise it as a downstream library in the Quax docs.

Sounds good, thanks!

On an unrelated note, by the way, I just went poking through jax-quantity and I can see you have magic methods like Quantity.add etc. Heads-up that these were removed from Quax in the most recent hopefully-now-stable release. Unfortunately I couldn't see a way to keep these and have things remain consistent.

Thanks for the heads up. Is there not a way to make it work, e.g. __eq__ = quaxify(jnp.equal) ?

@patrick-kidger
Copy link
Owner

Is there not a way to make it work, e.g. __eq__ = quaxify(jnp.equal) ?

The problem with doing this

Not that I can see a way to happen cleanly. If we were to go for that approach, then the following:

def foo(x):
    z = Quantity(...)
    return x + z

quaxify(foo)(y)

would actually result in two nested quaxifys! The first on the foo, and the second on the __add__. That's probably not what we wanted!

In terms of what that one end up doing: rather than hitting the dispatch rule for add_p.bind(:Y, :Z) (where Y is the type of y and Z is the type of z), we'd end up double-dispatching: once for add_p.bind(:Array, :Z), and then again using the type Y inside the implementation of that rule.
And if that feels a bit complicated to wrap your head around, then this is the reason this approach got removed: too complicated and too much of a footgun.

What that means is that right now, the expectation is that you will create all your array-ish values outside the quaxify, and then pass them across the boundary to wrap them into JAX tracers.

What might we do about this.

I'm not 100% happy with this approach, as it means that we have to pass all array-ish values as formal arguments. This makes it best suited for the approach of "transform an existing program" rather than "use array-ish values anywhere within a new program". But it was the only one I could find that seemed to be largely footgun free, act in a consistent way, etc.

For what it's worth, it should be possible to write some kind of "quax.wrap_arrayish_value_into_tracer" function that looks for an enclosing quaxify, and then wraps the value it is given with a suitable tracer. That would make the following work, for example:

def foo(x):
    z = Quantity(...)
    z2 = quax.wrap_arrayish_value_into_tracer(z)
    return x + z2

quaxify(foo)(y)

in other words, we get to pretend that we'd passed z across the quaxify boundary.

I don't love this though, as it (a) seems like an easy footgun to forget, (b) leaves the original z still bare, which means someone could monkey with it even thought z2 is the thing that should really be used, and (c) makes most sense when you have a single quaxify wrapper... it's not like we have a jax.please_pretend_i_vmapped_this_value, after all.

I think we're still figuring out what using Quax looks like. Maybe a cleaner way to do things will present itself.

@nstarman
Copy link
Contributor Author

Thanks for the detailed response!

For what it's worth, it should be possible to write some kind of "quax.wrap_arrayish_value_into_tracer" function that looks for an enclosing quaxify, and then wraps the value it is given with a suitable tracer.

I think this would be incredibly useful for Equinox so that the following would work. self sorta passes through the quax boundary, so making this work out of the box feels like less of a footgun. And then __add__ should work, maybe?

class Model(eqx.Module):

    y: Quantity

    @quaxify
    def foo(self, x: Quantity):
        return x + self.y ** 2

@patrick-kidger
Copy link
Owner

FWIW, I think the one you've linked there should work already. All PyTrees (both self and x) are looked over for any quax.Values -- in this case that includes Quantity -- and these are wrapped.

@nstarman
Copy link
Contributor Author

Awesome, then in galax we can now have G as a Quantity. I'm in the process of quantityizing galax.

@nstarman
Copy link
Contributor Author

I have confirmed that wrapping into tracers will be necessary.
In a branch of https://github.com/GalacticDynamics/galax/blob/29e479ef6f4ab35227a66622fe0c8af2d896c61d/src/galax/potential/_potential/builtin.py#L105-L108 I have set it up so that G is a Quantity (from jax_quantity) and self.m(t) and self.c(t) return Quantity as well.

def _potential_energy(self, q: BatchQVec3, /, t: BatchFloatOrIntQScalar) -> BatchFloatQScalar:
        r = xp.linalg.vector_norm(q, axis=-1)
        out = -self._G * self.m(t) / (r + self.c(t))
        return out

Inserting a print(self._G, r, self.m(t), self.c(t)) the output is

Quantity(value=f64[], unit=Unit("kpc3 / (Myr2 solMass2)"))
Traced<ShapedArray(float64[], weak_type=True)>with<_QuaxTrace(level=2/0)> with
  value = Quantity(value=f64[], unit=Unit("solMass"))
Traced<ShapedArray(float64[])>with<_QuaxTrace(level=2/0)> with
  value = Quantity(value=f64[], unit=Unit("kpc"))
Traced<ShapedArray(float64[], weak_type=True)>with<_QuaxTrace(level=2/0)> with
  value = Quantity(value=f64[], unit=Unit("kpc"))

So that _potential_energy returns a

Quantity(
  value=Quantity(value=f64[], unit=Unit("solMass / kpc")),
  unit=Unit("kpc3 / (Myr2 solMass2)")
)

@nstarman
Copy link
Contributor Author

Perhaps a function wrap_arrayish_value_into_tracer_like(arrayish, like_tracer) would be the way to go, since there are number of different tracers that can get created?

@nstarman
Copy link
Contributor Author

Ping @adrn @jnibauer, if y'all are interested.

@nstarman
Copy link
Contributor Author

Perhaps a function wrap_arrayish_value_into_tracer_like(arrayish, like_tracer) would be the way to go, since there are number of different tracers that can get created?

Or is there some way quaxify can auto-handle this for self?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Feb 16, 2024

Perhaps a function wrap_arrayish_value_into_tracer_like(arrayish, like_tracer) would be the way to go, since there are number of different tracers that can get created?

I think I like this! I might even make it a classmethod on Value, so that one can write something like Quantity.quaxify(like_tracer, 5.0, unit=Unit("kpc"))
where the API is

class Value:
    @classmethod
    def quaxify(cls, like_tracer, /, *args, **kwargs):
        ...  # implementation here

Or maybe we really jump off the API deep-end and do something like Quantity[like_tracer](5.0, unit=Unit("kpc")), using the class-level [...] to indicate that we don't want to create this object by directly, but would like to inherit this tracer.

In either case the goal is to help discourage ever creating an unwrapped array-ish value in the first place. I'm not 100% sold on any of these approaches here -- frankly maybe just the quax.wrap_arrayish_value_into_tracer_like free-function approach is the least confusing.

Or is there some way quaxify can auto-handle this for self?

I'm not sure what you're getting with this one I'm afraid -- it will already handle self just like any other argument, by looking through it for any array-ish values.

@nstarman
Copy link
Contributor Author

nstarman commented Feb 17, 2024

Or is there some way quaxify can auto-handle this for self?

I'm not sure what you're getting with this one I'm afraid -- it will already handle self just like any other argument, by looking through it for any array-ish values.

I meant if all the other inputs to the method are tracer objects then it would auto-apply wrap_arrayish_value_into_tracer_like.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Feb 20, 2024

So if you have quaxify on a method:

class Foo:
    @quaxify
    def some_method(self, ...): ...

then it will already wrap all values into tracers. That's what quaxify does anyway! And moreover they will be tracers associated with this new transform.

(It sounds like you're after some other non-quaxify decorator that will use the tracers of a previous quaxify, and special-case self? I don't think I like the idea of special-casing self, to be honest -- basically nothing else ever does this.)

Let me know if you have any opinion on the other options though btw, as I don't currently have strong feelings between them!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants