Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some custom primitives does not work with the newest version of jax #112

Closed
3 tasks
fehiepsi opened this issue Apr 17, 2019 · 2 comments · Fixed by #132
Closed
3 tasks

Some custom primitives does not work with the newest version of jax #112

fehiepsi opened this issue Apr 17, 2019 · 2 comments · Fixed by #132
Labels
bug Something isn't working

Comments

@fehiepsi
Copy link
Member

fehiepsi commented Apr 17, 2019

The error happens because new version of jax requires a specific pattern for using jit's static_argnums. See google/jax#595 for more context. The failed functions include:

  • xlogy
  • xlog1py
  • standard_gamma

In addition, I think that we can simplify the implementation of these functions by using decorator custom_transform as in cumsum, cumprod.

cc @neerajprad

@fehiepsi fehiepsi added the bug Something isn't working label Apr 17, 2019
@neerajprad
Copy link
Member

Are these functions failing with the jax master? We aren't using jit at least in xlogy and xlog1py, so that doesn't seem to immediately conflict with the JAX change. It also seems like a bunch of functions have been reorganized, so its probably worth tackling in one shot with the next jax release?

@fehiepsi
Copy link
Member Author

Yes, they fail with jax master. Maybe due to another reason. Yes, I also think that we just need to address these kinds of issues with one shot before our alpha release. It is almost there I guess. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants