Skip to content

Commit

Permalink
Streamable Conformer-Transducer ASR model for LibriSpeech (#2140)
Browse files Browse the repository at this point in the history
* Introduce DCT+DCConv logic

* DDP fix?

* Batch of changes and things brought back

* Streaming fixes (successfully trains)

* WIP streaming code

* WIP functional streaming code

* Fix left context

* Fix formatting

* Cleanups and docs in streaming utils

* Better comment hparams, change seed back to orig, improve naming

* uncomment averaging stuff; it was some ipython issue

* Remove pin_memory as it was not beneficial

* More cleanups, comments on context stuff

* More comments and TODOs

* encode_streaming docstring

* Dirty TransducerBeamSearcher change for streaming GS

* Fix precommit

* Fix encoders that do not support chunk_size

* Pre-commit again

* Make chunk_size type consistent

* Fix formatting of doctest in split_wav_lens

* Remove outdated TODO

* Add hasattr streaming to retain model backcompat

* Cleanup doc and naming for transducer_greedy_decode

* Cite paper for chunked attention

* Remove lost comment

* Update comment in self-attention

* Don't apply masked fill fix in the non-bool mask case

* Added TODO README update

* Revert change to custom_tgt_module; patching model instead

* Remove added entry in README

* Fix streaming conformer conv mismatch

* More conformer conv adjustments

* Adjust context size

* Remove outdated comment

* Fixed causal conformer decoder

* Fix linting

* Gate `custom_tgt_module` creation behind the presence of decoder layers

* Re-enable checkpoint averaging

* Change averaged ckpt count to 10

* Add new model results to README

* WIP refactor: Introduce DCTConfig dataclass

* Improved notice in README

* Formatting and linting fixes

* Attempt at fixing circular import?

* utils can't depend on core it seems; move dct

* Whoops, missed file

* Add DCT test, fix issues

* Remove now obsolete yaml variables for streaming

* Formatting

* Add dummy dct_config parameter to keep unsupported encoders working

* Linting fix

* Fix typo

* Add note on runtime autocast accuracy

* Fix very bad typo from refactor in YAML

* Fix hasattr streaming check

* Remove legacy comment

* Fix left context size calculation in new mask code

* Fix causal models in TransformerASR

* Remove comment on high-level inference code

* YAML formatting + commenting dynchunktrain stuff

* Remove outdated comment about DCConv left contexts

* Remove commented out debug prints from TransformerASR

* Move DCT into utils again

* Rename all(?) mentions of DCT to explicit dynamic chunk training

* Clarify padding logic

* Remove now-useless _do_conv, fix horrible formatting

* Slightly fix formatting further

* Add docstrings to forward_streaming methods

* Add a reference on Dynamic Chunk Training

* Rework conformer docstring docs

* Update conformer author list, fix doc formatting for authors

* Fix trailing whitespace in conformer

* Improved comments in Conformer.forward

* Added random dynchunktrain sampler example

* More explicit names for mask functions in TransformerASR

* Added docstring example on encode_streaming

* Pre-commit fix

* Fix typo in conformer

* Initial streaming integration test

* Precommit fix

* Fix indent in YAML

* More consistent spelling in streaming integration test
  • Loading branch information
asumagic committed Dec 18, 2023
1 parent cc4cf4f commit b01944f
Show file tree
Hide file tree
Showing 14 changed files with 1,812 additions and 76 deletions.
38 changes: 34 additions & 4 deletions recipes/LibriSpeech/ASR/transducer/README.md
Expand Up @@ -25,14 +25,44 @@ If your GPU effectively supports fp16 (half-precision) computations, it is recom
Enabling half precision can significantly reduce the peak VRAM requirements. For example, in the case of the Conformer Transducer recipe trained with Librispeech, the peak VRAM decreases from 39GB to 12GB when using fp16.
According to our tests, the performance is not affected.


# Librispeech Results

Dev. clean is evaluated with Greedy Decoding while the test sets are using Greedy Decoding OR a RNNLM + Beam Search.
Evaluation is performed in fp32. However, we found that during inference, fp16 or bf16 autocast has very little incidence on the WER.

| Release | Hyperparams file | Train precision | Dev-clean Greedy | Test-clean Greedy | Test-other Greedy | Test-clean BS+RNNLM | Test-other BS+RNNLM | Model link | GPUs |
|:-------------:|:---------------------------:|:-:| :------:| :-----------:| :------------------:| :------------------:| :------------------:| :--------:| :-----------:|
| 2023-12-12 | conformer_transducer.yaml `streaming: True` | bf16 | 2.56% | 2.72% | 6.47% | \* | \* | https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing | [4x A100SXM4 40GB](https://docs.alliancecan.ca/wiki/Narval/en) |

<sub>\*: not evaluated due to performance issues, see [issue #2301](https://github.com/speechbrain/speechbrain/issues/2301)</sub>

## Streaming model

### WER vs chunk size & left context

The following matrix presents the Word Error Rate (WER%) achieved on LibriSpeech
`test-clean` with various chunk sizes (in ms) and left context sizes (in # of
chunks).

The relative difference is not trivial to interpret, because we are not testing
against a continuous stream of speech, but rather against utterances of various
lengths. This tends to bias results in favor of larger chunk sizes.

The chunk size might not accurately represent expected latency due to slight
padding differences in streaming contexts.

The left chunk size is not representative of the receptive field of the model.
Because the model caches the streaming context at different layers, the model
may end up forming indirect dependencies to audio many seconds ago.

| Release | Hyperparams file | Dev-clean Greedy | Test-clean Greedy | Test-other Greedy | Test-clean BS+RNNLM | Test-other BS+RNNLM | Model link | GPUs |
|:-------------:|:---------------------------:| :------:| :-----------:| :------------------:| :------------------:| :------------------:| :--------:| :-----------:|
| 2023-07-19 | conformer_transducer.yaml | 2.62 | 2.84 | 6.98 | 2.62 | 6.31 | https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing | 3x3090 24GB |
| | full | cs=32 (1280ms) | 24 (960ms) | 16 (640ms) | 12 (480ms) | 8 (320ms) |
|:-----:|:----:|:-----:|:-----:|:-----:|:-----:|:-----:|
| full | 2.72%| - | - | - | - | - |
| lc=32 | - | 3.09% | 3.07% | 3.26% | 3.31% | 3.44% |
| 16 | - | 3.10% | 3.07% | 3.27% | 3.32% | 3.50% |
| 8 | - | 3.10% | 3.11% | 3.31% | 3.39% | 3.62% |
| 4 | - | 3.12% | 3.13% | 3.37% | 3.51% | 3.80% |
| 2 | - | 3.19% | 3.24% | 3.50% | 3.79% | 4.38% |

# **About SpeechBrain**
- Website: https://speechbrain.github.io/
Expand Down
Expand Up @@ -63,14 +63,36 @@ precision: fp32 # bf16, fp16 or fp32
batch_size: 8
grad_accumulation_factor: 4
sorting: random
avg_checkpoints: 1 # Number of checkpoints to average for evaluation
avg_checkpoints: 10 # Number of checkpoints to average for evaluation

# Feature parameters
sample_rate: 16000
n_fft: 512
n_mels: 80
win_length: 32

# Streaming & dynamic chunk training options
# At least for the current architecture on LibriSpeech, we found out that
# non-streaming accuracy is very similar between `streaming: True` and
# `streaming: False`.
streaming: True # controls all Dynamic Chunk Training & chunk size & left context mechanisms

# Configuration for Dynamic Chunk Training.
# In this model, a chunk is roughly equivalent to 40ms of audio.
dynchunktrain_config_sampler: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfigRandomSampler # yamllint disable-line rule:line-length
chunkwise_prob: 0.6 # Probability during a batch to limit attention and sample a random chunk size in the following range
chunk_size_min: 8 # Minimum chunk size (if in a DynChunkTrain batch)
chunk_size_max: 32 # Maximum chunk size (if in a DynChunkTrain batch)
limited_left_context_prob: 0.75 # If in a DynChunkTrain batch, the probability during a batch to restrict left context to a random number of chunks
left_context_chunks_min: 2 # Minimum left context size (in # of chunks)
left_context_chunks_max: 32 # Maximum left context size (in # of chunks)
# If you specify a valid/test config, you can optionally have evaluation be
# done with a specific DynChunkTrain configuration.
# valid_config: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfig
# chunk_size: 24
# left_context_size: 16
# test_config: ...

# Dataloader options
train_dataloader_opts:
batch_size: !ref <batch_size>
Expand Down
20 changes: 17 additions & 3 deletions recipes/LibriSpeech/ASR/transducer/train.py
Expand Up @@ -58,7 +58,6 @@ def compute_forward(self, batch, stage):
tokens_with_bos
)

# Forward pass
feats = self.hparams.compute_features(wavs)

# Add feature augmentation if specified.
Expand All @@ -69,10 +68,25 @@ def compute_forward(self, batch, stage):
)

current_epoch = self.hparams.epoch_counter.current

# Old models may not have the streaming hparam, we don't break them in
# any other way so just check for its presence
if hasattr(self.hparams, "streaming") and self.hparams.streaming:
dynchunktrain_config = self.hparams.dynchunktrain_config_sampler(
stage
)
else:
dynchunktrain_config = None

feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)

