Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] JAX integration #97

Open
nicholasjng opened this issue Aug 27, 2021 · 8 comments
Open

[FEATURE] JAX integration #97

nicholasjng opened this issue Aug 27, 2021 · 8 comments

Comments

@nicholasjng
Copy link

Salam / merhabalar friends,

Is your feature request related to a problem? Please describe.
Let's bring Google's JAX to ZenML!

Describe the solution you'd like
I would like to build an example (what exactly the example is about is TBD at this point) ZenML+JAX project - if it's cloud-ready, all the better (although JAX on GCP has some sharp edges as far as I understood).

Ideally, this could be accomplished by only a NumPy datasource plus a JAX trainer class, but that is a first hunch - let's hope karma does not strike me for this one.

Additional context
Admittedly I am still in March with my mental model of a lot of ZenML's designs, so I will need to spend some time to go through the newer concepts. When I'm ready and made progress, I'll submit a PR with the aforementioned JAX example.

Also, I might have to unpin some requirements to get stuff to build from source (Apple M1), due to the present lack of wheel support (think JAX itself, scipy, or pandas) - I'll check in and document what works here (or just go to Linux instead).

What do you think?

@htahir1
Copy link
Contributor

htahir1 commented Aug 27, 2021

Haha @nicholasjng awesome to read this request. Love it. Will notify this thread as soon as we're ready for that undertaking! JAX is an awesome idea and thanks for the request. Lets do it 💪

@htahir1
Copy link
Contributor

htahir1 commented Mar 22, 2022

Closing due to inactivity. We are migrating such issues to the roadmap for further voting :-)

@htahir1 htahir1 closed this as completed Mar 22, 2022
@IanQS
Copy link

IanQS commented May 29, 2022

Sorry to necro an old thread but I tried looking through the road map and could not find it. What's the status of this? Canceled, on hold, or completed?

Thank you :)

@htahir1
Copy link
Contributor

htahir1 commented May 29, 2022

We didn't find enough demand so had to prioritize other things for now. However, I can resurface it if you're interested. Contributions are welcome here, what would a JAX integration look like? Similar to the tensorflow one, i.e., having the ability to pass (materialize) JAX models through ZenML pipelines?

@htahir1 htahir1 reopened this May 29, 2022
@nicholasjng
Copy link
Author

AFAIK there is no builtin way to load/save JAX models, as it just contains the mathematical machinery for applying transformations and differentiating. For my own uses, I did export some models into HDF5, but that has little to do with JAX.

It's possible that first-party NN-libraries (flax, haiku, etc.) have some machinery for it, though. I think a small example on how to load/train/save a model might suffice, what do you have to implement for that? I'll see if I can come up with something if you point me to the necessary components :)

@htahir1
Copy link
Contributor

htahir1 commented Jun 8, 2022

@nicholasjng So sorry for the late reply , this slipped through. Happy to receive your contribution!

A good place to start would be to see the guide to add your own example. IMO, the example would resemble something like the LightGBM example.

You might need to implement a custom materializer to get this to work

I think this is a good starting point for JAX integration. WDYT?

@nicholasjng
Copy link
Author

Totally. I got some time on the weekend, happy to take a look then. It may or may not be as straightforward for me though, depending on the usability of ZenML on M1. I'll get it done!

@htahir1
Copy link
Contributor

htahir1 commented Jun 9, 2022

M1 is a bit problem unfortunately. Let me know if it doesnt work though!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants