Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks. #139

Merged
merged 90 commits into from Jan 31, 2021

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Jan 13, 2021

As noted below, this PR contains the following features:

  • It turns Elegy into a framework agnostic library by removing the dependencies between elegy.Model and elegy.Module, it proposes the GeneralizedModule API and implements it for Flax, Haiku, Elegy Module types, and regular python functions.
  • It introduces a new low-level API similar to Pytorch Lightning that lets users manually override the core parts of the training loop when maximal flexibility is required.
  • General changes that enable the framework-agnostic mindset.
  • Many quality of life changes like standardization of hooks, simplification of the Module system, etc.

Tasks:

  • Create hooks module
  • Refactor Model with low-level API and remove Module dependencies
  • Refactor Module to use hooks
  • Create GeneralizedModule and GeneralizedOptimizer Inferfaces
  • Implement GeneralizedModule for flax.linen.Module
  • Implement GeneralizedModule for elegy.Module
  • Implement GeneralizedModule for haiku.Module
  • Implement GeneralizedOptimizer for optax.GradientTransformation
  • Implement GeneralizedOptimizer for elegy.Optimizer
  • Fix Model.summary
  • Fix tests
  • Fix examples
  • Fix README
  • Fix guides
  • Fix docstrings

@cgarciae cgarciae marked this pull request as draft January 13, 2021 13:35
@alexander-g
Copy link
Contributor

Care to explain? Will this break the existing API?

@cgarciae
Copy link
Collaborator Author

cgarciae commented Jan 15, 2021

Hey @alexander-g, I was going to tag you here with some info once I got an MVP but here is what is happening:

I've been thinking a lot about both #138 and #128, and more generally on how to make Elegy more appealing to the research-oriented Jax ecosystem.

Current proposal:

  1. Implement a Pytorch Lightning / Keras low-level-like API that lets you have full control when needed. The API consistes of 4 methods you can override: init, pred_step, test_step, train_step. Check out their signatures here:
  2. Make Elegy "Framework Agnostic". This means that:
    • Low-level training code should be compatible with most Module frameworks like Flax / Haiku, more generally, you should be able to use pure Jax if so desired.
    • High-level code (the main Keras API) should be able to support Modules from most frameworks for the main network architecture, losses and metrics.
    • Due to Jax's nature, the last two points work better with frameworks that have functional interfaces.

Pros:

  • It is more compatible from Elegy's original intent of being a "Trainer Interface".
  • By being more framework agnostic users can now leverage all the cool stuff being developed by the broader ecosystem and we also have more potential users.
  • We can focus more on the high-level API and leave the improvements of the Module System to other libraries.

Cons:

  • We have to modify Elegy's module system to provide a functionally pure interface again (init, apply).
  • We have to generalize and simplify our hooks system, specially, add_summary should take its path as input instead of calculating it.
  • Right now It is not clear how to implement Model.summary in a generic fashion.

Pointers:

  • Check out this example on how to create a low-level fit-only Model using pure Jax: examples/logistic_regression.py
  • Check out this test that uses a flax.linen.Module inside elegy.Model and successfully runs predict.

Open Questions:

  • Should we gradually adopt a 3rd party Module system for the implementation of our metrics and applications? It would be nice if we can work together with that other library, the Flax team seams pretty open and responsive.
  • Should we sacrifice some feature like Model.summary for the sake of compatibility? Maybe we can find a way to fix it but even if it losses some info I think being framework agnostic will increase Elegy's impact.

@cgarciae
Copy link
Collaborator Author

@alexander-g @charlielito

What do you think about this? Very opened to suggestions.

@alexander-g
Copy link
Contributor

It sounds promising and I support this direction.
However, for me personally, Module.summary, or rather Module.slice which depends on it is quite important. Being able to extract specific outputs from a module quickly is a big plus of Keras compared to PyTorch for me. In PyTorch you have to rewrite the forward() method and know where to insert relu or pooling functions which is annoying.
If we can find another method to do this that would be perfectly fine with me. One might experiment with jax.named_call to group jax functions and then analyze the resulting jaxprs. But that's very hypothetical for now.

@cgarciae
Copy link
Collaborator Author

@alexander-g I believe we can maintain the functionality needed for Module.slice, thinking more about this I think we can even maintain the Model.summary working at least for elegy Modules.

@SamuelMarks
Copy link
Contributor

I suppose you could go a little further and centralise imports to one part of the codebase, say an __init__.py, and in that section have a bunch of options like:

if environ["ELEGY_ENGINE"] == "np":
    # use `numpy`
elif environ["ELEGY_ENGINE"] == "jax":
    # use `JAX`
elif environ["ELEGY_ENGINE"] == "arrow":
   # use apache arrow
elif environ["ELEGY_ENGINE"] == "tensorflow":
    # use `tf.experimental.numpy`, and maybe other tf stuff also

Most formula are equivalent, so easily work (à la, the K from the old Keras implementation)

@cgarciae
Copy link
Collaborator Author

Hey @SamuelMarks, right now by "framework agnostic" I mean that it should be able to work with any / most of the other Jax-based Module framework such as Flax, Haiku, ect.

@alexander-g
Copy link
Contributor

I left it like this because its more inline how you would manually have to do it when using a Module from any framework instead of a GeneralizedModule.

Just because other frameworks make it harder for their users doesn't mean Elegy has to copy it. It's not like it's mutually exclusive, the user could still call init() manually if they want.

@cgarciae
Copy link
Collaborator Author

Just because other frameworks make it harder for their users doesn't mean Elegy has to copy it. It's not like it's mutually exclusive, the user could still call init() manually if they want.

For now I think there is still one advantage of init + apply: safety. If we have a single type-dependent apply for GeneralizedModule then init can be called instead of apply by accident if the user messes up the types. If you have a bug you could potentially be always calling init which could be confusing.

That said, I am very intrigued by the idea, maybe we can explore it in the future.

@cgarciae
Copy link
Collaborator Author

@charlielito Examples are now being tested.

@cgarciae
Copy link
Collaborator Author

It been a very intense PR and some tasks are missing but I think I am merging this to master and continue the tasks in subsequent PR before release to make potential contribution easier.

@cgarciae cgarciae merged commit 06c0bc4 into master Jan 31, 2021
@cgarciae cgarciae changed the title framework-agnostic Framework Agnostic API: This PR introduces a new low-level API, remove the dependency between Model and Module, add support for Flax and Haiku, simplifies hooks. Jan 31, 2021
@cgarciae cgarciae changed the title Framework Agnostic API: This PR introduces a new low-level API, remove the dependency between Model and Module, add support for Flax and Haiku, simplifies hooks. Framework Agnostic API: Introduces a new low-level API, remove the dependency between Model and Module, add support for Flax and Haiku, simplifies hooks. Jan 31, 2021
@cgarciae cgarciae changed the title Framework Agnostic API: Introduces a new low-level API, remove the dependency between Model and Module, add support for Flax and Haiku, simplifies hooks. Framework Agnostic API: Introduces a new low-level API, removes the dependency between Model and Module, adds support for Flax and Haiku, simplifies hooks. Jan 31, 2021
@cgarciae cgarciae deleted the framework-agnostic branch January 31, 2021 22:31
alexander-g added a commit to alexander-g/elegy that referenced this pull request Feb 1, 2021
cgarciae pushed a commit that referenced this pull request Feb 1, 2021
@cgarciae cgarciae mentioned this pull request Feb 1, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants