Skip to content

Commit

Permalink
Adding cache-aware streaming Conformer with look-ahead support (NVIDI…
Browse files Browse the repository at this point in the history
…A#3888)

Signed-off-by: Anas Abou Allaban <aabouallaban@pm.me>
  • Loading branch information
VahidooX authored and piraka9011 committed Aug 25, 2022
1 parent c3f2bcb commit c66415b
Show file tree
Hide file tree
Showing 37 changed files with 2,036 additions and 156 deletions.
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,7 @@ pipeline {
model.data_dir=/home/TestData/nlp/new_multiatis \
model.validation_ds.prefix=dev \
model.test_ds.prefix=dev \
trainer.gpus=[0] \
trainer.devices=[0] \
+trainer.fast_dev_run=true \
exp_manager.exp_dir=checkpoints2'
sh 'rm -rf checkpoints2'
Expand Down
50 changes: 49 additions & 1 deletion docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,54 @@ You may find the example config files of Conformer-Transducer model with charact
``<NeMo_git_root>/examples/asr/conf/conformer/conformer_transducer_char.yaml`` and
with sub-word encoding at ``<NeMo_git_root>/examples/asr/conf/conformer/conformer_transducer_bpe.yaml``.

Cache-aware Streaming Conformer
-------------------------------

Buffered streaming uses overlapping chunks to make an offline ASR model to be used for streaming with reasonable accuracy. However, it uses significant amount of duplication in computations due to the overlapping chunks.
Also there is a accuracy gep between the offline model and the streaming one as there is inconsistency between how we train the model and how we perform inference for streaming.
The Cache-aware Streaming Conformer models would tackle and address these disadvantages. They are variants of Conformer which are trained with limited right context and it would make it possible to match the training and inference.
The cache-aware approach is supported for both the Conformer-CTC and Conformer-Transducer and enables the model to be used very efficiently for streaming.

Three categories of layers in Conformer have access to right tokens: 1-depthwise convolutions 2-self-attention, and 3-convolutions in downsampling layers.
Streaming Conformer models uses causal convolutions or convolutions with lower right context and also self-attention with limited right context to limit the effective right context for the input.
The model trained with such limitations can be used in streaming mode and give the exact same output and accuracy as when the whole audio is given to the model in offline mode.
These model can use caching mechanism to store and reuse the activations during streaming inference to avoid any duplications in the computations as much as possible.

We support the following three right context modeling:
* fully causal model with zero look-ahead: tokens would not see any future tokens. convolution layers are all causal and right tokens are masked for self-attention.
It gives zero latency but with limited accuracy.
To train such a model, you need to set `encoder.att_context_size=[left_context, 0]` and `encoder.conv_context_size=causal` in the config.

* regular look-ahead: convolutions would be able to see few future frames, and self-attention would also see the same number of future tokens.
In this approach the activations for the look-ahead part is not cached and recalculated in the next chunks. The right context in each layer should be a small number as multiple layers would increase the effective context size and then increase the look-ahead size and latency.
For example for a model of 17 layers with 4x downsampling and 10ms window shift, then even 2 right context in each layer means 17*2*10*4=1360ms look-ahead. Each step after the downsampling corresponds to 4*10=40ms.

* chunk-aware look-ahead: input is split into equal chunks. Convolutions are fully causal while self-attention layers would be able to see all the tokens in their corresponding chunk.
For example, in a model which chunk size of 20 tokens, tokens at the first position of each chunk would see all the next 19 tokens while the last token would see zero future tokens.
This approach is more efficient than regular look-ahead in terms of computations as the activations for most of the look-ahead part would be cached and there is close to zero duplications in the calculations.
In terms of accuracy, this approach gives similar or even better results in term of accuracy than regular look-ahead as each token in each layer have access to more tokens on average. That is why we recommend to use this approach for streaming.


** Note: Latencies are based on the assumption that the forward time of the network is zero.

Approaches with non-zero look-ahead can give significantly better accuracy by sacrificing latency. The latency can get controlled by the left context size.


In all modes, left context can be controlled by the number of tokens to be visible in the self-attention and the kernel size of the convolutions.
For example, if left context of self-attention in each layer is set to 20 tokens and there are 10 layers of Conformer, then effective left context is 20*10=200 tokens.
Left context of self-attention for regular look-ahead can be set as any number while it should be set as a multiplication of the right context in chunk-aware look-ahead.
For convolutions, if we use a left context of 30 in such model, then there would be 30*10=300 effective left context.
Left context of convolutions is dependent to the their kernel size while it can be any number for self-attention layers. Higher left context for self-attention means larger cache and more computations for the self-attention.
Self-attention left context of around 6 secs would give close result to have unlimited left context. For a model with 4x downsampling and shift window of 10ms in the preprocessor, each token corresponds to 4*10=40ms.

If striding approach is used for downsampling, all the convolutions in downsampling would be fully causal and don't see future tokens.
It is recommended to use stacking for streaming model which is significantly faster and uses less memory.

You may find the example config files of cache-aware streaming Conformer models at
``<NeMo_git_root>/examples/asr/conf/conformer/streaming/conformer_transducer_bpe_streaming.yaml`` for Transducer variant and
at ``<NeMo_git_root>/examples/asr/conf/conformer/streaming/conformer_ctc_bpe.yaml`` for CTC variant.


.. _LSTM-Transducer_model:

LSTM-Transducer
Expand All @@ -135,7 +183,7 @@ LSTM-Transducer
LSTM-Transducer is a model which uses RNNs (eg. LSTM) in the encoder. The architecture of this model is followed from suggestions in :cite:`asr-models-he2019streaming`.
It uses RNNT/Transducer loss/decoder. The encoder consists of RNN layers (LSTM as default) with lower projection size to increase the efficiency.
Layer norm is added between the layers to stabilize the training.
It can be trained/used in unidirectional or bidirectional mode. The unidirectional mode is fully causal and can be used easily for simple and efficient frame-wise streaming. However the accuracy of this model is generally lower than other models like Conformer and Citrinet.
It can be trained/used in unidirectional or bidirectional mode. The unidirectional mode is fully causal and can be used easily for simple and efficient streaming. However the accuracy of this model is generally lower than other models like Conformer and Citrinet.

This model supports both the sub-word level and character level encodings. You may find the example config file of RNNT model with wordpiece encoding at ``<NeMo_git_root>/examples/asr/conf/lstm/lstm_transducer_bpe.yaml``.
You can find more details on the config files for the RNNT models at ``LSTM-Transducer <./configs.html#lstm-transducer>``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
help="Model downsampling factor, 8 for Citrinet models and 4 for Conformer models",
)
parser.add_argument(
'--max_steps_per_timestep', type=int, default=5, help='Maximum number of tokens decoded per acoustic timestepB'
'--max_steps_per_timestep', type=int, default=5, help='Maximum number of tokens decoded per acoustic timestep'
)
parser.add_argument('--stateful_decoding', action='store_true', help='Whether to perform stateful decoding')
parser.add_argument('--device', default=None, type=str, required=False)
Expand Down Expand Up @@ -175,10 +175,10 @@ def main(args):
torch.set_grad_enabled(False)
if args.asr_model.endswith('.nemo'):
logging.info(f"Using local ASR model from {args.asr_model}")
asr_model = nemo_asr.models.EncDecCTCModelBPE.restore_from(restore_path=args.asr_model)
asr_model = nemo_asr.models.ASRModel.restore_from(restore_path=args.asr_model)
else:
logging.info(f"Using NGC cloud ASR model {args.asr_model}")
asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name=args.asr_model)
asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=args.asr_model)

cfg = copy.deepcopy(asr_model._cfg)
OmegaConf.set_struct(cfg.preprocessor, False)
Expand Down

0 comments on commit c66415b

Please sign in to comment.