-
Notifications
You must be signed in to change notification settings - Fork 407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] JAX integration #97
Comments
Haha @nicholasjng awesome to read this request. Love it. Will notify this thread as soon as we're ready for that undertaking! JAX is an awesome idea and thanks for the request. Lets do it 💪 |
Closing due to inactivity. We are migrating such issues to the roadmap for further voting :-) |
Sorry to necro an old thread but I tried looking through the road map and could not find it. What's the status of this? Canceled, on hold, or completed? Thank you :) |
We didn't find enough demand so had to prioritize other things for now. However, I can resurface it if you're interested. Contributions are welcome here, what would a JAX integration look like? Similar to the tensorflow one, i.e., having the ability to pass (materialize) JAX models through ZenML pipelines? |
AFAIK there is no builtin way to load/save JAX models, as it just contains the mathematical machinery for applying transformations and differentiating. For my own uses, I did export some models into HDF5, but that has little to do with JAX. It's possible that first-party NN-libraries (flax, haiku, etc.) have some machinery for it, though. I think a small example on how to load/train/save a model might suffice, what do you have to implement for that? I'll see if I can come up with something if you point me to the necessary components :) |
@nicholasjng So sorry for the late reply , this slipped through. Happy to receive your contribution! A good place to start would be to see the guide to add your own example. IMO, the example would resemble something like the LightGBM example. You might need to implement a custom materializer to get this to work I think this is a good starting point for JAX integration. WDYT? |
Totally. I got some time on the weekend, happy to take a look then. It may or may not be as straightforward for me though, depending on the usability of ZenML on M1. I'll get it done! |
M1 is a bit problem unfortunately. Let me know if it doesnt work though! |
Salam / merhabalar friends,
Is your feature request related to a problem? Please describe.
Let's bring Google's JAX to ZenML!
Describe the solution you'd like
I would like to build an example (what exactly the example is about is TBD at this point) ZenML+JAX project - if it's cloud-ready, all the better (although JAX on GCP has some sharp edges as far as I understood).
Ideally, this could be accomplished by only a NumPy datasource plus a JAX trainer class, but that is a first hunch - let's hope karma does not strike me for this one.
Additional context
Admittedly I am still in March with my mental model of a lot of ZenML's designs, so I will need to spend some time to go through the newer concepts. When I'm ready and made progress, I'll submit a PR with the aforementioned JAX example.
Also, I might have to unpin some requirements to get stuff to build from source (Apple M1), due to the present lack of wheel support (think JAX itself, scipy, or pandas) - I'll check in and document what works here (or just go to Linux instead).
What do you think?
The text was updated successfully, but these errors were encountered: