Skip to content

Commit

Permalink
Add EMA Docs, fix common collection documentation (NVIDIA#5757)
Browse files Browse the repository at this point in the history
* Add EMA docs, fix docs due to incorrect import, fix doc format for common collection

Signed-off-by: SeanNaren <snarenthiran@nvidia.com>

* Address feedback

Signed-off-by: SeanNaren <snarenthiran@nvidia.com>

Signed-off-by: SeanNaren <snarenthiran@nvidia.com>
  • Loading branch information
SeanNaren authored and titu1994 committed Mar 24, 2023
1 parent fdf2e1f commit 950ea09
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 34 deletions.
51 changes: 51 additions & 0 deletions docs/source/common/callbacks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
*********
Callbacks
*********

Exponential Moving Average (EMA)
================================

During training, EMA maintains a moving average of the trained parameters.
EMA parameters can produce significantly better results and faster convergence for a variety of different domains and models.

EMA is a simple calculation. EMA Weights are pre-initialized with the model weights at the start of training.

Every training update, the EMA weights are updated based on the new model weights.

.. math::
ema_w = ema_w * decay + model_w * (1-decay)
Enabling EMA is straightforward. We can pass the additional argument to the experiment manager at runtime.

.. code-block:: bash
python examples/asr/asr_ctc/speech_to_text_ctc.py \
model.train_ds.manifest_filepath=/path/to/my/train/manifest.json \
model.validation_ds.manifest_filepath=/path/to/my/validation/manifest.json \
trainer.devices=2 \
trainer.accelerator='gpu' \
trainer.max_epochs=50 \
exp_manager.ema.enable=True # pass this additional argument to enable EMA
To change the decay rate, pass the additional argument.

.. code-block:: bash
python examples/asr/asr_ctc/speech_to_text_ctc.py \
...
exp_manager.ema.enable=True \
exp_manager.ema.decay=0.999
We also offer other helpful arguments.

.. list-table::
:header-rows: 1

* - Argument
- Description
* - `exp_manager.ema.validate_original_weights=True`
- Validate the original weights instead of EMA weights.
* - `exp_manager.ema.every_n_steps=2`
- Apply EMA every N steps instead of every step.
* - `exp_manager.ema.cpu_offload=True`
- Offload EMA weights to CPU. May introduce significant slow-downs.
39 changes: 6 additions & 33 deletions docs/source/common/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,10 @@ Common Collection

The common collection contains things that could be used across all collections.

Tokenizers
----------
.. automodule:: nemo.collections.common.tokenizers.AutoTokenizer
:special-members: __init__
.. automodule:: nemo.collections.common.tokenizers.SentencePieceTokenizer
:special-members: __init__
.. automodule:: nemo.collections.common.tokenizers.TokenizerSpec
:special-members: __init__
.. toctree::
:maxdepth: 8


Losses
------
.. automodule:: nemo.collections.common.losses.AggregatorLoss
:special-members: __init__

.. automodule:: nemo.collections.common.losses.CrossEntropyLoss
:special-members: __init__

.. automodule:: nemo.collections.common.losses.MSELoss
:special-members: __init__

.. automodule:: nemo.collections.common.losses.SmoothedCrossEntropyLoss
:special-members: __init__
.. automodule:: nemo.collections.common.losses.SpanningLoss
:special-members: __init__


Metrics
-------

.. autoclass:: nemo.collections.common.metrics.Perplexity
:show-inheritance:
:members:
:undoc-members:
callbacks
losses
metrics
tokenizers
16 changes: 16 additions & 0 deletions docs/source/common/losses.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
Losses
------
.. autoclass:: nemo.collections.common.losses.AggregatorLoss
:special-members: __init__

.. autoclass:: nemo.collections.common.losses.CrossEntropyLoss
:special-members: __init__

.. autoclass:: nemo.collections.common.losses.MSELoss
:special-members: __init__

.. autoclass:: nemo.collections.common.losses.SmoothedCrossEntropyLoss
:special-members: __init__

.. autoclass:: nemo.collections.common.losses.SpanningLoss
:special-members: __init__
7 changes: 7 additions & 0 deletions docs/source/common/metrics.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Metrics
-------

.. autoclass:: nemo.collections.common.metrics.Perplexity
:show-inheritance:
:members:
:undoc-members:
8 changes: 8 additions & 0 deletions docs/source/common/tokenizers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Tokenizers
----------
.. autoclass:: nemo.collections.common.tokenizers.AutoTokenizer
:special-members: __init__
.. autoclass:: nemo.collections.common.tokenizers.SentencePieceTokenizer
:special-members: __init__
.. autoclass:: nemo.collections.common.tokenizers.TokenizerSpec
:special-members: __init__
2 changes: 1 addition & 1 deletion nemo/collections/common/callbacks/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import pytorch_lightning as pl
import torch
from lightning_utilities.core.rank_zero import rank_zero_info
from pytorch_lightning import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info


class EMA(Callback):
Expand Down

0 comments on commit 950ea09

Please sign in to comment.