src = self.modules.CNN(feats)
x = self.modules.enc(src, wav_lens, pad_idx=self.hparams.pad_index)
x = self.modules.enc(
src,
wav_lens,
pad_idx=self.hparams.pad_index,
dynchunktrain_config=dynchunktrain_config,
)
x = self.modules.proj_enc(x)

e_in = self.modules.emb(tokens_with_bos)
Expand Down Expand Up @@ -270,7 +284,7 @@ def on_evaluate_start(self, max_key=None, min_key=None):
super().on_evaluate_start()

ckpts = self.checkpointer.find_checkpoints(
max_key=max_key, min_key=min_key
max_key=max_key, min_key=min_key,
)
ckpt = sb.utils.checkpoints.average_checkpoints(
ckpts, recoverable_name="model"
Expand Down
50 changes: 44 additions & 6 deletions speechbrain/decoders/transducer.py
Expand Up @@ -135,7 +135,9 @@ def forward(self, tn_output):
hyps = self.searcher(tn_output)
return hyps

def transducer_greedy_decode(self, tn_output):
def transducer_greedy_decode(
self, tn_output, hidden_state=None, return_hidden=False
):
"""Transducer greedy decoder is a greedy decoder over batch which apply Transducer rules:
1- for each time step in the Transcription Network (TN) output:
-> Update the ith utterance only if
Expand All @@ -149,18 +151,43 @@ def transducer_greedy_decode(self, tn_output):
Output from transcription network with shape
[batch, time_len, hiddens].
hidden_state : (torch.Tensor, torch.Tensor)
Hidden state to initially feed the decode network with. This is
useful in conjunction with `return_hidden` to be able to perform
beam search in a streaming context, so that you can reuse the last
hidden state as an initial state across calls.
return_hidden : bool
Whether the return tuple should contain an extra 5th element with
the hidden state at of the last step. See `hidden_state`.
Returns
-------
torch.tensor
Tuple of 4 or 5 elements (if `return_hidden`).
First element: List[List[int]]
List of decoded tokens
Second element: torch.Tensor
Outputs a logits tensor [B,T,1,Output_Dim]; padding
has not been removed.
Third element: None
nbest; irrelevant for greedy decode
Fourth element: None
nbest scores; irrelevant for greedy decode
Fifth element: Present if `return_hidden`, (torch.Tensor, torch.Tensor)
Tuple representing the hidden state required to call
`transducer_greedy_decode` where you left off in a streaming
context.
"""
hyp = {
"prediction": [[] for _ in range(tn_output.size(0))],
"logp_scores": [0.0 for _ in range(tn_output.size(0))],
}
# prepare BOS = Blank for the Prediction Network (PN)
hidden = None
input_PN = (
torch.ones(
(tn_output.size(0), 1),
Expand All @@ -169,8 +196,13 @@ def transducer_greedy_decode(self, tn_output):
)
* self.blank_id
)
# First forward-pass on PN
out_PN, hidden = self._forward_PN(input_PN, self.decode_network_lst)

if hidden_state is None:
# First forward-pass on PN
out_PN, hidden = self._forward_PN(input_PN, self.decode_network_lst)
else:
out_PN, hidden = hidden_state

# For each time step
for t_step in range(tn_output.size(1)):
# do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden]
Expand Down Expand Up @@ -210,13 +242,19 @@ def transducer_greedy_decode(self, tn_output):
have_update_hyp, selected_hidden, hidden
)

return (
ret = (
hyp["prediction"],
torch.Tensor(hyp["logp_scores"]).exp().mean(),
None,
None,
)

if return_hidden:
# append the `(out_PN, hidden)` tuple to ret
ret += ((out_PN, hidden,),)

return ret

def transducer_beam_search_decode(self, tn_output):
"""Transducer beam search decoder is a beam search decoder over batch which apply Transducer rules:
1- for each utterance:
Expand Down
4 changes: 4 additions & 0 deletions speechbrain/lobes/models/transformer/Branchformer.py
Expand Up @@ -320,6 +320,7 @@ def forward(
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
pos_embs: Optional[torch.Tensor] = None,
dynchunktrain_config=None,
):
"""
Arguments
Expand All @@ -335,6 +336,9 @@ def forward(
If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
where S is the sequence length, and E is the embedding dimension.
"""
assert (
dynchunktrain_config is None
), "Dynamic Chunk Training unsupported for this encoder"

if self.attention_type == "RelPosMHAXL":
if pos_embs is None:
Expand Down

0 comments on commit b01944f

Please sign in to comment.