-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
More Docs: Expand documentation for the low-level API guides. (#168)
* pred_step + test_step * fix colab links * use pmean in pmap example
- Loading branch information
Showing
8 changed files
with
236 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Default Implementation | ||
|
||
### Methods | ||
The default implementation favors composition by implementing a method in term of another, especifically if follows this call graph: | ||
|
||
``` | ||
summary predict evalutate fit init | ||
⬇️ ⬇️ ⬇️ ⬇️ ⬇️ | ||
call_summary_step call_pred_step call_test_step call_train_step call_init_step | ||
⬇️ ⬇️ ⬇️ ⬇️ ⬇️ | ||
summary_step ➡️ pred_step ⬅ test_step ⬅ grad_step ⬅ train_step ⬅ init_step | ||
``` | ||
This structure allows you to for example override `test_step` and still be able to use use `fit` since `train_step` (called by `fit`) will call your `test_step` via `grad_step`. It also means that if you implement `test_step` but not `pred_step` there is a high chance both `predict` and `summary` will not work. | ||
|
||
##### call_* methods | ||
The `call_<method>` method family are _entrypoints_ that usually just redirect to their inputs to `<method>`, you choose to override these if you need to perform some some computation only when method in question is the entry point. For example if you want to change the behavior of `evaluate` without affecting the behavior of `fit` while preserving most of the default implementation you can override `call_step_step` to do the corresponding adjustments and then call `test_step`. Since `train_step` does not depend on `call_step_step` then the change will manifest during `evaluate` but not during `fit`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,71 @@ | ||
# pred_step | ||
This method is tasked with taking the input data and calculatin the predictions. | ||
The `pred_step` method computes the predictions of the main model, by overriding this method you can directly influence what happens during `predict`. | ||
|
||
### Inputs | ||
The following inputs are available for `pred_step`: | ||
Any of following input arguments are available for `pred_step`: | ||
|
||
| name | types | description | | ||
| :------------- | :------------- | :--------------------------------------- | | ||
| `x` | `tp.Any` | Input data | | ||
| `states` | `types.States` | Current state of the model | | ||
| `initializing` | `bool` | Whether the model is initializing or not | | ||
| `training` | `bool` | Whether the model is training or not | | ||
| name | type | | | ||
| :------------- | :------- | :--------------------------------------- | | ||
| `x` | `Any` | Input data | | ||
| `states` | `States` | Current state of the model | | ||
| `initializing` | `bool` | Whether the model is initializing or not | | ||
| `training` | `bool` | Whether the model is training or not | | ||
|
||
### Output | ||
You must request the arguments you want by **name**. | ||
|
||
### Outputs | ||
`pred_step` must output a tuple with the following values: | ||
|
||
| name | types | description | | ||
| :------- | :------------- | :--------------------------- | | ||
| `y_pred` | `tp.Any` | The predictions of the model | | ||
| `states` | `types.States` | The new state of the model | | ||
| name | type | | | ||
| :------- | :------- | :--------------------------- | | ||
| `y_pred` | `Any` | The predictions of the model | | ||
| `states` | `States` | The new state of the model | | ||
|
||
|
||
### Callers | ||
| method | when | | ||
| :----------------- | :-------------------------- | | ||
| `predict` | always | | ||
| `predict_on_batch` | always | | ||
| `test_step` | default implementation only | | ||
| `summary_step` | default implementation only | | ||
| method | when | | ||
| :------------- | :--------------------- | | ||
| `predict` | always | | ||
| `test_step` | default implementation | | ||
| `summary_step` | default implementation | | ||
|
||
### Examples | ||
If for some reason you wish to create a pure jax / Module-less model, you can define your own Model that implements `pred_step` like this: | ||
|
||
```python | ||
class LinearClassifier(elegy.Model): | ||
def pred_step(self, x, y_true, states, initializing): | ||
x = jnp.reshape(x, (x.shape[0], -1)) / 255 | ||
|
||
# initialize or use existing parameters | ||
if initializing: | ||
w = jax.random.uniform( | ||
jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10] | ||
) | ||
b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1]) | ||
else: | ||
w, b = states.net_params | ||
|
||
# model | ||
y_pred = jnp.dot(x, w) + b | ||
|
||
return y_pred, states.update(net_params=(w, b)) | ||
|
||
### Example | ||
model = LinearClassifier( | ||
optimizer=optax.adam(1e-3), | ||
loss=elegy.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
metrics=elegy.metrics.SparseCategoricalAccuracy(), | ||
) | ||
|
||
model.fit( | ||
x=X_train, | ||
y=y_train, | ||
epochs=100, | ||
batch_size=64, | ||
) | ||
``` | ||
Here we implement the same `LinearClassifier` from the [basics](./basics) section but we extracted the definition of the model to `pred_step` and we let the basic implementation of `test_step` take care of the `loss` and `metrics` which we provide to the `LinearClassifier`'s constructor. | ||
|
||
### Default Implementation | ||
### Default Implementation | ||
The default implementation of `pred_step` does the following: | ||
* Calls `api_module.init` or `api_module.apply` depending on state of `initializing`. `api_module` of type `GeneralizedModule` is a wrapper over the `module` object passed by the user to the `Model`s constructor. |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# test_step | ||
The `test_step` computes the main `loss` of the model along with some `logs` for reporting, by overriding this method you can directly influence what happens during `evaluate`. | ||
|
||
### Inputs | ||
Any of following input arguments are available for `test_step`: | ||
|
||
| name | type | | | ||
| :-------------- | :------------------ | :------------------------------------------ | | ||
| `x` | `Any` | Input data | | ||
| `y_true` | `Any` | The target labels | | ||
| `sample_weight` | `Optional[ndarray]` | The weight of each sample in the total loss | | ||
| `class_weight` | `Optional[ndarray]` | The weight of each class in the total loss | | ||
| `states` | `States` | Current state of the model | | ||
| `initializing` | `bool` | Whether the model is initializing or not | | ||
| `training` | `bool` | Whether the model is training or not | | ||
|
||
|
||
You must request the arguments you want by **name**. | ||
|
||
### Outputs | ||
`pred_step` must output a tuple with the following values: | ||
|
||
| name | type | | | ||
| :------- | :------------------- | :------------------------------------------ | | ||
| `loss` | `ndarray` | The loss of the model over the data | | ||
| `logs` | `Dict[str, ndarray]` | A dictionary with a set of values to report | | ||
| `states` | `States` | The new state of the model | | ||
|
||
|
||
### Callers | ||
| method | when | | ||
| :----------- | :------------------------------------------------ | | ||
| `evaluate` | always | | ||
| `grad_step` | default implementation | | ||
| `train_step` | default implementation during initialization only | | ||
|
||
### Examples | ||
Lets review the example of `test_step` found in [basics](./basics): | ||
|
||
```python | ||
class LinearClassifier(elegy.Model): | ||
def test_step(self, x, y_true, states, initializing): | ||
x = jnp.reshape(x, (x.shape[0], -1)) / 255 | ||
|
||
# initialize or use existing parameters | ||
if initializing: | ||
w = jax.random.uniform( | ||
jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10] | ||
) | ||
b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1]) | ||
else: | ||
w, b = states.net_params | ||
|
||
# model | ||
logits = jnp.dot(x, w) + b | ||
|
||
# categorical crossentropy loss | ||
labels = jax.nn.one_hot(y_true, 10) | ||
loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)) | ||
accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true) | ||
|
||
# metrics | ||
logs = dict(accuracy=accuracy, loss=loss) | ||
|
||
# update states | ||
states = states.update(net_params=(w, b)) | ||
|
||
return loss, logs, states | ||
|
||
model = LinearClassifier( | ||
optimizer=optax.adam(1e-3) | ||
) | ||
|
||
model.fit( | ||
x=X_train, | ||
y=y_train, | ||
epochs=100, | ||
batch_size=64, | ||
) | ||
``` | ||
In this case `test_step` is defining both the "forward" pass of the model and calculating the losses and metrics in a single place. However, since we are not defining `pred_step` we loose the power to call `predict` which might not be desirable. The optimimal way to fix this is to extract the calculation of the logits into `pred_step` and call this from `test_step`: | ||
|
||
```python | ||
class LinearClassifier(elegy.Model): | ||
def test_step(self, x, states, initializing): | ||
x = jnp.reshape(x, (x.shape[0], -1)) / 255 | ||
|
||
# initialize or use existing parameters | ||
if initializing: | ||
w = jax.random.uniform( | ||
jax.random.PRNGKey(42), shape=[np.prod(x.shape[1:]), 10] | ||
) | ||
b = jax.random.uniform(jax.random.PRNGKey(69), shape=[1]) | ||
else: | ||
w, b = states.net_params | ||
|
||
# model | ||
logits = jnp.dot(x, w) + b | ||
|
||
return logits, states.update(net_params=(w, b)) | ||
|
||
def test_step(self, x, y_true, states, initializing): | ||
# call pred_step | ||
logits, states = self.pred_step((x, states, initializing) | ||
|
||
# categorical crossentropy loss | ||
labels = jax.nn.one_hot(y_true, 10) | ||
loss = jnp.mean(-jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)) | ||
accuracy=jnp.mean(jnp.argmax(logits, axis=-1) == y_true) | ||
|
||
# metrics | ||
logs = dict(accuracy=accuracy, loss=loss) | ||
|
||
# update states | ||
states = states.update(net_params=(w, b)) | ||
|
||
return loss, logs, states | ||
|
||
model = LinearClassifier( | ||
optimizer=optax.adam(1e-3), | ||
) | ||
|
||
model.fit( | ||
x=X_train, | ||
y=y_train, | ||
epochs=100, | ||
batch_size=64, | ||
) | ||
``` | ||
This not only creates a separation of concerns, it also favors code reuse, and we can now use `predict`, `evaluate`, and `fit` as intended. | ||
|
||
There are cases however where you might want to implement a forward pass inside `test_step` that is different from what you would define in `pred_step`, for example you can create a `VAE` or `GAN` Models that use multiple modules to calculate the loss inside `test_step` (e.g. encoder, decoder, and discriminator) but only use the decoder inside `pred_step` to generate samples. | ||
|
||
### Default Implementation | ||
The default implementation of `pred_step` does the following: | ||
* Call `pred_step` to get `y_pred`. | ||
* Calls `api_loss.init` or `api_loss.apply` depending on state of `initializing`. `api_loss` of type `Losses` computes the aggregated batch loss from the loss functions passed by the user through the `loss` argument in the `Model`s constructor, and also computes a running mean of each loss individually which is passed for reporting to `logs`. | ||
* Calls `api_metrics.init` or `api_metrics.apply` depending on state of `initializing`. `api_metrics` of type `Metrics` calculates the metrics passed by the user through the `metrics` argument in the `Model`s constructor and passes their values to `logs` for reporting. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters