Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamable Conformer-Transducer ASR model for LibriSpeech #2140

Merged
merged 84 commits into from Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
cedfd46
Introduce DCT+DCConv logic
asumagic Jul 5, 2023
b59accb
DDP fix?
asumagic Jul 5, 2023
9241781
Batch of changes and things brought back
asumagic Jul 8, 2023
a504541
Streaming fixes (successfully trains)
asumagic Jul 12, 2023
008055b
WIP streaming code
asumagic Jul 18, 2023
4fc70f4
WIP functional streaming code
asumagic Aug 22, 2023
4a7e95f
Fix left context
asumagic Aug 22, 2023
83616a1
Fix formatting
asumagic Aug 23, 2023
9b3a0d3
Cleanups and docs in streaming utils
asumagic Aug 23, 2023
4ee692f
Better comment hparams, change seed back to orig, improve naming
asumagic Aug 23, 2023
b459187
uncomment averaging stuff; it was some ipython issue
asumagic Aug 23, 2023
fa5edea
Remove pin_memory as it was not beneficial
asumagic Aug 23, 2023
ed36776
More cleanups, comments on context stuff
asumagic Aug 23, 2023
a256508
More comments and TODOs
asumagic Aug 23, 2023
0e69129
encode_streaming docstring
asumagic Aug 23, 2023
d05a771
Dirty TransducerBeamSearcher change for streaming GS
asumagic Aug 28, 2023
afb96db
Fix precommit
asumagic Aug 29, 2023
6585ae8
Fix encoders that do not support chunk_size
asumagic Aug 29, 2023
98d0ddf
Pre-commit again
asumagic Aug 29, 2023
c23435f
Make chunk_size type consistent
asumagic Aug 29, 2023
dd264a6
Fix formatting of doctest in split_wav_lens
asumagic Aug 29, 2023
107688e
Remove outdated TODO
asumagic Aug 29, 2023
8b88dc9
Add hasattr streaming to retain model backcompat
asumagic Aug 29, 2023
c4c730d
Cleanup doc and naming for transducer_greedy_decode
asumagic Aug 29, 2023
a02ed5f
Cite paper for chunked attention
asumagic Aug 29, 2023
be92a12
Remove lost comment
asumagic Aug 29, 2023
382b97b
Update comment in self-attention
asumagic Aug 29, 2023
12f89bf
Don't apply masked fill fix in the non-bool mask case
asumagic Aug 29, 2023
ee444a0
Added TODO README update
asumagic Aug 29, 2023
1013c71
Revert change to custom_tgt_module; patching model instead
asumagic Aug 30, 2023
10ff215
Remove added entry in README
asumagic Aug 31, 2023
b16754f
Fix streaming conformer conv mismatch
asumagic Nov 20, 2023
e5785d8
More conformer conv adjustments
asumagic Nov 21, 2023
9633156
Adjust context size
asumagic Nov 21, 2023
c1fbb8f
Remove outdated comment
asumagic Nov 23, 2023
7446706
Fixed causal conformer decoder
asumagic Nov 28, 2023
1f91e85
Fix linting
asumagic Nov 28, 2023
d96a92e
Gate `custom_tgt_module` creation behind the presence of decoder layers
asumagic Dec 4, 2023
ddb6d5b
Re-enable checkpoint averaging
asumagic Dec 4, 2023
6ce59c3
Change averaged ckpt count to 10
asumagic Dec 4, 2023
4f52a6f
Add new model results to README
asumagic Dec 14, 2023
de7d997
WIP refactor: Introduce DCTConfig dataclass
asumagic Dec 14, 2023
d9b6f88
Improved notice in README
asumagic Dec 14, 2023
ffb820c
Merge branch 'unstable-v0.6' into streaming-asr-v2
asumagic Dec 14, 2023
11fec0c
Formatting and linting fixes
asumagic Dec 14, 2023
65255c8
Attempt at fixing circular import?
asumagic Dec 14, 2023
0ec5417
utils can't depend on core it seems; move dct
asumagic Dec 14, 2023
61606ac
Whoops, missed file
asumagic Dec 14, 2023
fe38e5b
Add DCT test, fix issues
asumagic Dec 15, 2023
c2fc373
Remove now obsolete yaml variables for streaming
asumagic Dec 15, 2023
2d31242
Formatting
asumagic Dec 15, 2023
faa79e9
Add dummy dct_config parameter to keep unsupported encoders working
asumagic Dec 15, 2023
90f1367
Linting fix
asumagic Dec 15, 2023
bcc6b2c
Fix typo
asumagic Dec 15, 2023
2577cc6
Add note on runtime autocast accuracy
asumagic Dec 15, 2023
0c8e382
Fix very bad typo from refactor in YAML
asumagic Dec 15, 2023
db73114
Fix hasattr streaming check
asumagic Dec 15, 2023
74496e6
Remove legacy comment
asumagic Dec 15, 2023
4558232
Fix left context size calculation in new mask code
asumagic Dec 15, 2023
8da79f3
Fix causal models in TransformerASR
asumagic Dec 15, 2023
bd9f506
Remove comment on high-level inference code
asumagic Dec 15, 2023
0a49c01
YAML formatting + commenting dynchunktrain stuff
asumagic Dec 15, 2023
28cfbb6
Remove outdated comment about DCConv left contexts
asumagic Dec 15, 2023
246b8b4
Remove commented out debug prints from TransformerASR
asumagic Dec 15, 2023
49e73ec
Move DCT into utils again
asumagic Dec 18, 2023
2faf306
Rename all(?) mentions of DCT to explicit dynamic chunk training
asumagic Dec 18, 2023
c577c5b
Clarify padding logic
asumagic Dec 18, 2023
176f1d8
Remove now-useless _do_conv, fix horrible formatting
asumagic Dec 18, 2023
f17470e
Slightly fix formatting further
asumagic Dec 18, 2023
86850ad
Add docstrings to forward_streaming methods
asumagic Dec 18, 2023
392ed08
Add a reference on Dynamic Chunk Training
asumagic Dec 18, 2023
721f147
Rework conformer docstring docs
asumagic Dec 18, 2023
a459180
Update conformer author list, fix doc formatting for authors
asumagic Dec 18, 2023
b2f6b5c
Fix trailing whitespace in conformer
asumagic Dec 18, 2023
86b8bba
Improved comments in Conformer.forward
asumagic Dec 18, 2023
9c63fe2
Added random dynchunktrain sampler example
asumagic Dec 18, 2023
eee3752
More explicit names for mask functions in TransformerASR
asumagic Dec 18, 2023
17d4f5f
Added docstring example on encode_streaming
asumagic Dec 18, 2023
ccc00f6
Pre-commit fix
asumagic Dec 18, 2023
dbcdf9a
Fix typo in conformer
asumagic Dec 18, 2023
eb67e1a
Initial streaming integration test
asumagic Dec 18, 2023
f6213ec
Precommit fix
asumagic Dec 18, 2023
56c69ff
Fix indent in YAML
asumagic Dec 18, 2023
e70bb1d
More consistent spelling in streaming integration test
asumagic Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 34 additions & 4 deletions recipes/LibriSpeech/ASR/transducer/README.md
Expand Up @@ -25,14 +25,44 @@ If your GPU effectively supports fp16 (half-precision) computations, it is recom
Enabling half precision can significantly reduce the peak VRAM requirements. For example, in the case of the Conformer Transducer recipe trained with Librispeech, the peak VRAM decreases from 39GB to 12GB when using fp16.
According to our tests, the performance is not affected.


# Librispeech Results

Dev. clean is evaluated with Greedy Decoding while the test sets are using Greedy Decoding OR a RNNLM + Beam Search.
Evaluation is performed in fp32. However, we found that during inference, fp16 or bf16 autocast has very little incidence on the WER.

| Release | Hyperparams file | Train precision | Dev-clean Greedy | Test-clean Greedy | Test-other Greedy | Test-clean BS+RNNLM | Test-other BS+RNNLM | Model link | GPUs |
|:-------------:|:---------------------------:|:-:| :------:| :-----------:| :------------------:| :------------------:| :------------------:| :--------:| :-----------:|
| 2023-12-12 | conformer_transducer.yaml `streaming: True` | bf16 | 2.56% | 2.72% | 6.47% | \* | \* | https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing | [4x A100SXM4 40GB](https://docs.alliancecan.ca/wiki/Narval/en) |

<sub>\*: not evaluated due to performance issues, see [issue #2301](https://github.com/speechbrain/speechbrain/issues/2301)</sub>

## Streaming model

### WER vs chunk size & left context

The following matrix presents the Word Error Rate (WER%) achieved on LibriSpeech
`test-clean` with various chunk sizes (in ms) and left context sizes (in # of
chunks).

The relative difference is not trivial to interpret, because we are not testing
against a continuous stream of speech, but rather against utterances of various
lengths. This tends to bias results in favor of larger chunk sizes.

The chunk size might not accurately represent expected latency due to slight
padding differences in streaming contexts.

The left chunk size is not representative of the receptive field of the model.
Because the model caches the streaming context at different layers, the model
may end up forming indirect dependencies to audio many seconds ago.

| Release | Hyperparams file | Dev-clean Greedy | Test-clean Greedy | Test-other Greedy | Test-clean BS+RNNLM | Test-other BS+RNNLM | Model link | GPUs |
|:-------------:|:---------------------------:| :------:| :-----------:| :------------------:| :------------------:| :------------------:| :--------:| :-----------:|
| 2023-07-19 | conformer_transducer.yaml | 2.62 | 2.84 | 6.98 | 2.62 | 6.31 | https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing | 3x3090 24GB |
| | full | cs=32 (1280ms) | 24 (960ms) | 16 (640ms) | 12 (480ms) | 8 (320ms) |
|:-----:|:----:|:-----:|:-----:|:-----:|:-----:|:-----:|
| full | 2.72%| - | - | - | - | - |
| lc=32 | - | 3.09% | 3.07% | 3.26% | 3.31% | 3.44% |
| 16 | - | 3.10% | 3.07% | 3.27% | 3.32% | 3.50% |
| 8 | - | 3.10% | 3.11% | 3.31% | 3.39% | 3.62% |
| 4 | - | 3.12% | 3.13% | 3.37% | 3.51% | 3.80% |
| 2 | - | 3.19% | 3.24% | 3.50% | 3.79% | 4.38% |

# **About SpeechBrain**
- Website: https://speechbrain.github.io/
Expand Down
Expand Up @@ -63,14 +63,36 @@ precision: fp32 # bf16, fp16 or fp32
batch_size: 8
grad_accumulation_factor: 4
sorting: random
avg_checkpoints: 1 # Number of checkpoints to average for evaluation
avg_checkpoints: 10 # Number of checkpoints to average for evaluation

# Feature parameters
sample_rate: 16000
n_fft: 512
n_mels: 80
win_length: 32

# Streaming & dynamic chunk training options
asumagic marked this conversation as resolved.
Show resolved Hide resolved
# At least for the current architecture on LibriSpeech, we found out that
# non-streaming accuracy is very similar between `streaming: True` and
# `streaming: False`.
streaming: True # controls all Dynamic Chunk Training & chunk size & left context mechanisms

# Configuration for Dynamic Chunk Training.
# In this model, a chunk is roughly equivalent to 40ms of audio.
dynchunktrain_config_sampler: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfigRandomSampler # yamllint disable-line rule:line-length
chunkwise_prob: 0.6 # Probability during a batch to limit attention and sample a random chunk size in the following range
chunk_size_min: 8 # Minimum chunk size (if in a DynChunkTrain batch)
chunk_size_max: 32 # Maximum chunk size (if in a DynChunkTrain batch)
limited_left_context_prob: 0.75 # If in a DynChunkTrain batch, the probability during a batch to restrict left context to a random number of chunks
left_context_chunks_min: 2 # Minimum left context size (in # of chunks)
left_context_chunks_max: 32 # Maximum left context size (in # of chunks)
# If you specify a valid/test config, you can optionally have evaluation be
# done with a specific DynChunkTrain configuration.
# valid_config: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfig
# chunk_size: 24
# left_context_size: 16
# test_config: ...

# Dataloader options
train_dataloader_opts:
batch_size: !ref <batch_size>
Expand Down
20 changes: 17 additions & 3 deletions recipes/LibriSpeech/ASR/transducer/train.py
Expand Up @@ -58,7 +58,6 @@ def compute_forward(self, batch, stage):
tokens_with_bos
)

# Forward pass
feats = self.hparams.compute_features(wavs)

# Add feature augmentation if specified.
Expand All @@ -69,10 +68,25 @@ def compute_forward(self, batch, stage):
)

current_epoch = self.hparams.epoch_counter.current

# Old models may not have the streaming hparam, we don't break them in
# any other way so just check for its presence
if hasattr(self.hparams, "streaming") and self.hparams.streaming:
dynchunktrain_config = self.hparams.dynchunktrain_config_sampler(
stage
)
else:
dynchunktrain_config = None

feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)

src = self.modules.CNN(feats)
x = self.modules.enc(src, wav_lens, pad_idx=self.hparams.pad_index)
x = self.modules.enc(
src,
wav_lens,
pad_idx=self.hparams.pad_index,
dynchunktrain_config=dynchunktrain_config,
)
x = self.modules.proj_enc(x)

e_in = self.modules.emb(tokens_with_bos)
Expand Down Expand Up @@ -270,7 +284,7 @@ def on_evaluate_start(self, max_key=None, min_key=None):
super().on_evaluate_start()

ckpts = self.checkpointer.find_checkpoints(
max_key=max_key, min_key=min_key
max_key=max_key, min_key=min_key,
)
ckpt = sb.utils.checkpoints.average_checkpoints(
ckpts, recoverable_name="model"
Expand Down
50 changes: 44 additions & 6 deletions speechbrain/decoders/transducer.py
Expand Up @@ -135,7 +135,9 @@ def forward(self, tn_output):
hyps = self.searcher(tn_output)
return hyps

def transducer_greedy_decode(self, tn_output):
def transducer_greedy_decode(
self, tn_output, hidden_state=None, return_hidden=False
):
"""Transducer greedy decoder is a greedy decoder over batch which apply Transducer rules:
1- for each time step in the Transcription Network (TN) output:
-> Update the ith utterance only if
Expand All @@ -149,18 +151,43 @@ def transducer_greedy_decode(self, tn_output):
Output from transcription network with shape
[batch, time_len, hiddens].

hidden_state : (torch.Tensor, torch.Tensor)
Hidden state to initially feed the decode network with. This is
useful in conjunction with `return_hidden` to be able to perform
beam search in a streaming context, so that you can reuse the last
hidden state as an initial state across calls.

return_hidden : bool
Whether the return tuple should contain an extra 5th element with
the hidden state at of the last step. See `hidden_state`.

Returns
-------
torch.tensor
Tuple of 4 or 5 elements (if `return_hidden`).

First element: List[List[int]]
List of decoded tokens

Second element: torch.Tensor
Outputs a logits tensor [B,T,1,Output_Dim]; padding
has not been removed.

Third element: None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then why do we return it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To match the beam search API, AFAIK. This is not new, though, it was just undocumented before.

nbest; irrelevant for greedy decode

Fourth element: None
asumagic marked this conversation as resolved.
Show resolved Hide resolved
nbest scores; irrelevant for greedy decode

Fifth element: Present if `return_hidden`, (torch.Tensor, torch.Tensor)
Tuple representing the hidden state required to call
`transducer_greedy_decode` where you left off in a streaming
context.
"""
hyp = {
"prediction": [[] for _ in range(tn_output.size(0))],
"logp_scores": [0.0 for _ in range(tn_output.size(0))],
}
# prepare BOS = Blank for the Prediction Network (PN)
hidden = None
input_PN = (
torch.ones(
(tn_output.size(0), 1),
Expand All @@ -169,8 +196,13 @@ def transducer_greedy_decode(self, tn_output):
)
* self.blank_id
)
# First forward-pass on PN
out_PN, hidden = self._forward_PN(input_PN, self.decode_network_lst)

if hidden_state is None:
# First forward-pass on PN
out_PN, hidden = self._forward_PN(input_PN, self.decode_network_lst)
else:
out_PN, hidden = hidden_state

# For each time step
for t_step in range(tn_output.size(1)):
# do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden]
Expand Down Expand Up @@ -210,13 +242,19 @@ def transducer_greedy_decode(self, tn_output):
have_update_hyp, selected_hidden, hidden
)

return (
ret = (
hyp["prediction"],
torch.Tensor(hyp["logp_scores"]).exp().mean(),
None,
None,
)

if return_hidden:
# append the `(out_PN, hidden)` tuple to ret
ret += ((out_PN, hidden,),)

return ret

def transducer_beam_search_decode(self, tn_output):
"""Transducer beam search decoder is a beam search decoder over batch which apply Transducer rules:
1- for each utterance:
Expand Down
4 changes: 4 additions & 0 deletions speechbrain/lobes/models/transformer/Branchformer.py
Expand Up @@ -320,6 +320,7 @@ def forward(
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
pos_embs: Optional[torch.Tensor] = None,
dynchunktrain_config=None,
):
"""
Arguments
Expand All @@ -335,6 +336,9 @@ def forward(
If custom pos_embs are given it needs to have the shape (1, 2*S-1, E)
where S is the sequence length, and E is the embedding dimension.
"""
assert (
dynchunktrain_config is None
), "Dynamic Chunk Training unsupported for this encoder"

if self.attention_type == "RelPosMHAXL":
if pos_embs is None:
Expand Down