Skip to content

Commit

Permalink
DOCFIX: Advanced AshPy sections (#63)
Browse files Browse the repository at this point in the history
* Add section on custom metrics to the docs
* Add section on custom callbacks to the docs
  • Loading branch information
mr-ubik committed Apr 14, 2020
1 parent 7ca6881 commit dd7c629
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 28 deletions.
239 changes: 239 additions & 0 deletions docs/source/advanced_ashpy.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
##############
Advanced AshPy
##############

**************
Custom Metrics
**************

AshPy Trainers can accept metrics that they will use for both logging and automatic model
selection.

Implementing a custom Metric in AshPy can be done via two approach:

1. Your metric is already available as a |keras.Metric| and you want to use it as is.
2. You need to write the implementation of the Metric from scratch or you need to alter
the default behavior we provide for AshPy Metrics.


Wrapping Keras Metrics
======================


In case number (1) what you want to do is to search for one of the Metrics provided by AshPy
and use it as a wrapper around the one you wish to use.

.. note::
Passing an :mod:`operator` funciton to the AshPy Metric will enable model selection using the
metric value.

The example below shows how to implement the Precision metric for an |ClassifierTrainer|.

.. code-block:: python
import operator
from ashpy.metrics import ClassifierMetric
from ashpy.trainers import ClassifierTrainer
from tensorflow.keras.metrics import Precision
precision = ClassifierMetric(
metric=tf.keras.metrics.Precision(),
model_selection_operator=operator.gt,
logdir=Path().cwd() / "log",
)
trainer = ClassifierTrainer(
...
metrics = [precision]
...
)
You can apply this technique to any object derived and behaving as a |keras.Metric|
(i.e. the Metrics present in `TensorFlow Addons`_)


Creating your own Metric
========================


As an example of a custom Metric we present the analysis of the :class:`ashpy.metrics.classifier.ClassifierLoss`.

.. code-block:: python
class ClassifierLoss(Metric):
"""A handy way to measure the classification loss."""
def __init__(
self,
name: str = "loss",
model_selection_operator: Callable = None,
logdir: Union[Path, str] = Path().cwd() / "log",
) -> None:
"""
Initialize the Metric.
Args:
name (str): Name of the metric.
model_selection_operator (:py:obj:`typing.Callable`): The operation that will
be used when `model_selection` is triggered to compare the metrics,
used by the `update_state`.
Any :py:obj:`typing.Callable` behaving like an :py:mod:`operator` is accepted.
.. note::
Model selection is done ONLY if an operator is specified here.
logdir (str): Path to the log dir, defaults to a `log` folder in the current
directory.
"""
super().__init__(
name=name,
metric=tf.metrics.Mean(name=name, dtype=tf.float32),
model_selection_operator=model_selection_operator,
logdir=logdir,
)
def update_state(self, context: ClassifierContext) -> None:
"""
Update the internal state of the metric, using the information from the context object.
Args:
context (:py:class:`ashpy.contexts.ClassifierContext`): An AshPy Context
holding all the information the Metric needs.
"""
updater = lambda value: lambda: self._metric.update_state(value)
for features, labels in context.dataset:
loss = context.loss(
context,
features=features,
labels=labels,
training=context.log_eval_mode == LogEvalMode.TRAIN,
)
self._distribute_strategy.experimental_run_v2(updater(loss))
* Each custom Metric should always inherit from |ashpy.Metric|.
* We advise that each custom Metric respescts the base :meth:`ashpy.metrics.metric.Metric.__init__()`
* Inside the :func:`super()` call be sure to provide one of the :mod:`tf.keras.metrics` `primitive` metrics
(i.e. :class:`tf.keras.metrics.Mean`, :class:`tf.keras.metrics.Sum`).

.. warning::
The :code:`name` argument of the :meth:`ashpy.metrics.metric.Metric.__init__()` is a :obj:`str` identifier
which should be unique across all the metrics used by your :class:`Trainer <ashpy.trainers.Trainer>`.


Custom Computation inside Metric.update_state()
-----------------------------------------------


* This method is invoked during the training and receives a |Context|.
* In this example, since we are working under the |ClassifierTrainer| we are using an |ClassifierContext|.
For more information on the |Context| family of objects see :ref:`ashpy-internals`.
* Inside this update_state state we won't be doing any fancy computation, we just retrieve
the loss value from the |ClassifierContext| and then we call the :code:`updater` lambda
from the fetched distribution strategy.
* The active distribution strategy is automatically retrieved during the :func:`super()`,
this guarantees that every object derived from an |ashpy.Metric| will work flawlessly
even in a distributed environment.
* :attr:`ashpy.metrics.metric.Metric.metric` (here referenced as :code:`self._metric` is the
`primitive` |keras.Metric| whose :code:`upadate_state()` method we will be using to simplify
our operations.
* Custom computation will almost always be done via iteration over the data offered by the
|Context|.

For a much more complex (but probably exhaustive) example have a look at the source code
of :class:`ashpy.metrics.SlicedWassersteinDistance <ashpy.metrics.sliced_wasserstein_metric.SlicedWassersteinDistance>`.

****************
Custom Callbacks
****************

Our |ashpy.Callback| is built on the same base structure as a |keras.Callback| exposing methods
acting as hooks for the same events.

* on_train_start
* on_epoch_start
* on_batch_start
* on_batch_end
* on_epoch_end
* on_train_end

Inside the :mod:`ashpy.callbacks` module we offer two `primitive` Callbacks classes to inherit from.

1. :class:`ashpy.callbacks.Callback <ashpy.callbacks.callback.Callback>`: is the most basic
form of callback and the basic block for all the other.
2. |ashpy.CounterCallback|: is derived
from :class:`ashpy.callbacks.Callback <ashpy.callbacks.callback.Callback>` and contains
built-in logic for triggering an event given a desired frequency.

Let's take a look at the following example which is the callback used to log GANs output to
TensorBoard - :class:`ashpy.callbacks.gan.LogImageGANCallback`

.. code-block:: python
class LogImageGANCallback(CounterCallback):
def __init__(
self,
event: Event = Event.ON_EPOCH_END,
name: str = "log_image_gan_callback",
event_freq: int = 1,
) -> None:
"""
Initialize the LogImageCallbackGAN.
Args:
event (:py:class:`ashpy.callbacks.events.Event`): event to consider.
event_freq (int): frequency of logging.
name (str): name of the callback.
"""
super(LogImageGANCallback, self).__init__(
event=event, fn=self._log_fn, name=name, event_freq=event_freq
)
def _log_fn(self, context: GANContext) -> None:
"""
Log output of the generator to Tensorboard.
Args:
context (:py:class:`ashpy.contexts.gan.GANContext`): current context.
"""
if context.log_eval_mode == LogEvalMode.TEST:
out = context.generator_model(context.generator_inputs, training=False)
elif context.log_eval_mode == LogEvalMode.TRAIN:
out = context.fake_samples
else:
raise ValueError("Invalid LogEvalMode")
log("generator", out, context.global_step)
Let's start with the :code:`__init__()` function, as for the Custom |ashpy.Metric| when
inheriting from either |ashpy.Callback| or |ashpy.CounterCallback| respect the common part of the signature:

* :code:`event`: In AshPy we use an Enum - :class:`ashpy.callbacks.Event <ashpy.callbacks.events.Event>` - to
choose the event type you want the |ashpy.Callback| to be triggered on.
* :code:`name`: Unique :obj:`str` identifier for the |ashpy.Callback|
* :code:`event_freq`: Simple :obj:`int` specifying the frequency.
* :code:`fn`: A :func:`callable()` this is the function that gets triggered. Inside AshPy we
converged on using a private method called ``_log_fn()`` in each of our derived Callbacks.
Whatever approach you choose, the function fed to :code:`fn` should have a |Context| as input.
For more information on the |Context| family of objects see :ref:`ashpy-internals`.

.. warning::
The :code:`name` argument of the :meth:`ashpy.callbacks.callback.Callback.__init__()` is a :obj:`str` identifier
which should be unique across all the callbacks used by your :class:`Trainer <ashpy.trainers.Trainer>`.


.. |ashpy.Callback| replace:: :class:`Callback <ashpy.callbacks.callback.Callback>`
.. |ashpy.CounterCallback| replace:: :class:`CounterCallback <ashpy.callbacks.counter_callback.CounterCallback>`
.. |ashpy.Metric| replace:: :class:`ashpy.metrics.Metric <ashpy.metrics.metric.Metric>`
.. |ClassifierContext| replace:: :class:`ClassifierContext <ashpy.contexts.classifier.ClassifierContext>`
.. |ClassifierTrainer| replace:: :class:`ClassifierTrainer <ashpy.trainers.classifier.ClassifierTrainer>`
.. |Context| replace:: :class:`Context <ashpy.contexts.context.Context>`
.. |keras.Callback| replace:: :class:`tf.keras.callbacks.Callback`
.. |keras.Metric| replace:: :class:`tf.keras.metrics.Metric`
.. |Metric.update_state()| replace:: :meth:`ashpy.metrics.metric.Metric.update_state()`


.. _TensorFlow Addons: https://www.tensorflow.org/addons/overview
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ Welcome to ASHPY's documentation!
home
write_the_docs
getting_started
advanced_ashpy
internals
api
internal
dependencies_graph
about

Expand Down
63 changes: 38 additions & 25 deletions docs/source/internal.rst → docs/source/internals.rst
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
.. _ashpy-internals:

###############
AshPy Internals
###############

The two main concepts of AshPy internals are :py:class:`Context <ashpy.contexts.context.Context>` and :py:class:`Executor <ashpy.losses.executor.Executor>`.
The two main concepts of AshPy internals are |Context| and |Executor|.

*******
Context
-------
*******

A :py:class:`Context <ashpy.contexts.context.Context>` is an object that contains all the needed information. Here needed depends on the application.
In AshPy the :code:`Context` concept links a generic training loop with the loss function calculation and the model evaluation.
A :code:`Context` is a useful class in which all the models, metrics, dataset and mode of your network are set.
A |Context| is an object that contains all the needed information. Here needed depends on the application.
In AshPy the |Context| concept links a generic training loop with the loss function calculation and the model evaluation.
A |Context| is a useful class in which all the models, metrics, dataset and mode of your network are set.
Passing the context around means that you can any time access to all what you need in order to perform any type of computation.

In AshPy we have (until now) three types of contexts:
Expand All @@ -18,9 +22,9 @@ In AshPy we have (until now) three types of contexts:
- `GANEncoder Context`_

Classifier Context
++++++++++++++++++
==================

The :py:class:`ClassifierContext <ashpy.contexts.classifier.ClassifierContext>` is very simple, it contains only:
The :class:`ClassifierContext <ashpy.contexts.classifier.ClassifierContext>` is rather straightforward containing only:

- classifier_model
- loss
Expand All @@ -30,13 +34,13 @@ The :py:class:`ClassifierContext <ashpy.contexts.classifier.ClassifierContext>`
- global_step
- ckpt

In this way the loss function (:py:class:`Executor <ashpy.losses.executor.Executor>`) can use the context in order to get the model
In this way the loss function (|Executor|) can use the context in order to get the model
and the needed information in order to correctly feed the model.

GAN Context
+++++++++++
===========

The basic :py:class:`GANContext <ashpy.contexts.gan.GANContext>` is composed by:
The basic :class:`GANContext <ashpy.contexts.gan.GANContext>` is composed by:

- dataset
- generator_model
Expand All @@ -51,18 +55,19 @@ The basic :py:class:`GANContext <ashpy.contexts.gan.GANContext>` is composed by:
As we can see we have all information needed to define our training and evaluation loop.

GANEncoder Context
++++++++++++++++++
==================

The :py:class:`GANEncoderContext <ashpy.contexts.gan.GANEncoderContext>` extends the GANContext, contains all the
The :class:`GANEncoderContext <ashpy.contexts.gan.GANEncoderContext>` extends the GANContext, contains all the
information of the base class plus:

- Encoder Model
- Encoder Loss

********
Executor
--------
********

The :py:class:`Executor <ashpy.losses.executor.Executor>` is the main concept behind the loss function implementation in AshPy.
The |Executor| is the main concept behind the loss function implementation in AshPy.
An Executor is a class that helps in order to better generalize a training loop.
With an Executor you can construct, for example, a custom loss function and put every computation you need inside it.
You should define a :code:`call` function inside your class and decorate it with :code:`@Executor.reduce` header, if needed.
Expand All @@ -75,11 +80,12 @@ An executor takes also care of the distribution strategy by reducing appropriate
`Tensorflow Guide`__).

An Executor Example
*******************
===================

In this example we will see the implementation of the Generator Binary CrossEntropy loss.

The :code:`__init__` method is straightforward, we need only to instantiate :py:class:`tf.losses.BinaryCrossentropy` object and then we pass it to our parent:
The :code:`__init__` method is straightforward, we need only to instantiate :class:`tf.losses.BinaryCrossentropy`
object and then we pass it to our parent:

.. code-block:: python
Expand Down Expand Up @@ -110,14 +116,16 @@ Then we need to implement the call function respecting the signature:
# mean everything
return tf.reduce_mean(value)
The function :py:func:`get_discriminator_inputs` returns the correct discriminator inputs using the context.
The discriminator input can be the output of the generator (unconditioned case) or the output of the generator together
with the condition (conditioned case).
The function :func:`get_discriminator_inputs` returns the correct discriminator inputs
using the context.
The discriminator input can be the output of the generator (unconditioned case) or the
output of the generator together with the condition (conditioned case).

The the :py:func:`call` uses the discriminator model inside the context in order to obtain the output of the
discriminator when evaluated in the `fake_inputs`.
The the :func:`call` uses the discriminator model inside the context in order to obtain
the output of the discriminator when evaluated in the `fake_inputs`.

After that the :py:func:`self._fn` (BinaryCrossentropy) is used to get the value of the loss. This loss is then averaged.
After that the :func:`self._fn` (BinaryCrossentropy) is used to get the value of the loss.
This loss is then averaged.

In this way the executor computes correctly the loss function.

Expand Down Expand Up @@ -147,8 +155,13 @@ If we want to use our executor in a distribution strategy the only modifications
The important things are:

- :code:`Executor.reduce_loss` decoration: uses the Executor decorator in order to correctly reduce the loss
- :code:`tf.reduce_mean(value, axis=1)` (last line), we perform only the mean over the axis 1. The output of the `call` function
should be a :py:class:`tf.Tensor` with shape (N, 1) or (N,). This is because the decorator performs the mean over the axis 0.
- :code:`tf.reduce_mean(value, axis=1)` (last line), we perform only the mean over the axis 1. The output of the ``call`` function
should be a :class:`tf.Tensor` with shape (N, 1) or (N,). This is because the decorator performs the mean over the axis 0.


.. |Context| replace:: :class:`Context <ashpy.contexts.context.Context>`
.. |Executor| replace:: :class:`Executor <ashpy.losses.executor.Executor>`

.. _tf_guide: https://www.tensorflow.org/beta/guide/distribute_strategy#using_tfdistributestrategy_with_custom_training_loops
__ tf_guide_
__ tf_guide_

4 changes: 2 additions & 2 deletions src/ashpy/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@

class Metric(ABC):
"""
Metric is the abstract class that every ash Metric must implement.
Metric is the abstract class that every AshPy Metric must implement.
AshPy Metrics wrap and extend Keras Metrics.
AshPy Metric wrap and extend :class:`tf.keras.metrics.Metric`.
"""

def __init__(
Expand Down

0 comments on commit dd7c629

Please sign in to comment.