You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi! I love the possiblities of quax and would like to use it for a unit system in my acustical wave simulation. For this, it is necessary to register a function for jax.lax.while_p, since the simulation runs for many thousands of steps in a while loop. I was wondering if you could give some tips for the implementation. An MWE of my current attempt looks something like this:
import jax
import jax.numpy as jnp
import quax
import jax.core as core
class ArrayWrapper(quax.ArrayValue):
array: jax.Array
def aval(self):
return core.ShapedArray(self.array.shape, self.array.dtype)
def materialise(self) -> jax.Array:
raise ValueError("Refusing to materialise")
@quax.register(jax.lax.while_p)
def _(*args, cond_nconsts: int, cond_jaxpr, body_nconsts: int, body_jaxpr):
new_args = [a.array if isinstance(a, ArrayWrapper) else a for a in args]
is_quaxed = [isinstance(a, ArrayWrapper) for a in args]
out = jax.lax.while_p.bind(
*new_args,
cond_nconsts=cond_nconsts,
cond_jaxpr=cond_jaxpr,
body_nconsts=body_nconsts,
body_jaxpr=body_jaxpr,
)
return [
ArrayWrapper(a) if quaxed else a
for a, quaxed in zip(out, is_quaxed[cond_nconsts+body_nconsts:])
]
def body_fn(a: jax.Array):
return a + 1
def cond_fn(a: jax.Array):
return a[1] < 10
def loop_fn(a: jax.Array):
res = jax.lax.while_loop(
body_fun=body_fn,
cond_fun=cond_fn,
init_val=a,
)
return res
a = ArrayWrapper(jnp.arange(10))
a = jax.jit(quax.quaxify(loop_fn))(a)
print(a.array)
Even though this code executes, it is far from optimal. Since the body_fn and cond_fn are no longer quaxed, all advantages of the unit system are lost. Especially, the expected bahvior should be that the code above raises an Exception as the primitives for adding, less_than, etc. are not registered. But, the code runs since cond_fn and body_fn are no longer quaxed.
I don't know how one could integrate the quaxed functions into the XLA WhileOp primitive. Do you have any insights on how to achieve this?
Many thanks
Yannik
The text was updated successfully, but these errors were encountered:
So I think for the higher-order primitives (lax.while_loop, lax.cond, ...) then there probably shouldn't need to be an implementation-specific override. I think this should be a thing that Quax just does automatically for all ArrayValues, in the same way that it already handles jax.jit and jax.custom_jvp.
I think an implementation should probably look to the above for guidance, and try to do something that will work without knowing the details of your specific ArrayValues.
I'm afraid this isn't a topic I've thought about harder than this. Quax is a fairly experimental library, and this is the main thing that I haven't implemented.
Thanks for the quick answer! I added a PR with an implementation of the while primitive. This is a general override that does not need to know any implementation-specifics of the user. If you think this implementation looks good, i can also work on an implementation of cond, etc. if i find the time :)
Hi! I love the possiblities of quax and would like to use it for a unit system in my acustical wave simulation. For this, it is necessary to register a function for jax.lax.while_p, since the simulation runs for many thousands of steps in a while loop. I was wondering if you could give some tips for the implementation. An MWE of my current attempt looks something like this:
Even though this code executes, it is far from optimal. Since the body_fn and cond_fn are no longer quaxed, all advantages of the unit system are lost. Especially, the expected bahvior should be that the code above raises an Exception as the primitives for adding, less_than, etc. are not registered. But, the code runs since cond_fn and body_fn are no longer quaxed.
I don't know how one could integrate the quaxed functions into the XLA WhileOp primitive. Do you have any insights on how to achieve this?
Many thanks
Yannik
The text was updated successfully, but these errors were encountered: