Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to implement jax.lax.while with quax #15

Open
ymahlau opened this issue Apr 19, 2024 · 2 comments · May be fixed by #16
Open

How to implement jax.lax.while with quax #15

ymahlau opened this issue Apr 19, 2024 · 2 comments · May be fixed by #16

Comments

@ymahlau
Copy link

ymahlau commented Apr 19, 2024

Hi! I love the possiblities of quax and would like to use it for a unit system in my acustical wave simulation. For this, it is necessary to register a function for jax.lax.while_p, since the simulation runs for many thousands of steps in a while loop. I was wondering if you could give some tips for the implementation. An MWE of my current attempt looks something like this:

import jax
import jax.numpy as jnp
import quax
import jax.core as core

class ArrayWrapper(quax.ArrayValue):
    array: jax.Array

    def aval(self):
        return core.ShapedArray(self.array.shape, self.array.dtype)

    def materialise(self) -> jax.Array:
        raise ValueError("Refusing to materialise")
        
@quax.register(jax.lax.while_p)
def _(*args, cond_nconsts: int, cond_jaxpr, body_nconsts: int, body_jaxpr):
    new_args = [a.array if isinstance(a, ArrayWrapper) else a for a in args]
    is_quaxed = [isinstance(a, ArrayWrapper) for a in args]
    out = jax.lax.while_p.bind(
        *new_args, 
        cond_nconsts=cond_nconsts,
        cond_jaxpr=cond_jaxpr,
        body_nconsts=body_nconsts,
        body_jaxpr=body_jaxpr,
    )
    return [
        ArrayWrapper(a) if quaxed else a
        for a, quaxed in zip(out, is_quaxed[cond_nconsts+body_nconsts:])
    ]

def body_fn(a: jax.Array):
    return a + 1

def cond_fn(a: jax.Array):
    return a[1] < 10

def loop_fn(a: jax.Array):
    res = jax.lax.while_loop(
        body_fun=body_fn,
        cond_fun=cond_fn,
        init_val=a,
    )
    return res

a = ArrayWrapper(jnp.arange(10))
a = jax.jit(quax.quaxify(loop_fn))(a)
print(a.array)

Even though this code executes, it is far from optimal. Since the body_fn and cond_fn are no longer quaxed, all advantages of the unit system are lost. Especially, the expected bahvior should be that the code above raises an Exception as the primitives for adding, less_than, etc. are not registered. But, the code runs since cond_fn and body_fn are no longer quaxed.

I don't know how one could integrate the quaxed functions into the XLA WhileOp primitive. Do you have any insights on how to achieve this?

Many thanks
Yannik

@patrick-kidger
Copy link
Owner

So I think for the higher-order primitives (lax.while_loop, lax.cond, ...) then there probably shouldn't need to be an implementation-specific override. I think this should be a thing that Quax just does automatically for all ArrayValues, in the same way that it already handles jax.jit and jax.custom_jvp.

I think an implementation should probably look to the above for guidance, and try to do something that will work without knowing the details of your specific ArrayValues.

I'm afraid this isn't a topic I've thought about harder than this. Quax is a fairly experimental library, and this is the main thing that I haven't implemented.

@ymahlau ymahlau linked a pull request Apr 20, 2024 that will close this issue
@ymahlau
Copy link
Author

ymahlau commented Apr 20, 2024

Thanks for the quick answer! I added a PR with an implementation of the while primitive. This is a general override that does not need to know any implementation-specifics of the user. If you think this implementation looks good, i can also work on an implementation of cond, etc. if i find the time :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants