Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for JAX/XLA/Trax? #332

Open
VivekPanyam opened this issue Apr 8, 2020 · 2 comments
Open

Add support for JAX/XLA/Trax? #332

VivekPanyam opened this issue Apr 8, 2020 · 2 comments

Comments

@VivekPanyam
Copy link
Collaborator

A few months ago, I noticed that JAX (https://github.com/google/jax) and Trax (https://github.com/google/trax) have been getting more popular.

JAX functions which are compiled (https://github.com/google/jax#compilation-with-jit) can be turned into an XLA HLO proto (see google/jax#1871) which can be run from C++

Trax can use TF, Numpy, or JAX under the hood so I don't think we need to do much additional work to add support for it.

Concretely, we'd need to add a backend for XLA and packagers for JAX and Trax

@VivekPanyam
Copy link
Collaborator Author

Note: Flax (https://github.com/google/flax) is another DL library built on top of JAX

@vkuzmin-uber
Copy link
Contributor

Is it right understanding that this is GPU/TPU optimization only?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants