-
-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
quaxify
on a jax.grad
#5
Comments
So I think the error here is that youv'e written That is, In particular you should never have to mess around with things like Regarding |
Thanks! I had originally tried Assuming #6, consider import jax
import jax.numpy as jnp
from jax_quantity import Quantity
from quax import quaxify
jax.config.update("jax_enable_x64", True)
x = jnp.array([1, 2, 3], dtype=jnp.float64)
q = Quantity(x, unit="m")
def func(q: Quantity) -> Quantity:
return 5 * q**3
out = quaxify(jax.grad(func))(q[0])
out.value, out.unit
> (Array(15., dtype=float64), Unit("m2")) This works perfectly! 🎉 The problem arises when def func(q: Quantity) -> Quantity:
return 5 * q**3 + Quantity(jnp.array(1.0), "m3")
out = quaxify(jax.grad(func))(q[0])
> ValueError: Cannot add a non-quantity and quantity. This is my error message in @register(lax.add_p)
def _add_p_vq(x: DenseArrayValue, y: Quantity) -> Quantity:
# x = 0 is a special case
if jnp.array_equal(x, 0):
return y
# otherwise we can't add a quantity to a normal value
raise ValueError("Cannot add a non-quantity and quantity.") What appears to be happening is that def func(q: Quantity) -> Quantity:
jax.debug.print("q {}, {}, {}", type(q), type(q.primal), type(q.primal.value))
jax.debug.print("5 q ** 3: {}", type((5 * q**3).primal.value))
return 5 * q**3 + Quantity(1.0, unit="m3")
out = quaxify(jax.grad(func))(q[0])
> q <class 'jax._src.interpreters.ad.JVPTracer'>, <class 'quax._core._QuaxTracer'>, <class 'jax_quantity._core.Quantity'>
> 5 q ** 3: <class 'jax_quantity._core.Quantity'> If @register(lax.add_p)
def _add_p_qq(x: Quantity, y: Quantity) -> Quantity:
return Quantity(lax.add(x.to_value(x.unit), y.to_value(x.unit)), unit=x.unit) However, this doesn't appear to be the case. |
Sorry for the noise. constant = Quantity(1.0, unit="m2")
@quaxify
def func(q: Quantity) -> Quantity:
jax.debug.print("5 q ** 3: {}", (5 * q**3).value)
jax.debug.print("c * q: {}", type((constant * q).value))
return jax.lax.add(5 * q**3, constant * q)
out = func(q[0])
out.value, out.unit
> 5 q ** 3: Quantity(value=f64[], unit=Unit("m3"))
> c * q: <class 'quax._core._QuaxTracer'>
> UnitConversionError: ... volume and area can't be added # (my words) So the Quantity multiplication is strange, with the value being a QuaxTracer and the resulting unit being m^2, not m^3. |
Hmm, it's not clear to me if you have a question or if you've resolved it? FWIW, I will comment that you are doing something that I'm still not completely happy with Quax's behaviour for, which is returning custom array-ish values from primitive rules. This is fine when working with a function But it's not really clear how it should be defined for something like Problem 1: handling nested values For example maybe the That's now pretty weird: when you wrote not a Problem 2: easy bugs in implementations As another example of the problems this causes, consider this line: in which you've written Right now I'm still thinking about a plan to fix this. I'm not sure exactly what that will be yet, though! So I'd welcome any thoughts on what might be a nice way to do this. The end goal will probably be to end up with something Julia-like, where you can just nest array-ish values freely without worrying too much; possibly this can be accomplished by always putting Quax at the bottom of the interpreter stack. Details very much TBD; you have been warned :) |
you're right either way 😆. I've half-solved the problem. Functions that don't internally construct and operate on a Quantity work great. Functions that do still raise an error.
Thanks! I'll change my annotations to |
@patrick-kidger, I found the root of the problem and have a question about how to best resolve the issue. Consider this MWE (a simplified form of from typing import Self
import jax
import quax
from jax import lax
class MyArray(quax.ArrayValue):
value: jax.Array
unit: str
@property
def shape(self) -> tuple[int, ...]:
"""Shape of the array."""
return self.value.shape
def materialise(self) -> None:
raise RuntimeError("Refusing to materialise `MyArray`.")
def aval(self) -> jax.core.ShapedArray:
return jax.core.get_aval(self.value)
@quax.register(lax.mul_p)
def _(x: MyArray, y: MyArray) -> MyArray:
unit = f"{x.unit}*{y.unit}"
return MyArray(lax.mul(x.value, y.value), unit=unit) If I define a function def func(q: MyArray) -> MyArray:
c = MyArray(2.0, unit="m2")
jax.debug.print("mutiplying {} * {}", type(c), type(q))
return c * q
x = MyArray(1.5, unit="m")
out = func(x)
> mutiplying <class '__main__.MyArray'> * <class '__main__.MyArray'>
> MyArray(value=f64[], unit='m2*m') However if this function is quaxfunc = quax.quaxify(func)
out = quaxfunc(x)
> mutiplying <class '__main__.MyArray'> * <class 'quax._core._QuaxTracer'> The issue is that I can hack around this by defining @quax.register(lax.mul_p)
def _(x: MyArray, y: quax.DenseArrayValue) -> MyArray:
assert isinstance(y.array, quax._core._QuaxTracer)
assert isinstance(y.array.value, MyArray)
unit = f"{x.unit}*{y.array.value.unit}"
return MyArray(lax.mul(x.value, y.array.value.value), unit=unit) I don't see a comparable scenario in |
This looks expected. When you wrap jaxpr = jax.make_jaxpr(func)(x)
print(jaxpr) this defines a sequence of primitive operations, turning input into output. Then the Quaxified version -- So in this case, the Here, what you probably want to do is just not apply the (FWIW this is a complexity I'm hoping will be tidied up when I rewrite things a bit, as per my previous comment.) |
Unfortunately since that function was a MWE of
Thanks for the diagnosis! what is the proper way to get the import jax
import jax.numpy as jnp
from jax_quantity import Quantity
from quax import quaxify
jax.config.update("jax_enable_x64", True)
x = jnp.array([1, 2, 3], dtype=jnp.float64)
q = Quantity(x, unit="m")
def func(q: Quantity) -> Quantity:
return Quantity(2.0, unit="m2") * q
out = func(q[0])
print(out.value, out.unit)
> (Array(2., dtype=float64), Unit("m3"))
out = quaxify(jax.grad(func))(q[2])
print(out.value, out.unit)
> TypeError: Gradient only defined for scalar-output functions. Output was Quantity(value=f64[], unit=Unit("m3")). I think the way I deconstructed |
Okay, returning to this issue! It has not been forgotten about... As above, this discussion has prompted me to do some harding thinking on the design choices for Quax. The just-released v0.0.3 release should hopefully straighten things out a bit. It's very much a breaking release (!), but if it seems to work then hopefully we can standardise on this, and start building libraries on top of Quax in earnest. With respect to topics discussed in this issue:
|
Thanks! |
In GalacticDynamics/unxt#4 I'm trying to get
jax.grad
to work on functions that acceptQuantity
arguments, and have run into some difficulties.The following doesn't work,
returning an error
TypeError: Gradient only defined for scalar-output functions. Output was Quantity(value=f64[], unit=Unit("m2")).
This error was expected sincegrad
checks for scalar outputs (withjax._src.api._check_scalar
). The underlying issue appeared to be that_check_scalar
callsconcrete_aval
, which errors onQuantity
.quax
-compatible classes have anaval()
method so I hooked that up to a handler and registered it intopytype_aval_mappings
While this gets a few lines further in
grad
, unfortunately this causes a disagreement between pytree structureswith the error
TypeError: Tree structure of cotangent input PyTreeDef(*), does not match structure of primal output PyTreeDef(CustomNode(Quantity[('value',), ('unit',), (Unit("m2"),)], [*])).
.I haven't figured out how to fix this issue. Any suggestions would be appreciated!
p.s. @dfm has figured out how to do
grad
onQuantity
injpu
by shunting the units toaux
data and re-assembling after. This solution works well, but it's a solution unique to Quantity, requiring a customgrad
function. I was hoping to get this working withquaxify
in a way that didn't require in https://github.com/GalacticDynamics/array-api-jax-compat dispatching usingplum
to library-specificgrad
implementations (especially since it's not obvious on what to dispatch to mapfunc
to theQuantity
implementation).The text was updated successfully, but these errors were encountered: