Skip to content

Commit

Permalink
Update Getting Started + README (#152)
Browse files Browse the repository at this point in the history
* strucure docs

* Update getting started

* remove update_modules by default

* update getting started

* getting started low-level api

* update readme + correct logistic regression confusion

* update dependencies + version

* update ipython dev

* update readme

* random change

* turn off example testing
  • Loading branch information
cgarciae committed Feb 1, 2021
1 parent a8b917b commit 59acb3b
Show file tree
Hide file tree
Showing 31 changed files with 1,936 additions and 1,181 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ jobs:
- name: Upload coverage
uses: codecov/codecov-action@v1

- name: Test Examples
run: bash scripts/test-examples.sh
# - name: Test Examples
# run: bash scripts/test-examples.sh
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ cython_debug/
/test.*
/summaries
/runs
/docs/model
/docs/model/
/docs/models/
/models
/TODO
147 changes: 80 additions & 67 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Elegy

[![PyPI Status Badge](https://badge.fury.io/py/elegy.svg)](https://pypi.org/project/elegy/)
[![Coverage](https://img.shields.io/codecov/c/github/poets-ai/elegy?color=%2334D058)](https://codecov.io/gh/poets-ai/elegy)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/elegy)](https://pypi.org/project/elegy/)
Expand All @@ -10,16 +9,13 @@

-----------------

_Elegy is a Neural Networks framework based on Jax inspired by Keras._

Elegy implements the Keras API but makes changes to play better with Jax and gives more flexibility around [losses and metrics](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/) and excellent [module system](https://poets-ai.github.io/elegy/guides/module-system/) that makes it super easy to use. Elegy is in an early stage, feel free to send us your feedback!
_Elegy is a framework-agnostic Trainer interface for the Jax ecosystem._

#### Main Features

* **Familiar**: Elegy should feel very familiar to Keras users.
* **Flexible**: Elegy improves upon the basic Keras API by letting users optionally take more control over the definition of losses and metrics.
* **Easy-to-use**: Elegy maintains all the simplicity and ease of use that Keras brings with it.
* **Compatible**: Elegy strives to be compatible with the rest of the Jax ecosystem.
* **Flexible**: Elegy provides a functional Pytorch Lightning-like low-level API that provides maximal flexibility when needed.
* **Easy-to-use**: Elegy provides a Keras-like high-level API that makes it very easy to do common tasks.
* **Agnostic**: Elegy provides support a variety of frameworks including Flax, Haiku, and Optax on the high-level API, and it is 100% framework-agnostic on the low-level API.
* **Compatible**: Elegy can consume a wide variety of common data sources including TensorFlow Datasets, Pytorch DataLoaders, Python generators, and Numpy pytrees.

For more information take a look at the [Documentation](https://poets-ai.github.io/elegy).

Expand All @@ -32,22 +28,27 @@ pip install elegy

For Windows users we recommend the Windows subsystem for linux 2 [WSL2](https://docs.microsoft.com/es-es/windows/wsl/install-win10?redirectedfrom=MSDN) since [jax](https://github.com/google/jax/issues/438) does not support it yet.

## Quick Start
Elegy greatly simplifies the training of Deep Learning models compared to pure Jax where, due to Jax's functional nature, users have to do a lot of book keeping around the state of the model. In Elegy you just have to follow 3 basic steps:
## Quick Start: High-level API
In Elegy's high-level API provides a very simple interface you can use by implementing following steps:

**1.** Define the architecture inside an `elegy.Module`:
**1.** Define the architecture inside a `Module`. We will use Flax Linen for this example:
```python
class MLP(elegy.Module):
def call(self, x: jnp.ndarray) -> jnp.ndarray:
x = elegy.nn.Linear(300)(x)
import flax.linen as nn
import jax

class MLP(nn.Module):
@nn.compact
def call(self, x):
x = nn.Dense(300)(x)
x = jax.nn.relu(x)
x = elegy.nn.Linear(10)(x)
x = nn.Dense(10)(x)
return x
```
Note that we can define sub-modules on-the-fly directly in the `call` (forward) method.

**2.** Create a `Model` from this module and specify additional things like losses, metrics, and optimizers:
```python
import elegy, optax

model = elegy.Model(
module=MLP(),
loss=[
Expand All @@ -72,59 +73,71 @@ model.fit(
)
```

And you are done! For more information check out:


* Our [Getting Started](https://poets-ai.github.io/elegy/getting-started/) tutorial.
* Elegy's [Documentation](https://poets-ai.github.io/elegy).
* The [examples](https://github.com/poets-ai/elegy/tree/master/examples) directory.
* [What is Jax?](https://github.com/google/jax#what-is-jax)

## Why Jax & Elegy?

Given all the well-established Deep Learning framework like TensorFlow + Keras or Pytorch + Pytorch-Lightning/Skorch, it is fair to ask why we need something like Jax + Elegy? Here are some of the reasons why this framework exists.

#### Why Jax?

**Jax** is a linear algebra library with the perfect recipe:
* Numpy's familiar API
* The speed and hardware support of XLA
* Automatic Differentiation

The awesome thing about Jax is that Deep Learning is just a use-case that it happens to excel at but you can use it for most task you would use NumPy for. Jax is so compatible with Numpy that is array type actually inherits from `np.ndarray`.

In a sense, Jax takes the best of both TensorFlow and Pytorch in a principled manner: while both TF and Pytorch historically converged to the same set of features, their APIs still contain quirks they have to keep for compatibility.

#### Why Elegy?

We believe that **Elegy** can offer the best experience for coding Deep Learning applications by leveraging the power and familiarity of Jax API, an easy-to-use and succinct Module system, and packaging everything on top of a convenient Keras-like API. Elegy improves upon other Deep Learning frameworks in the following ways:

1. Its hook-based [Module System](https://poets-ai.github.io/elegy/guides/module-system/) makes it easier (less verbose) to write model code compared to Keras & Pytorch since it lets you declare sub-modules, parameters, and states directly on your `call` (forward) method. Thanks to this you get shape inference for free so there is no need for a `build` method (Keras) or propagating shape information all over the place (Pytorch). A naive implementation of `Linear` could be as simple as:
## Quick Start: Low-level API
In Elegy's low-level API provides lets you define exactly what goes on during training, testing, and inference. Lets define the `test_step` to implement a linear classifier in pure jax:

**1.** Calculate our loss, logs, and states:
```python
class Linear:
def __init__(self, units):
super().__init__()
self.units = units

def call(self, x):
w = self.add_parameter("w", [x.shape[-1], self.units], initializer=jnp.ones)
b = self.add_parameter("b", [self.units], initializer=jnp.ones)

return jnp.dot(x, w) + b
class LinearClassifier(elegy.Model):
# request parameters by name via depending injection.
# possible: net_params, x, y_true, net_states, metrics_states, sample_weight, class_weight, rng, states, initializing
def test_step(
self,
x, # inputs
y_true, # labels
states: elegy.States, # model state
initializing: bool, # if True we should initialize our parameters
):
# flatten + scale
x = jnp.reshape(x, (x.shape[0], -1)) / 255
# maybe initialize or use existing
if initializing:
w = jax.random.uniform(
jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10]
)
b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1])
else:
w, b = states.net_params
# model
logits = jnp.dot(x, w) + b
# categorical crossentropy loss
labels = jax.nn.one_hot(y_true, 10)
loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)
# metrics
logs = dict(
accuracy=accuracy,
loss=loss,
)
return loss, logs, states.update(rng=rng, net_params=(w, b))
```
2. It has a very flexible system for defining the inputs for [losses and metrics](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/) based on _dependency injection_ in opposition to Keras rigid requirement to have matching (output, label) pairs, and being unable to use additional information like inputs, parameters, and states in the definition of losses and metrics.
3. Its hook system preserve's [reference information](https://poets-ai.github.io/elegy/guides/module-system/) from a module to its sub-modules, parameters, and states while maintaining a functional API. This is crucial since most Jax-based frameworks like Flax and Haiku tend to loose this information which makes it very tricky to perform tasks like transfer learning where you need to mix a pre-trained models into a new model (easier to do if you keep references).

## Features
* `Model` estimator class
* `losses` module
* `metrics` module
* `regularizers` module
* `callbacks` module
* `nn` layers module
**2.** Instantiate our `LinearClassifier` with an optimizer:
```python
model = LinearClassifier(
optimizer=optax.rmsprop(1e-3),
)
```
**3.** Train the model using the `fit` method:
```python
model.fit(
x=X_train,
y=y_train,
epochs=100,
steps_per_epoch=200,
batch_size=64,
validation_data=(X_test, y_test),
shuffle=True,
callbacks=[elegy.callbacks.TensorBoard("summaries")]
)
```

For more information checkout the **Reference API** section in the [Documentation](https://poets-ai.github.io/elegy).
## More Info
* [Getting Started: High-level API](https://poets-ai.github.io/elegy/getting-started-high-level-api/) tutorial.
* [Getting Started: High-level API](https://poets-ai.github.io/elegy/getting-started-low-level-api/) tutorial.
* Elegy's [Documentation](https://poets-ai.github.io/elegy).
* The [examples](https://github.com/poets-ai/elegy/tree/master/examples) directory.
* [What is Jax?](https://github.com/google/jax#what-is-jax)

## Contributing
Deep Learning is evolving at an incredible pace, there is so much to do and so few hands. If you wish to contribute anything from a loss or metric to a new awesome feature for Elegy just open an issue or send a PR! For more information check out our [Contributing Guide](https://poets-ai.github.io/elegy/guides/contributing).
Expand All @@ -144,9 +157,9 @@ To cite this project:
```
@software{elegy2020repository,
author = {PoetsAI},
title = {Elegy: A Keras-like deep learning framework based on Jax},
title = {Elegy: A framework-agnostic Trainer interface for the Jax ecosystem},
url = {https://github.com/poets-ai/elegy},
version = {0.3.0},
version = {0.4.0},
year = {2020},
}
```
Expand Down
File renamed without changes.
88 changes: 88 additions & 0 deletions docs/contributing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Contributing
This is a short guide on how to start contributing to Elegy along with some best practices for the project.

## Setup
We use `poetry >= 1.1.4`, the easiest way to setup a development environment is run:

```bash
poetry config virtualenvs.in-project true --local
poetry install
```

In order for Jax to recognize your GPU, you will probably have to install it again using the command below.

```bash
PYTHON_VERSION=cp38
CUDA_VERSION=cuda101 # alternatives: cuda100, cuda101, cuda102, cuda110, check your cuda version
PLATFORM=manylinux2010_x86_64 # alternatives: manylinux2010_x86_64
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.55-$PYTHON_VERSION-none-$PLATFORM.whl
pip install --upgrade jax
```

#### Gitpod
An alternative way to contribute is using [gitpod](https://gitpod.io/) which creates a vscode-based cloud development enviroment.
To get started just login at gitpod, grant the appropriate permissions to github, and open the following link:

https://gitpod.io/#https://github.com/poets-ai/elegy

We have built a `python 3.8` enviroment and all development dependencies will install when the enviroment starts.

## Creating Losses and Metrics
For this you can follow these guidelines:

* Each loss / metric should be defined in its own file.
* Inherit from either `elegy.losses.loss.Loss` or `elegy.metrics.metric.Metric` or an existing class that inherits from them.
* Try to use an existing metric or loss as a template
* You must provide documentation for the following:
* The class definition.
* The `__init__` method.
* The `call` method.
* Try to port the documentation + signature from its Keras counter part.
* If so you must give credits to the original source file.
* You must include tests.
* If you there exists an equivalent loss/metric in Keras you must test numerical equivalence between both.

## Testing
To execute all the tests just run
```bash
pytest
```

## Documentation
We use `mkdocs`. If you create a new object that requires documentation please do the following:

* Add a markdown file inside `/docs/api` in the appropriate location according to the project's structure. This file must:
* Contain the path of function / class as header
* Use `mkdocstring` to render the API information.
* Example:
```markdown
# elegy.losses.BinaryCrossentropy

::: elegy.losses.BinaryCrossentropy
selection:
inherited_members: true
members:
- call
- __init__
```
* Add and entry to `mkdocs.yml` inside `nav` pointing to this file. Checkout `mkdocs.yml`.

To build and visualize the documentation locally run
```bash
mkdocs serve
```

## Creating a PR
Before sending a pull request make sure all test run and code is formatted with `black`:

```bash
black .
```

## Changelog
`CHANGELOG.md` is automatically generated using [github-changelog-generator](https://github.com/github-changelog-generator/github-changelog-generator), to update the changelog just run:
```bash
docker run -it --rm -v (pwd):/usr/local/src/your-app ferrarimarco/github-changelog-generator -u poets-ai -p elegy -t <TOKEN>
```
where `<TOKEN>` token can be obtained from Github at [Personal access tokens](https://github.com/settings/tokens), you only have to give permission for the `repo` section.
2 changes: 1 addition & 1 deletion docs/guides/module-system.md → docs/elegy-module.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# The Module System
# Elegy Module

This is a guide to Elegy's underlying Module System. It will help get a better understanding of how
Elegy interacts with Jax at the lower level, certain details about the hooks system and how it
Expand Down

0 comments on commit 59acb3b

Please sign in to comment.