Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented jax.lax.while primitive #16

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ymahlau
Copy link

@ymahlau ymahlau commented Apr 20, 2024

This PR adds an implementation of the jax.lax.while primitive

Closes #15

@patrick-kidger
Copy link
Owner

Huh! This... actually seems really simple. I was not expecting it to be this easy.

I think to merge this I would like to see some tests added, if that's okay? The ones that jump out are:

  • a basic test with quaxified carry.
  • something where the cond_fun and body_fun close over additional constant values.
  • combining this with jit/grad/vmap.

@ymahlau
Copy link
Author

ymahlau commented May 4, 2024

I agree that some test cases are necessary. To this end, I extended the unitful example and added it to the examples folder to have some wrapper that allows for easy testing. Then I added some test cases including constant values, jit and vmap.

@patrick-kidger
Copy link
Owner

Nice, thank you! My comments on this are:

  1. it looks like the pre-commit hooks are failing (formatting etc.). These should be easy enough to fix: pre-commit install; pre-commit run --all-files; git add -u; git commit.
  2. I have a very specific request: can you add some a test in which we (a) close over a variable in the body function whilst simultaneously (b) differentiating with respect to it? This is something which I have come to know can easily go wrong... so I'd like to be sure we get it right!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How to implement jax.lax.while with quax
2 participants