Skip to content

Commit

Permalink
More Docs: Expand documentation for the low-level API guides. (#168)
Browse files Browse the repository at this point in the history
* pred_step + test_step

* fix colab links

* use pmean in pmap example
  • Loading branch information
cgarciae committed Mar 1, 2021
1 parent 690eee1 commit d0ff8f0
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 51 deletions.
4 changes: 2 additions & 2 deletions docs/getting-started/high-level-api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started-basic-api.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
"<a href=\"https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started/high-level-api.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
Expand Down Expand Up @@ -561,4 +561,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
4 changes: 2 additions & 2 deletions docs/getting-started/low-level-api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started-low-level-api.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
"<a href=\"https://colab.research.google.com/github/poets-ai/elegy/blob/master/docs/getting-started/low-level-api.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
Expand Down Expand Up @@ -602,4 +602,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
39 changes: 18 additions & 21 deletions docs/low-level-api/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,25 @@ Elegy's low-level API allows you to override some core methods in `Model` that s
### Methods
This is the list of all the overrideable methods:

| Caller Methods | Overridable Method |
| :----------------------------------- | :----------------- |
| - `predict` <br>- `predict_on_batch` | `pred_step` |
| - `evaluate`<br>- `test_on_batch` | `test_step` |
| | `grad_step` |
| - `fit`<br>- `train_on_batch` | `train_step` |
| - `summary` | `summary_step` |

Each overrideable method has a default implementation which is what gives rise to the high-level API, the default implementation almost always implements a method in term of another in this manner:

```
pred_step ⬅ test_step ⬅ grad_step ⬅ train_step
pred_step ⬅ summary_step
```
This allows you to e.g. override `test_step` and still be able to use use `fit` since `train_step` (called by `fit`) will call your `test_step` via `grad_step`. It also means that e.g. if you implement `test_step` but not `pred_step` there is a high chance both `predict` and `summary` will not work as expected since both depend on `pred_step`.
| Caller | Method |
| :--------- | :------------- |
| `predict` | `pred_step` |
| `evaluate` | `test_step` |
| | `grad_step` |
| `fit` | `train_step` |
| `init` | `init_step` |
| `summary` | `summary_step` |
| | `states_step` |
| | `jit_step` |

Each method has a default implementation which is what gives rise to the high-level API.

### Example
Each overrideable methods takes some input + state, performs some `jax` operations + updates the state, and returns some outputs + the new state. Lets see a simple example of a linear classifier using `test_step`:
Most overrideable methods take some input & state, perform some `jax` operations & updates the state, and returns some outputs & the new state. Lets see a simple example of a linear classifier using `test_step`:

```python
class LinearClassifier(elegy.Model):
def test_step(self, x, y_true, states, initializing) -> elegy.TestStep:
def test_step(self, x, y_true, states, initializing):
x = jnp.reshape(x, (x.shape[0], -1)) / 255

# initialize or use existing parameters
Expand Down Expand Up @@ -66,9 +63,9 @@ model.fit(
)
```

As you see here we perform everything from parameter initialization, modeling, calculating the main loss, and logging some metrics. Notes:
As you see here we perform everything from parameter initialization, modeling, calculating the main loss, and logging some metrics. Some notes about the previous example:

* The `states` argument of type `elegy.States` is an immutable Mapping which you add / update fields via its `update` method.
* `net_params` is one of the names used by the default implementation, check the [States](./states.md) guid for more information.
* `initializing` tells you whether you to initialize your parameters or fetch them from `states`, if you are using a Module framework this usually tells you whether to call `init` or `apply`.
* `test_step` returns 3 specific outputs, you should check the docs for each method to know what to return.
* `net_params` is one of the names used by the default implementation, check the [States](./states.md) guide for more information.
* `initializing` tells you whether to initialize the parameters of the model or fetch the current ones from `states`, if you are using a Module framework this usually tells you whether to call `init` or `apply`.
* `test_step` should returns 3 specific outputs (`loss`, `logs`, `states`), you should check the docs for each method to know what to return.
16 changes: 16 additions & 0 deletions docs/low-level-api/default-implementation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Default Implementation

### Methods
The default implementation favors composition by implementing a method in term of another, especifically if follows this call graph:

```
summary predict evalutate fit init
⬇️ ⬇️ ⬇️ ⬇️ ⬇️
call_summary_step call_pred_step call_test_step call_train_step call_init_step
⬇️ ⬇️ ⬇️ ⬇️ ⬇️
summary_step ➡️ pred_step ⬅ test_step ⬅ grad_step ⬅ train_step ⬅ init_step
```
This structure allows you to for example override `test_step` and still be able to use use `fit` since `train_step` (called by `fit`) will call your `test_step` via `grad_step`. It also means that if you implement `test_step` but not `pred_step` there is a high chance both `predict` and `summary` will not work.

##### call_* methods
The `call_<method>` method family are _entrypoints_ that usually just redirect to their inputs to `<method>`, you choose to override these if you need to perform some some computation only when method in question is the entry point. For example if you want to change the behavior of `evaluate` without affecting the behavior of `fit` while preserving most of the default implementation you can override `call_step_step` to do the corresponding adjustments and then call `test_step`. Since `train_step` does not depend on `call_step_step` then the change will manifest during `evaluate` but not during `fit`.
79 changes: 58 additions & 21 deletions docs/low-level-api/pred_step.md
Original file line number Diff line number Diff line change
@@ -1,34 +1,71 @@
# pred_step
This method is tasked with taking the input data and calculatin the predictions.
The `pred_step` method computes the predictions of the main model, by overriding this method you can directly influence what happens during `predict`.

### Inputs
The following inputs are available for `pred_step`:
Any of following input arguments are available for `pred_step`:

| name | types | description |
| :------------- | :------------- | :--------------------------------------- |
| `x` | `tp.Any` | Input data |
| `states` | `types.States` | Current state of the model |
| `initializing` | `bool` | Whether the model is initializing or not |
| `training` | `bool` | Whether the model is training or not |
| name | type | |
| :------------- | :------- | :--------------------------------------- |
| `x` | `Any` | Input data |
| `states` | `States` | Current state of the model |
| `initializing` | `bool` | Whether the model is initializing or not |
| `training` | `bool` | Whether the model is training or not |

### Output
You must request the arguments you want by **name**.

### Outputs
`pred_step` must output a tuple with the following values:

| name | types | description |
| :------- | :------------- | :--------------------------- |
| `y_pred` | `tp.Any` | The predictions of the model |
| `states` | `types.States` | The new state of the model |
| name | type | |
| :------- | :------- | :--------------------------- |
| `y_pred` | `Any` | The predictions of the model |
| `states` | `States` | The new state of the model |


### Callers
| method | when |
| :----------------- | :-------------------------- |
| `predict` | always |
| `predict_on_batch` | always |
| `test_step` | default implementation only |
| `summary_step` | default implementation only |
| method | when |
| :------------- | :--------------------- |
| `predict` | always |
| `test_step` | default implementation |
| `summary_step` | default implementation |

### Examples
If for some reason you wish to create a pure jax / Module-less model, you can define your own Model that implements `pred_step` like this:

```python
class LinearClassifier(elegy.Model):
def pred_step(self, x, y_true, states, initializing):
x = jnp.reshape(x, (x.shape[0], -1)) / 255

# initialize or use existing parameters
if initializing:
w = jax.random.uniform(
jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10]
)
b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1])
else:
w, b = states.net_params

# model
y_pred = jnp.dot(x, w) + b

return y_pred, states.update(net_params=(w, b))

### Example
model = LinearClassifier(
optimizer=optax.adam(1e-3),
loss=elegy.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=elegy.metrics.SparseCategoricalAccuracy(),
)

model.fit(
x=X_train,
y=y_train,
epochs=100,
batch_size=64,
)
```
Here we implement the same `LinearClassifier` from the [basics](./basics) section but we extracted the definition of the model to `pred_step` and we let the basic implementation of `test_step` take care of the `loss` and `metrics` which we provide to the `LinearClassifier`'s constructor.

### Default Implementation
### Default Implementation
The default implementation of `pred_step` does the following:
* Calls `api_module.init` or `api_module.apply` depending on state of `initializing`. `api_module` of type `GeneralizedModule` is a wrapper over the `module` object passed by the user to the `Model`s constructor.
3 changes: 0 additions & 3 deletions docs/low-level-api/supporting-the-high-level-api.md

This file was deleted.

138 changes: 138 additions & 0 deletions docs/low-level-api/test_step.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# test_step
The `test_step` computes the main `loss` of the model along with some `logs` for reporting, by overriding this method you can directly influence what happens during `evaluate`.

### Inputs
Any of following input arguments are available for `test_step`:

| name | type | |
| :-------------- | :------------------ | :------------------------------------------ |
| `x` | `Any` | Input data |
| `y_true` | `Any` | The target labels |
| `sample_weight` | `Optional[ndarray]` | The weight of each sample in the total loss |
| `class_weight` | `Optional[ndarray]` | The weight of each class in the total loss |
| `states` | `States` | Current state of the model |
| `initializing` | `bool` | Whether the model is initializing or not |
| `training` | `bool` | Whether the model is training or not |


You must request the arguments you want by **name**.

### Outputs
`pred_step` must output a tuple with the following values:

| name | type | |
| :------- | :------------------- | :------------------------------------------ |
| `loss` | `ndarray` | The loss of the model over the data |
| `logs` | `Dict[str, ndarray]` | A dictionary with a set of values to report |
| `states` | `States` | The new state of the model |


### Callers
| method | when |
| :----------- | :------------------------------------------------ |
| `evaluate` | always |
| `grad_step` | default implementation |
| `train_step` | default implementation during initialization only |

### Examples
Lets review the example of `test_step` found in [basics](./basics):

```python
class LinearClassifier(elegy.Model):
def test_step(self, x, y_true, states, initializing):
x = jnp.reshape(x, (x.shape[0], -1)) / 255

# initialize or use existing parameters
if initializing:
w = jax.random.uniform(
jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10]
)
b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1])
else:
w, b = states.net_params

# model
logits = jnp.dot(x, w) + b

# categorical crossentropy loss
labels = jax.nn.one_hot(y_true, 10)
loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

# metrics
logs = dict(accuracy=accuracy, loss=loss)

# update states
states = states.update(net_params=(w, b))

return loss, logs, states

model = LinearClassifier(
optimizer=optax.adam(1e-3)
)

model.fit(
x=X_train,
y=y_train,
epochs=100,
batch_size=64,
)
```
In this case `test_step` is defining both the "forward" pass of the model and calculating the losses and metrics in a single place. However, since we are not defining `pred_step` we loose the power to call `predict` which might not be desirable. The optimimal way to fix this is to extract the calculation of the logits into `pred_step` and call this from `test_step`:

```python
class LinearClassifier(elegy.Model):
def test_step(self, x, states, initializing):
x = jnp.reshape(x, (x.shape[0], -1)) / 255

# initialize or use existing parameters
if initializing:
w = jax.random.uniform(
jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10]
)
b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1])
else:
w, b = states.net_params

# model
logits = jnp.dot(x, w) + b

return logits, states.update(net_params=(w, b))

def test_step(self, x, y_true, states, initializing):
# call pred_step
logits, states = self.pred_step((x, states, initializing)

# categorical crossentropy loss
labels = jax.nn.one_hot(y_true, 10)
loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1))
accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true)

# metrics
logs = dict(accuracy=accuracy, loss=loss)

# update states
states = states.update(net_params=(w, b))

return loss, logs, states

model = LinearClassifier(
optimizer=optax.adam(1e-3),
)

model.fit(
x=X_train,
y=y_train,
epochs=100,
batch_size=64,
)
```
This not only creates a separation of concerns, it also favors code reuse, and we can now use `predict`, `evaluate`, and `fit` as intended.

There are cases however where you might want to implement a forward pass inside `test_step` that is different from what you would define in `pred_step`, for example you can create a `VAE` or `GAN` Models that use multiple modules to calculate the loss inside `test_step` (e.g. encoder, decoder, and discriminator) but only use the decoder inside `pred_step` to generate samples.

### Default Implementation
The default implementation of `pred_step` does the following:
* Call `pred_step` to get `y_pred`.
* Calls `api_loss.init` or `api_loss.apply` depending on state of `initializing`. `api_loss` of type `Losses` computes the aggregated batch loss from the loss functions passed by the user through the `loss` argument in the `Model`s constructor, and also computes a running mean of each loss individually which is passed for reporting to `logs`.
* Calls `api_metrics.init` or `api_metrics.apply` depending on state of `initializing`. `api_metrics` of type `Metrics` calculates the metrics passed by the user through the `metrics` argument in the `Model`s constructor and passes their values to `logs` for reporting.
4 changes: 2 additions & 2 deletions examples/elegy_mnist_conv_pmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def jit_step(self):
def grad_step(self, *args, **kwargs):
loss, logs, states, grads = super().grad_step(*args, **kwargs)

grads = jax.lax.psum(grads, axis_name="device")
grads = jax.lax.pmean(grads, axis_name="device")

return loss, logs, states, grads

# here we override train_step instead of train_step
# here we override call_train_step instead of train_step
def call_train_step(
self,
x,
Expand Down

0 comments on commit d0ff8f0

Please sign in to comment.