diff --git a/recipes/LibriSpeech/ASR/transducer/README.md b/recipes/LibriSpeech/ASR/transducer/README.md index 72cef5c7c9..aa6d964e82 100644 --- a/recipes/LibriSpeech/ASR/transducer/README.md +++ b/recipes/LibriSpeech/ASR/transducer/README.md @@ -25,14 +25,44 @@ If your GPU effectively supports fp16 (half-precision) computations, it is recom Enabling half precision can significantly reduce the peak VRAM requirements. For example, in the case of the Conformer Transducer recipe trained with Librispeech, the peak VRAM decreases from 39GB to 12GB when using fp16. According to our tests, the performance is not affected. - # Librispeech Results Dev. clean is evaluated with Greedy Decoding while the test sets are using Greedy Decoding OR a RNNLM + Beam Search. +Evaluation is performed in fp32. However, we found that during inference, fp16 or bf16 autocast has very little incidence on the WER. + +| Release | Hyperparams file | Train precision | Dev-clean Greedy | Test-clean Greedy | Test-other Greedy | Test-clean BS+RNNLM | Test-other BS+RNNLM | Model link | GPUs | +|:-------------:|:---------------------------:|:-:| :------:| :-----------:| :------------------:| :------------------:| :------------------:| :--------:| :-----------:| +| 2023-12-12 | conformer_transducer.yaml `streaming: True` | bf16 | 2.56% | 2.72% | 6.47% | \* | \* | https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing | [4x A100SXM4 40GB](https://docs.alliancecan.ca/wiki/Narval/en) | + +\*: not evaluated due to performance issues, see [issue #2301](https://github.com/speechbrain/speechbrain/issues/2301) + +## Streaming model + +### WER vs chunk size & left context + +The following matrix presents the Word Error Rate (WER%) achieved on LibriSpeech +`test-clean` with various chunk sizes (in ms) and left context sizes (in # of +chunks). + +The relative difference is not trivial to interpret, because we are not testing +against a continuous stream of speech, but rather against utterances of various +lengths. This tends to bias results in favor of larger chunk sizes. + +The chunk size might not accurately represent expected latency due to slight +padding differences in streaming contexts. + +The left chunk size is not representative of the receptive field of the model. +Because the model caches the streaming context at different layers, the model +may end up forming indirect dependencies to audio many seconds ago. -| Release | Hyperparams file | Dev-clean Greedy | Test-clean Greedy | Test-other Greedy | Test-clean BS+RNNLM | Test-other BS+RNNLM | Model link | GPUs | -|:-------------:|:---------------------------:| :------:| :-----------:| :------------------:| :------------------:| :------------------:| :--------:| :-----------:| -| 2023-07-19 | conformer_transducer.yaml | 2.62 | 2.84 | 6.98 | 2.62 | 6.31 | https://drive.google.com/drive/folders/1QtQz1Bkd_QPYnf3CyxhJ57ovbSZC2EhN?usp=sharing | 3x3090 24GB | +| | full | cs=32 (1280ms) | 24 (960ms) | 16 (640ms) | 12 (480ms) | 8 (320ms) | +|:-----:|:----:|:-----:|:-----:|:-----:|:-----:|:-----:| +| full | 2.72%| - | - | - | - | - | +| lc=32 | - | 3.09% | 3.07% | 3.26% | 3.31% | 3.44% | +| 16 | - | 3.10% | 3.07% | 3.27% | 3.32% | 3.50% | +| 8 | - | 3.10% | 3.11% | 3.31% | 3.39% | 3.62% | +| 4 | - | 3.12% | 3.13% | 3.37% | 3.51% | 3.80% | +| 2 | - | 3.19% | 3.24% | 3.50% | 3.79% | 4.38% | # **About SpeechBrain** - Website: https://speechbrain.github.io/ diff --git a/recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml b/recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml index dad6b00e26..c7ad99c638 100644 --- a/recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml +++ b/recipes/LibriSpeech/ASR/transducer/hparams/conformer_transducer.yaml @@ -63,7 +63,7 @@ precision: fp32 # bf16, fp16 or fp32 batch_size: 8 grad_accumulation_factor: 4 sorting: random -avg_checkpoints: 1 # Number of checkpoints to average for evaluation +avg_checkpoints: 10 # Number of checkpoints to average for evaluation # Feature parameters sample_rate: 16000 @@ -71,6 +71,28 @@ n_fft: 512 n_mels: 80 win_length: 32 +# Streaming & dynamic chunk training options +# At least for the current architecture on LibriSpeech, we found out that +# non-streaming accuracy is very similar between `streaming: True` and +# `streaming: False`. +streaming: True # controls all Dynamic Chunk Training & chunk size & left context mechanisms + +# Configuration for Dynamic Chunk Training. +# In this model, a chunk is roughly equivalent to 40ms of audio. +dynchunktrain_config_sampler: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfigRandomSampler # yamllint disable-line rule:line-length + chunkwise_prob: 0.6 # Probability during a batch to limit attention and sample a random chunk size in the following range + chunk_size_min: 8 # Minimum chunk size (if in a DynChunkTrain batch) + chunk_size_max: 32 # Maximum chunk size (if in a DynChunkTrain batch) + limited_left_context_prob: 0.75 # If in a DynChunkTrain batch, the probability during a batch to restrict left context to a random number of chunks + left_context_chunks_min: 2 # Minimum left context size (in # of chunks) + left_context_chunks_max: 32 # Maximum left context size (in # of chunks) + # If you specify a valid/test config, you can optionally have evaluation be + # done with a specific DynChunkTrain configuration. + # valid_config: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfig + # chunk_size: 24 + # left_context_size: 16 + # test_config: ... + # Dataloader options train_dataloader_opts: batch_size: !ref diff --git a/recipes/LibriSpeech/ASR/transducer/train.py b/recipes/LibriSpeech/ASR/transducer/train.py index 319edaa7e8..4a0e889a41 100644 --- a/recipes/LibriSpeech/ASR/transducer/train.py +++ b/recipes/LibriSpeech/ASR/transducer/train.py @@ -58,7 +58,6 @@ def compute_forward(self, batch, stage): tokens_with_bos ) - # Forward pass feats = self.hparams.compute_features(wavs) # Add feature augmentation if specified. @@ -69,10 +68,25 @@ def compute_forward(self, batch, stage): ) current_epoch = self.hparams.epoch_counter.current + + # Old models may not have the streaming hparam, we don't break them in + # any other way so just check for its presence + if hasattr(self.hparams, "streaming") and self.hparams.streaming: + dynchunktrain_config = self.hparams.dynchunktrain_config_sampler( + stage + ) + else: + dynchunktrain_config = None + feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch) src = self.modules.CNN(feats) - x = self.modules.enc(src, wav_lens, pad_idx=self.hparams.pad_index) + x = self.modules.enc( + src, + wav_lens, + pad_idx=self.hparams.pad_index, + dynchunktrain_config=dynchunktrain_config, + ) x = self.modules.proj_enc(x) e_in = self.modules.emb(tokens_with_bos) @@ -270,7 +284,7 @@ def on_evaluate_start(self, max_key=None, min_key=None): super().on_evaluate_start() ckpts = self.checkpointer.find_checkpoints( - max_key=max_key, min_key=min_key + max_key=max_key, min_key=min_key, ) ckpt = sb.utils.checkpoints.average_checkpoints( ckpts, recoverable_name="model" diff --git a/speechbrain/decoders/transducer.py b/speechbrain/decoders/transducer.py index 9cc9db400c..b96eae9ff2 100644 --- a/speechbrain/decoders/transducer.py +++ b/speechbrain/decoders/transducer.py @@ -135,7 +135,9 @@ def forward(self, tn_output): hyps = self.searcher(tn_output) return hyps - def transducer_greedy_decode(self, tn_output): + def transducer_greedy_decode( + self, tn_output, hidden_state=None, return_hidden=False + ): """Transducer greedy decoder is a greedy decoder over batch which apply Transducer rules: 1- for each time step in the Transcription Network (TN) output: -> Update the ith utterance only if @@ -149,18 +151,43 @@ def transducer_greedy_decode(self, tn_output): Output from transcription network with shape [batch, time_len, hiddens]. + hidden_state : (torch.Tensor, torch.Tensor) + Hidden state to initially feed the decode network with. This is + useful in conjunction with `return_hidden` to be able to perform + beam search in a streaming context, so that you can reuse the last + hidden state as an initial state across calls. + + return_hidden : bool + Whether the return tuple should contain an extra 5th element with + the hidden state at of the last step. See `hidden_state`. + Returns ------- - torch.tensor + Tuple of 4 or 5 elements (if `return_hidden`). + + First element: List[List[int]] + List of decoded tokens + + Second element: torch.Tensor Outputs a logits tensor [B,T,1,Output_Dim]; padding has not been removed. + + Third element: None + nbest; irrelevant for greedy decode + + Fourth element: None + nbest scores; irrelevant for greedy decode + + Fifth element: Present if `return_hidden`, (torch.Tensor, torch.Tensor) + Tuple representing the hidden state required to call + `transducer_greedy_decode` where you left off in a streaming + context. """ hyp = { "prediction": [[] for _ in range(tn_output.size(0))], "logp_scores": [0.0 for _ in range(tn_output.size(0))], } # prepare BOS = Blank for the Prediction Network (PN) - hidden = None input_PN = ( torch.ones( (tn_output.size(0), 1), @@ -169,8 +196,13 @@ def transducer_greedy_decode(self, tn_output): ) * self.blank_id ) - # First forward-pass on PN - out_PN, hidden = self._forward_PN(input_PN, self.decode_network_lst) + + if hidden_state is None: + # First forward-pass on PN + out_PN, hidden = self._forward_PN(input_PN, self.decode_network_lst) + else: + out_PN, hidden = hidden_state + # For each time step for t_step in range(tn_output.size(1)): # do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden] @@ -210,13 +242,19 @@ def transducer_greedy_decode(self, tn_output): have_update_hyp, selected_hidden, hidden ) - return ( + ret = ( hyp["prediction"], torch.Tensor(hyp["logp_scores"]).exp().mean(), None, None, ) + if return_hidden: + # append the `(out_PN, hidden)` tuple to ret + ret += ((out_PN, hidden,),) + + return ret + def transducer_beam_search_decode(self, tn_output): """Transducer beam search decoder is a beam search decoder over batch which apply Transducer rules: 1- for each utterance: diff --git a/speechbrain/lobes/models/transformer/Branchformer.py b/speechbrain/lobes/models/transformer/Branchformer.py index 3f75701a34..c2dd297fc7 100644 --- a/speechbrain/lobes/models/transformer/Branchformer.py +++ b/speechbrain/lobes/models/transformer/Branchformer.py @@ -320,6 +320,7 @@ def forward( src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, pos_embs: Optional[torch.Tensor] = None, + dynchunktrain_config=None, ): """ Arguments @@ -335,6 +336,9 @@ def forward( If custom pos_embs are given it needs to have the shape (1, 2*S-1, E) where S is the sequence length, and E is the embedding dimension. """ + assert ( + dynchunktrain_config is None + ), "Dynamic Chunk Training unsupported for this encoder" if self.attention_type == "RelPosMHAXL": if pos_embs is None: diff --git a/speechbrain/lobes/models/transformer/Conformer.py b/speechbrain/lobes/models/transformer/Conformer.py index 009b87d932..9220583021 100755 --- a/speechbrain/lobes/models/transformer/Conformer.py +++ b/speechbrain/lobes/models/transformer/Conformer.py @@ -1,14 +1,19 @@ """Conformer implementation. Authors +------- * Jianyuan Zhong 2020 * Samuele Cornell 2021 +* Sylvain de Langen 2023 """ +from dataclasses import dataclass import torch import torch.nn as nn -from typing import Optional +import torch.nn.functional as F +from typing import Optional, List import speechbrain as sb +import math import warnings @@ -17,11 +22,51 @@ MultiheadAttention, PositionalwiseFeedForward, ) +from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig from speechbrain.lobes.models.transformer.hypermixing import HyperMixing from speechbrain.nnet.normalization import LayerNorm from speechbrain.nnet.activations import Swish +@dataclass +class ConformerEncoderLayerStreamingContext: + """Streaming metadata and state for a `ConformerEncoderLayer`. + + The multi-head attention and Dynamic Chunk Convolution require to save some + left context that gets inserted as left padding. + + See :class:`.ConvolutionModule` documentation for further details. + """ + + mha_left_context_size: int + """For this layer, specifies how many frames of inputs should be saved. + Usually, the same value is used across all layers, but this can be modified. + """ + + mha_left_context: Optional[torch.Tensor] = None + """Left context to insert at the left of the current chunk as inputs to the + multi-head attention. It can be `None` (if we're dealing with the first + chunk) or `<= mha_left_context_size` because for the first few chunks, not + enough left context may be available to pad. + """ + + dcconv_left_context: Optional[torch.Tensor] = None + """Left context to insert at the left of the convolution according to the + Dynamic Chunk Convolution method. + + Unlike `mha_left_context`, here the amount of frames to keep is fixed and + inferred from the kernel size of the convolution module. + """ + + +@dataclass +class ConformerEncoderStreamingContext: + """Streaming metadata and state for a `ConformerEncoder`.""" + + layers: List[ConformerEncoderLayerStreamingContext] + """Streaming metadata and state for each layer of the encoder.""" + + class ConvolutionModule(nn.Module): """This is an implementation of convolution module in Conformer. @@ -91,6 +136,9 @@ def __init__( bias=bias, ) + # NOTE: there appears to be a mismatch compared to the Conformer paper: + # I believe the first LayerNorm below is supposed to be a BatchNorm. + self.after_conv = nn.Sequential( nn.LayerNorm(input_size), activation(), @@ -99,20 +147,176 @@ def __init__( nn.Dropout(dropout), ) - def forward(self, x, mask=None): - """ Processes the input tensor x and returns the output an output tensor""" - out = self.layer_norm(x) - out = out.transpose(1, 2) - out = self.bottleneck(out) - out = self.conv(out) + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + dynchunktrain_config: Optional[DynChunkTrainConfig] = None, + ): + """Applies the convolution to an input tensor `x`. + + Arguments + --------- + x: torch.Tensor + Input tensor to the convolution module. + mask: torch.Tensor, optional + Mask to be applied over the output of the convolution using + `masked_fill_`, if specified. + dynchunktrain_config: DynChunkTrainConfig, optional + If specified, makes the module support Dynamic Chunk Convolution + (DCConv) as implemented by + `Dynamic Chunk Convolution for Unified Streaming and Non-Streaming Conformer ASR `_. + This allows masking future frames while preserving better accuracy + than a fully causal convolution, at a small speed cost. + This should only be used for training (or, if you know what you're + doing, for masked evaluation at inference time), as the forward + streaming function should be used at inference time. + """ + + if dynchunktrain_config is not None: + # chances are chunking+causal is unintended; i don't know where it + # may make sense, but if it does to you, feel free to implement it. + assert ( + not self.causal + ), "Chunked convolution not supported with causal padding" + + # in a causal convolution, which is not the case here, an output + # frame would never be able to depend on a input frame from any + # point in the future. + + # but with the dynamic chunk convolution, we instead use a "normal" + # convolution but where, for any output frame, the future beyond the + # "current" chunk gets masked. + # see the paper linked in the documentation for details. + + chunk_size = dynchunktrain_config.chunk_size + batch_size = x.shape[0] + chunk_left_context = self.padding + + chunk_count = int(math.ceil(x.shape[1] / chunk_size)) + + # determine the amount of padding we need to insert at the right of + # the last chunk so that all chunks end up with the same size. + if x.shape[1] % chunk_size != 0: + final_right_padding = chunk_size - (x.shape[1] % chunk_size) + else: + final_right_padding = 0 + + # compute the left context that can and should be added, for each + # chunk. for the first few chunks, we will need to add extra padding + applied_left_context = [ + min(chunk_left_context, i * chunk_size) + for i in range(chunk_count) + ] + + # build views of chunks with left context (but no 0-padding yet) + # the left context is treated as if it were left padding: we do not + # want to keep any convolution results centered on the left context + out = [ + x[ + :, + (i * chunk_size - applied_left_context[i]) : ( + (i + 1) * chunk_size + ), + ..., + ] + for i in range(chunk_count) + ] + + # TODO: it should be possible to insert some padding to stack all + # the tensors at this level. currently, this is rather inefficient + # as this as to be called on every individual chunk. + + out = [self.layer_norm(chk) for chk in out] + out = [chk.transpose(1, 2) for chk in out] + out = [self.bottleneck(chk) for chk in out] + out = [chk.transpose(1, 2) for chk in out] + + # TODO: experiment around reflect padding, which is difficult + # because small chunks have too little time steps to reflect from + + # pad zeroes manually along the time axis + out = [ + F.pad( + out[i], + ( + # last channel is the channel dim, so do not insert any + # padding at the start or end of that dimension + 0, + 0, + # add missing left 0-padding if we lacked left context + chunk_left_context - applied_left_context[i], + # add missing right 0-padding as we disable default padding + # also add missing frames of the rightmost chunk + self.padding + + (final_right_padding if i == len(out) - 1 else 0), + ), + value=0, + ) + for i in range(len(out)) + ] + + # we pack together chunks in a single tensor so that we can feed it + # to the convolution directly. this is much more performant than + # doing the same with lists. + + # -> [batch_size, num_chunks, chunk_size + lc + rpad, in_channels] + out = torch.stack(out, dim=1) + + # -> [batch_size * num_chunks, chunk_size + lc + rpad, in_channels] + out = torch.flatten(out, end_dim=1) + + # for the convolution: + # -> [batch_size * num_chunks, in_channels, chunk_size + lc + rpad] + out = out.transpose(1, 2) + + # let's keep backwards compat by pointing at the weights from the + # already declared Conv1d. + # in the prior steps, we manually applied: + # - left padding (known left context + zeroes if necessary) + # - right padding (zeroes) + # hence we're fully disabling conv1d's own padding. + # -> [batch_size * num_chunks, out_channels, chunk_size + rpad] + out = F.conv1d( + out, + weight=self.conv.weight, + bias=self.conv.bias, + stride=self.conv.stride, + padding=0, + dilation=self.conv.dilation, + groups=self.conv.groups, + ) + + # -> [batch_size * num_chunks, chunk_size + rpad, out_channels] + out = out.transpose(1, 2) + + out = self.after_conv(out) + + # -> [batch_size, num_chunks, chunk_size, out_channels] + out = torch.unflatten(out, dim=0, sizes=(batch_size, -1)) + + # -> [batch_size, time_steps + extra right padding, out_channels] + out = torch.flatten(out, start_dim=1, end_dim=2) + + # -> [batch_size, time_steps, out_channels] + if final_right_padding > 0: + out = out[:, :-final_right_padding, :] + else: + out = self.layer_norm(x) + out = out.transpose(1, 2) + out = self.bottleneck(out) + out = self.conv(out) + + if self.causal: + # chomp + out = out[..., : -self.padding] + + out = out.transpose(1, 2) + out = self.after_conv(out) - if self.causal: - # chomp - out = out[..., : -self.padding] - out = out.transpose(1, 2) - out = self.after_conv(out) if mask is not None: out.masked_fill_(mask, 0.0) + return out @@ -231,7 +435,8 @@ def forward( x, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, - pos_embs: Optional[torch.Tensor] = None, + pos_embs: torch.Tensor = None, + dynchunktrain_config: Optional[DynChunkTrainConfig] = None, ): """ Arguments @@ -244,8 +449,12 @@ def forward( The mask for the src keys per batch. pos_embs: torch.Tensor, torch.nn.Module, optional Module or tensor containing the input sequence positional embeddings + dynchunktrain_config: Optional[DynChunkTrainConfig] + Dynamic Chunk Training configuration object for streaming, + specifically involved here to apply Dynamic Chunk Convolution to + the convolution module. """ - conv_mask = None + conv_mask: Optional[torch.Tensor] = None if src_key_padding_mask is not None: conv_mask = src_key_padding_mask.unsqueeze(-1) # ffn module @@ -253,6 +462,7 @@ def forward( # muti-head attention module skip = x x = self.norm1(x) + x, self_attn = self.mha_layer( x, x, @@ -263,11 +473,100 @@ def forward( ) x = x + skip # convolution module - x = x + self.convolution_module(x, conv_mask) + x = x + self.convolution_module( + x, conv_mask, dynchunktrain_config=dynchunktrain_config + ) # ffn module x = self.norm2(x + 0.5 * self.ffn_module2(x)) return x, self_attn + def forward_streaming( + self, + x, + context: ConformerEncoderLayerStreamingContext, + pos_embs: torch.Tensor = None, + ): + """Conformer layer streaming forward (typically for + DynamicChunkTraining-trained models), which is to be used at inference + time. Relies on a mutable context object as initialized by + `make_streaming_context` that should be used across chunks. + Invoked by `ConformerEncoder.forward_streaming`. + + Arguments + --------- + x : torch.Tensor + Input tensor for this layer. Batching is supported as long as you + keep the context consistent. + context: ConformerEncoderStreamingContext + Mutable streaming context; the same object should be passed across + calls. + pos_embs: torch.Tensor, optional + Positional embeddings, if used.""" + + orig_len = x.shape[-2] + # ffn module + x = x + 0.5 * self.ffn_module1(x) + + # TODO: make the approach for MHA left context more efficient. + # currently, this saves the inputs to the MHA. + # the naive approach is suboptimal in a few ways, namely that the + # outputs for this left padding is being re-computed even though we + # discard them immediately after. + + # left pad `x` with our MHA left context + if context.mha_left_context is not None: + x = torch.cat((context.mha_left_context, x), dim=1) + + # compute new MHA left context for the next call to our function + if context.mha_left_context_size > 0: + context.mha_left_context = x[ + ..., -context.mha_left_context_size :, : + ] + + # multi-head attention module + skip = x + x = self.norm1(x) + + x, self_attn = self.mha_layer( + x, x, x, attn_mask=None, key_padding_mask=None, pos_embs=pos_embs, + ) + x = x + skip + + # truncate outputs corresponding to the MHA left context (we only care + # about our chunk's outputs); see above to-do + x = x[..., -orig_len:, :] + + if context.dcconv_left_context is not None: + x = torch.cat((context.dcconv_left_context, x), dim=1) + + # compute new DCConv left context for the next call to our function + context.dcconv_left_context = x[ + ..., -self.convolution_module.padding :, : + ] + + # convolution module + x = x + self.convolution_module(x) + + # truncate outputs corresponding to the DCConv left context + x = x[..., -orig_len:, :] + + # ffn module + x = self.norm2(x + 0.5 * self.ffn_module2(x)) + return x, self_attn + + def make_streaming_context(self, mha_left_context_size: int): + """Creates a blank streaming context for this encoding layer. + + Arguments + --------- + mha_left_context_size : int + How many left frames should be saved and used as left context to the + current chunk when streaming + """ + return ConformerEncoderLayerStreamingContext( + mha_left_context_size=mha_left_context_size + ) + class ConformerEncoder(nn.Module): """This class implements the Conformer encoder. @@ -355,6 +654,7 @@ def forward( src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, pos_embs: Optional[torch.Tensor] = None, + dynchunktrain_config: Optional[DynChunkTrainConfig] = None, ): """ Arguments @@ -369,8 +669,11 @@ def forward( Module or tensor containing the input sequence positional embeddings If custom pos_embs are given it needs to have the shape (1, 2*S-1, E) where S is the sequence length, and E is the embedding dimension. + dynchunktrain_config: Optional[DynChunkTrainConfig] + Dynamic Chunk Training configuration object for streaming, + specifically involved here to apply Dynamic Chunk Convolution to the + convolution module. """ - if self.attention_type == "RelPosMHAXL": if pos_embs is None: raise ValueError( @@ -385,12 +688,71 @@ def forward( src_mask=src_mask, src_key_padding_mask=src_key_padding_mask, pos_embs=pos_embs, + dynchunktrain_config=dynchunktrain_config, + ) + attention_lst.append(attention) + output = self.norm(output) + + return output, attention_lst + + def forward_streaming( + self, + src: torch.Tensor, + context: ConformerEncoderStreamingContext, + pos_embs: Optional[torch.Tensor] = None, + ): + """Conformer streaming forward (typically for + DynamicChunkTraining-trained models), which is to be used at inference + time. Relies on a mutable context object as initialized by + `make_streaming_context` that should be used across chunks. + + Arguments + --------- + src : torch.Tensor + Input tensor. Batching is supported as long as you keep the context + consistent. + context: ConformerEncoderStreamingContext + Mutable streaming context; the same object should be passed across + calls. + pos_embs: torch.Tensor, optional + Positional embeddings, if used.""" + + if self.attention_type == "RelPosMHAXL": + if pos_embs is None: + raise ValueError( + "The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory" + ) + + output = src + attention_lst = [] + for i, enc_layer in enumerate(self.layers): + output, attention = enc_layer.forward_streaming( + output, pos_embs=pos_embs, context=context.layers[i] ) attention_lst.append(attention) output = self.norm(output) return output, attention_lst + def make_streaming_context(self, mha_left_context_size: int): + """Creates a blank streaming context for the encoder. + + Arguments + --------- + mha_left_context_size : int + How many left frames should be saved and used as left context to the + current chunk when streaming. This value is replicated across all + layers. + """ + return ConformerEncoderStreamingContext( + layers=[ + layer.make_streaming_context( + mha_left_context_size=mha_left_context_size + ) + for layer in self.layers + ] + ) + class ConformerDecoderLayer(nn.Module): """This is an implementation of Conformer encoder layer. diff --git a/speechbrain/lobes/models/transformer/Transformer.py b/speechbrain/lobes/models/transformer/Transformer.py index 5b56d4f17b..52bb367cdb 100644 --- a/speechbrain/lobes/models/transformer/Transformer.py +++ b/speechbrain/lobes/models/transformer/Transformer.py @@ -530,6 +530,7 @@ def forward( src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, pos_embs: Optional[torch.Tensor] = None, + dynchunktrain_config=None, ): """ Arguments @@ -541,6 +542,10 @@ def forward( src_key_padding_mask : tensor The mask for the src keys per batch (optional). """ + assert ( + dynchunktrain_config is None + ), "Dynamic Chunk Training unsupported for this encoder" + output = src if self.layerdrop_prob > 0.0: keep_probs = self.rng.random(len(self.layers)) diff --git a/speechbrain/lobes/models/transformer/TransformerASR.py b/speechbrain/lobes/models/transformer/TransformerASR.py index c346240438..05a9d2e2be 100755 --- a/speechbrain/lobes/models/transformer/TransformerASR.py +++ b/speechbrain/lobes/models/transformer/TransformerASR.py @@ -4,9 +4,10 @@ * Jianyuan Zhong 2020 """ +from dataclasses import dataclass import torch # noqa 42 from torch import nn -from typing import Optional +from typing import Any, Optional from speechbrain.nnet.linear import Linear from speechbrain.nnet.containers import ModuleList from speechbrain.lobes.models.transformer.Transformer import ( @@ -17,6 +18,147 @@ ) from speechbrain.nnet.activations import Swish from speechbrain.dataio.dataio import length_to_mask +from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig + + +@dataclass +class TransformerASRStreamingContext: + """Streaming metadata and state for a `TransformerASR` instance.""" + + dynchunktrain_config: DynChunkTrainConfig + """Dynamic Chunk Training configuration holding chunk size and context size + information.""" + + encoder_context: Any + """Opaque encoder context information. It is constructed by the encoder's + `make_streaming_context` method and is passed to the encoder when using + `encode_streaming`. + """ + + +def make_transformer_src_mask( + src: torch.Tensor, + causal: bool = False, + dynchunktrain_config: Optional[DynChunkTrainConfig] = None, +) -> Optional[torch.Tensor]: + """Prepare the source transformer mask that restricts which frames can + attend to which frames depending on causal or other simple restricted + attention methods. + + Arguments + --------- + src: torch.Tensor + The source tensor to build a mask from. The contents of the tensor are + not actually used currently; only its shape and other metadata (e.g. + device). + + causal: bool + Whether strict causality shall be used. Frames will not be able to + attend to any future frame. + + dynchunktrain_config: DynChunkTrainConfig, optional + Dynamic Chunk Training configuration. This implements a simple form of + chunkwise attention. Incompatible with `causal`.""" + + if causal: + assert dynchunktrain_config is None + return get_lookahead_mask(src) + + if dynchunktrain_config is not None: + # init a mask that masks nothing by default + # 0 == no mask, 1 == mask + src_mask = torch.zeros( + (src.shape[1], src.shape[1]), device=src.device, dtype=torch.bool, + ) + + # The following is not really the sole source used to implement this, + # but it helps introduce the concept. + # ref: Unified Streaming and Non-streaming Two-pass End-to-end Model + # for Speech Recognition + # https://arxiv.org/pdf/2012.05481.pdf + + timesteps = src.size(1) + + # mask the future at the right of each chunk + for t in range(timesteps): + # if we have a chunk size of 8 then: + # for 0..7 -> mask 8.. + # for 8..15 -> mask 16.. + # etc. + next_chunk_index = (t // dynchunktrain_config.chunk_size) + 1 + visible_range = next_chunk_index * dynchunktrain_config.chunk_size + src_mask[t, visible_range:] = True + + # mask the past at the left of each chunk (accounting for left context) + # only relevant if using left context + if not dynchunktrain_config.is_infinite_left_context(): + for t in range(timesteps): + chunk_index = t // dynchunktrain_config.chunk_size + chunk_first_t = chunk_index * dynchunktrain_config.chunk_size + + left_context_frames = ( + dynchunktrain_config.left_context_size + * dynchunktrain_config.chunk_size + ) + + frame_remaining_context = max( + 0, chunk_first_t - left_context_frames, + ) + + # end range is exclusive, so there is no off-by-one here + src_mask[t, :frame_remaining_context] = True + + return src_mask + + return None + + +def make_transformer_src_tgt_masks( + src, + tgt=None, + wav_len=None, + pad_idx=0, + causal: bool = False, + dynchunktrain_config: Optional[DynChunkTrainConfig] = None, +): + """This function generates masks for training the transformer model, + opiniated for an ASR context with encoding masks and, optionally, decoding + masks (if specifying `tgt`). + + Arguments + --------- + src : tensor + The sequence to the encoder (required). + tgt : tensor + The sequence to the decoder. + pad_idx : int + The index for token (default=0). + causal: bool + Whether strict causality shall be used. See `make_asr_src_mask` + dynchunktrain_config: DynChunkTrainConfig, optional + Dynamic Chunk Training configuration. See `make_asr_src_mask` + """ + src_key_padding_mask = None + + # mask out audio beyond the length of audio for each batch + if wav_len is not None: + abs_len = torch.round(wav_len * src.shape[1]) + src_key_padding_mask = ~length_to_mask(abs_len).bool() + + # mask out the source + src_mask = make_transformer_src_mask( + src, causal=causal, dynchunktrain_config=dynchunktrain_config + ) + + # If no decoder in the transformer... + if tgt is not None: + tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx) + tgt_mask = get_lookahead_mask(tgt) + else: + tgt_key_padding_mask = None + tgt_mask = None + + return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask class TransformerASR(TransformerInterface): @@ -152,9 +294,11 @@ def __init__( ), torch.nn.Dropout(dropout), ) - self.custom_tgt_module = ModuleList( - NormalizedEmbedding(d_model, tgt_vocab) - ) + + if num_decoder_layers > 0: + self.custom_tgt_module = ModuleList( + NormalizedEmbedding(d_model, tgt_vocab) + ) # reset parameters using xavier_normal_ self._init_params() @@ -183,7 +327,9 @@ def forward(self, src, tgt, wav_len=None, pad_idx=0): tgt_key_padding_mask, src_mask, tgt_mask, - ) = self.make_masks(src, tgt, wav_len, pad_idx=pad_idx) + ) = make_transformer_src_tgt_masks( + src, tgt, wav_len, causal=self.causal, pad_idx=pad_idx + ) src = self.custom_src_module(src) # add pos encoding to queries if are sinusoidal ones else @@ -229,38 +375,6 @@ def forward(self, src, tgt, wav_len=None, pad_idx=0): return encoder_out, decoder_out - def make_masks(self, src, tgt=None, wav_len=None, pad_idx=0): - """This method generates the masks for training the transformer model. - - Arguments - --------- - src : tensor - The sequence to the encoder (required). - tgt : tensor - The sequence to the decoder. - pad_idx : int - The index for token (default=0). - """ - src_key_padding_mask = None - if wav_len is not None: - abs_len = torch.round(wav_len * src.shape[1]) - src_key_padding_mask = ~length_to_mask(abs_len).bool() - - src_mask = None - - if self.causal: - src_mask = get_lookahead_mask(src) - - # If no decoder in the transformer... - if tgt is not None: - tgt_key_padding_mask = get_key_padding_mask(tgt, pad_idx=pad_idx) - tgt_mask = get_lookahead_mask(tgt) - else: - tgt_key_padding_mask = None - tgt_mask = None - - return src_key_padding_mask, tgt_key_padding_mask, src_mask, tgt_mask - @torch.no_grad() def decode(self, tgt, encoder_out, enc_len=None): """This method implements a decoding step for the transformer model. @@ -302,7 +416,13 @@ def decode(self, tgt, encoder_out, enc_len=None): ) return prediction, multihead_attns[-1] - def encode(self, src, wav_len=None, pad_idx=0): + def encode( + self, + src, + wav_len=None, + pad_idx=0, + dynchunktrain_config: Optional[DynChunkTrainConfig] = None, + ): """ Encoder forward pass @@ -318,8 +438,18 @@ def encode(self, src, wav_len=None, pad_idx=0): bz, t, ch1, ch2 = src.shape src = src.reshape(bz, t, ch1 * ch2) - (src_key_padding_mask, _, src_mask, _,) = self.make_masks( - src, None, wav_len, pad_idx=pad_idx + ( + src_key_padding_mask, + _, + src_mask, + _, + ) = make_transformer_src_tgt_masks( + src, + None, + wav_len, + pad_idx=pad_idx, + causal=self.causal, + dynchunktrain_config=dynchunktrain_config, ) src = self.custom_src_module(src) @@ -333,11 +463,134 @@ def encode(self, src, wav_len=None, pad_idx=0): encoder_out, _ = self.encoder( src=src, + src_mask=src_mask, src_key_padding_mask=src_key_padding_mask, pos_embs=pos_embs_source, + dynchunktrain_config=dynchunktrain_config, + ) + + return encoder_out + + def encode_streaming(self, src, context: TransformerASRStreamingContext): + """ + Streaming encoder forward pass + + Arguments + --------- + src : torch.Tensor + The sequence (chunk) to the encoder. + + context : TransformerASRStreamingContext + Mutable reference to the streaming context. This holds the state + needed to persist across chunk inferences and can be built using + `make_streaming_context`. This will get mutated by this function. + + Returns + ------- + Encoder output for this chunk. + + Example + ------- + >>> import torch + >>> from speechbrain.lobes.models.transformer.TransformerASR import TransformerASR + >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig + >>> net = TransformerASR( + ... tgt_vocab=100, + ... input_size=64, + ... d_model=64, + ... nhead=8, + ... num_encoder_layers=1, + ... num_decoder_layers=0, + ... d_ffn=128, + ... attention_type="RelPosMHAXL", + ... positional_encoding=None, + ... encoder_module="conformer", + ... normalize_before=True, + ... causal=False, + ... ) + >>> ctx = net.make_streaming_context( + ... DynChunkTrainConfig(16, 24), + ... encoder_kwargs={"mha_left_context_size": 24}, + ... ) + >>> src1 = torch.rand([8, 16, 64]) + >>> src2 = torch.rand([8, 16, 64]) + >>> out1 = net.encode_streaming(src1, ctx) + >>> out1.shape + torch.Size([8, 16, 64]) + >>> ctx.encoder_context.layers[0].mha_left_context.shape + torch.Size([8, 16, 64]) + >>> out2 = net.encode_streaming(src2, ctx) + >>> out2.shape + torch.Size([8, 16, 64]) + >>> ctx.encoder_context.layers[0].mha_left_context.shape + torch.Size([8, 24, 64]) + >>> combined_out = torch.concat((out1, out2), dim=1) + >>> combined_out.shape + torch.Size([8, 32, 64]) + """ + + if src.dim() == 4: + bz, t, ch1, ch2 = src.shape + src = src.reshape(bz, t, ch1 * ch2) + + # HACK: our problem here is that the positional_encoding is computed + # against the size of our source tensor, but we only know how many left + # context frames we're injecting to the encoder within the encoder + # context. + # so this workaround does just that. + # + # i'm not sure how this would be best refactored, but an option would be + # to let the encoder get the pos embedding itself and have a way to + # cache it. + # + # additionally, positional encoding functions take in a whole source + # tensor just to get its attributes (size, device, type) but this is + # sort of silly for the embeddings that don't need one. + # so we craft a dummy empty (uninitialized) tensor to help... + known_left_context = context.encoder_context.layers[0].mha_left_context + if known_left_context is None: + pos_encoding_dummy = src + else: + target_shape = list(src.shape) + target_shape[-2] += known_left_context.shape[-2] + pos_encoding_dummy = torch.empty(size=target_shape).to(src) + + src = self.custom_src_module(src) + if self.attention_type == "RelPosMHAXL": + pos_embs_source = self.positional_encoding(pos_encoding_dummy) + + elif self.positional_encoding_type == "fixed_abs_sine": + src = src + self.positional_encoding(pos_encoding_dummy) + pos_embs_source = None + + encoder_out, _ = self.encoder.forward_streaming( + src=src, pos_embs=pos_embs_source, context=context.encoder_context ) return encoder_out + def make_streaming_context( + self, dynchunktrain_config: DynChunkTrainConfig, encoder_kwargs={} + ): + """Creates a blank streaming context for this transformer and its + encoder. + + Arguments + --------- + dynchunktrain_config : DynChunkTrainConfig + Runtime chunkwise attention configuration. + + encoder_kwargs : dict + Parameters to be forward to the encoder's `make_streaming_context`. + Metadata required for the encoder could differ depending on the + encoder. + """ + return TransformerASRStreamingContext( + dynchunktrain_config=dynchunktrain_config, + encoder_context=self.encoder.make_streaming_context( + **encoder_kwargs, + ), + ) + def _init_params(self): for p in self.parameters(): if p.dim() > 1: @@ -373,7 +626,7 @@ def __init__(self, transformer, *args, **kwargs): super().__init__(*args, **kwargs) self.transformer = transformer - def forward(self, x, wav_lens=None, pad_idx=0): + def forward(self, x, wav_lens=None, pad_idx=0, **kwargs): """ Processes the input tensor x and returns an output tensor.""" - x = self.transformer.encode(x, wav_lens, pad_idx) + x = self.transformer.encode(x, wav_lens, pad_idx, **kwargs,) return x diff --git a/speechbrain/nnet/attention.py b/speechbrain/nnet/attention.py index 528522f673..f6351dd690 100644 --- a/speechbrain/nnet/attention.py +++ b/speechbrain/nnet/attention.py @@ -591,17 +591,28 @@ def forward( query + self.pos_bias_v.view(1, 1, self.num_heads, self.head_dim) ).transpose(1, 2) + # Moved the `* self.scale` mul from after the `attn_score` sum to prior + # to the matmul in order to lower overflow risks on fp16. + # This change is inspired by the following paper, but no other changes + # were ported from there so far. + # ref: E.T.: Re-Thinking Self-Attention for Transformer Models on GPUs + # https://asherliu.github.io/docs/sc21a.pdf + # (batch, head, qlen, klen) - matrix_ac = torch.matmul(q_with_bias_u, key.permute(0, 2, 3, 1)) + matrix_ac = torch.matmul( + q_with_bias_u * self.scale, key.permute(0, 2, 3, 1) + ) # (batch, num_heads, klen, 2*klen-1) - matrix_bd = torch.matmul(q_with_bias_v, p_k.permute(0, 2, 3, 1)) + matrix_bd = torch.matmul( + q_with_bias_v * self.scale, p_k.permute(0, 2, 3, 1) + ) matrix_bd = self.rel_shift(matrix_bd) # shifting trick # if klen != qlen: # import ipdb # ipdb.set_trace( - attn_score = (matrix_ac + matrix_bd) * self.scale + attn_score = matrix_ac + matrix_bd # already scaled above # compute attention probability if attn_mask is not None: @@ -622,8 +633,25 @@ def forward( key_padding_mask.view(bsz, 1, 1, klen), self.attn_fill_value, ) - attn_score = F.softmax(attn_score, dim=-1) + attn_score = F.softmax(attn_score, dim=-1, dtype=torch.float32) attn_score = self.dropout_att(attn_score) + + # it is possible for us to hit full NaN when using chunked training + # so reapply masks, except with 0.0 instead as we are after the softmax + # because -inf would output 0.0 regardless anyway + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_score = attn_score.masked_fill(attn_mask, 0.0) + else: + # NOTE: the above fix is not implemented for this case as + # summing the mask with NaN would still result in NaN + pass + + if key_padding_mask is not None: + attn_score = attn_score.masked_fill( + key_padding_mask.view(bsz, 1, 1, klen), 0.0, + ) + x = torch.matmul( attn_score, value.transpose(1, 2) ) # (batch, head, time1, d_k) diff --git a/speechbrain/utils/dynamic_chunk_training.py b/speechbrain/utils/dynamic_chunk_training.py new file mode 100644 index 0000000000..5f0d8892d7 --- /dev/null +++ b/speechbrain/utils/dynamic_chunk_training.py @@ -0,0 +1,160 @@ +"""Configuration and utility classes for classes for Dynamic Chunk Training, as +often used for the training of streaming-capable models in speech recognition. + +The definition of Dynamic Chunk Training is based on that of the following +paper, though a lot of the literature refers to the same definition: +https://arxiv.org/abs/2012.05481 + +Authors +* Sylvain de Langen 2023 +""" + +import speechbrain as sb +from dataclasses import dataclass +from typing import Optional + +import torch + +# NOTE: this configuration object is intended to be relatively specific to +# Dynamic Chunk Training; if you want to implement a different similar type of +# chunking different from that, you should consider using a different object. +@dataclass +class DynChunkTrainConfig: + """Dynamic Chunk Training configuration object for use with transformers, + often in ASR for streaming. + + This object may be used both to configure masking at training time and for + run-time configuration of DynChunkTrain-ready models.""" + + chunk_size: int + """Size in frames of a single chunk, always `>0`. + If chunkwise streaming should be disabled at some point, pass an optional + streaming config parameter.""" + + left_context_size: Optional[int] = None + """Number of *chunks* (not frames) visible to the left, always `>=0`. + If zero, then chunks can never attend to any past chunk. + If `None`, the left context is infinite (but use + `.is_fininite_left_context` for such a check).""" + + def is_infinite_left_context(self) -> bool: + """Returns true if the left context is infinite (i.e. any chunk can + attend to any past frame).""" + return self.left_context_size is None + + +@dataclass +class DynChunkTrainConfigRandomSampler: + """Helper class to generate a DynChunkTrainConfig at runtime depending on the current + stage. + + Example + ------- + >>> from speechbrain.core import Stage + >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig + >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfigRandomSampler + >>> # for the purpose of this example, we test a scenario with a 100% + >>> # chance of the (24, None) scenario to occur + >>> sampler = DynChunkTrainConfigRandomSampler( + ... chunkwise_prob=1.0, + ... chunk_size_min=24, + ... chunk_size_max=24, + ... limited_left_context_prob=0.0, + ... left_context_chunks_min=16, + ... left_context_chunks_max=16, + ... test_config=DynChunkTrainConfig(32, 16), + ... valid_config=None + ... ) + >>> one_train_config = sampler(Stage.TRAIN) + >>> one_train_config + DynChunkTrainConfig(chunk_size=24, left_context_size=None) + >>> one_train_config.is_infinite_left_context() + True + >>> sampler(Stage.TEST) + DynChunkTrainConfig(chunk_size=32, left_context_size=16)""" + + chunkwise_prob: float + """When sampling (during `Stage.TRAIN`), the probability that a finite chunk + size will be used. + In the other case, any chunk can attend to the full past and future + context.""" + + chunk_size_min: int + """When sampling a random chunk size, the minimum chunk size that can be + picked.""" + + chunk_size_max: int + """When sampling a random chunk size, the maximum chunk size that can be + picked.""" + + limited_left_context_prob: float + """When sampling a random chunk size, the probability that the left context + will be limited. + In the other case, any chunk can attend to the full past context.""" + + left_context_chunks_min: int + """When sampling a random left context size, the minimum number of left + context chunks that can be picked.""" + + left_context_chunks_max: int + """When sampling a random left context size, the maximum number of left + context chunks that can be picked.""" + + test_config: Optional[DynChunkTrainConfig] = None + """The configuration that should be used for `Stage.TEST`. + When `None`, evaluation is done with full context (i.e. non-streaming).""" + + valid_config: Optional[DynChunkTrainConfig] = None + """The configuration that should be used for `Stage.VALID`. + When `None`, evaluation is done with full context (i.e. non-streaming).""" + + def _sample_bool(self, prob: float) -> bool: + """Samples a random boolean with a probability, in a way that depends on + PyTorch's RNG seed. + + Arguments + --------- + prob : float + Probability (0..1) to return True (False otherwise).""" + return torch.rand((1,)).item() < prob + + def __call__(self, stage: "sb.core.Stage") -> DynChunkTrainConfig: + """In training stage, samples a random DynChunkTrain configuration. + During validation or testing, returns the relevant configuration. + + Arguments + --------- + stage : speechbrain.core.Stage + Current stage of training or evaluation. + In training mode, a random DynChunkTrainConfig will be sampled + according to the specified probabilities and ranges. + During evaluation, the relevant DynChunkTrainConfig attribute will + be picked. + """ + if stage == sb.core.Stage.TRAIN: + # When training for streaming, for each batch, we have a + # `dynamic_chunk_prob` probability of sampling a chunk size + # between `dynamic_chunk_min` and `_max`, otherwise output + # frames can see anywhere in the future. + if self._sample_bool(self.chunkwise_prob): + chunk_size = torch.randint( + self.chunk_size_min, self.chunk_size_max + 1, (1,), + ).item() + + if self._sample_bool(self.limited_left_context_prob): + left_context_chunks = torch.randint( + self.left_context_chunks_min, + self.left_context_chunks_max + 1, + (1,), + ).item() + else: + left_context_chunks = None + + return DynChunkTrainConfig(chunk_size, left_context_chunks) + return None + elif stage == sb.core.Stage.TEST: + return self.test_config + elif stage == sb.core.Stage.VALID: + return self.valid_config + else: + raise AttributeError(f"Unsupported stage found {stage}") diff --git a/speechbrain/utils/streaming.py b/speechbrain/utils/streaming.py new file mode 100644 index 0000000000..336f6e197a --- /dev/null +++ b/speechbrain/utils/streaming.py @@ -0,0 +1,233 @@ +"""Utilities to assist with designing and training streaming models. + +Authors +* Sylvain de Langen 2023 +""" + +import math +import torch +from typing import Callable, List + + +def split_fixed_chunks( + x: torch.Tensor, chunk_size: int, dim: int = -1 +) -> List[torch.Tensor]: + """Split an input tensor `x` into a list of chunk tensors of size + `chunk_size` alongside dimension `dim`. + Useful for splitting up sequences with chunks of fixed sizes. + + If dimension `dim` cannot be evenly split by `chunk_size`, then the last + chunk will be smaller than `chunk_size`. + + Arguments + --------- + x : torch.Tensor + The tensor to split into chunks, typically a sequence or audio signal. + + chunk_size : int + The size of each chunk, i.e. the max size of each chunk on dimension + `dim`. + + dim : int + Dimension to split alongside of, typically the time dimension. + + Returns + ------- + List[torch.Tensor] + A chunk list of tensors, see description and example. + Guarantees `.size(dim) <= chunk_size`. + + Example + ------- + >>> import torch + >>> from speechbrain.utils.streaming import split_fixed_chunks + >>> x = torch.zeros((16, 10000, 80)) + >>> chunks = split_fixed_chunks(x, 128, dim=1) + >>> len(chunks) + 79 + >>> chunks[0].shape + torch.Size([16, 128, 80]) + >>> chunks[-1].shape + torch.Size([16, 16, 80]) + """ + + num_chunks = math.ceil(x.size(dim) / chunk_size) + split_at_indices = [(i + 1) * chunk_size for i in range(num_chunks - 1)] + return torch.tensor_split(x, split_at_indices, dim=1) + + +def split_wav_lens( + chunk_lens: List[int], wav_lens: torch.Tensor +) -> List[torch.Tensor]: + """Converts a single `wav_lens` tensor into a list of `chunk_count` tensors, + typically useful when chunking signals with `split_fixed_chunks`. + + `wav_lens` represents the relative length of each audio within a batch, + which is typically used for masking. This function computes the relative + length at chunk level. + + Arguments + --------- + chunk_lens : List[int] + Length of the sequence of every chunk. For example, if `chunks` was + returned from `split_fixed_chunks(x, chunk_size, dim=1)`, then this + should be `[chk.size(1) for chk in chunks]`. + + wav_lens : torch.Tensor + Relative lengths of audio within a batch. For example, for an input + signal of 100 frames and a batch of 3 elements, `(1.0, 0.5, 0.25)` + would mean the batch holds audio of 100 frames, 50 frames and 25 frames + respectively. + + Returns + ------- + List[torch.Tensor] + A list of chunked wav_lens, see description and example. + + Example + ------- + >>> import torch + >>> from speechbrain.utils.streaming import split_wav_lens, split_fixed_chunks + >>> x = torch.zeros((3, 20, 80)) + >>> chunks = split_fixed_chunks(x, 8, dim=1) + >>> len(chunks) + 3 + >>> # 20 frames, 13 frames, 17 frames + >>> wav_lens = torch.tensor([1.0, 0.65, 0.85]) + >>> chunked_wav_lens = split_wav_lens([c.size(1) for c in chunks], wav_lens) + >>> chunked_wav_lens + [tensor([1., 1., 1.]), tensor([1.0000, 0.6250, 1.0000]), tensor([1.0000, 0.0000, 0.2500])] + >>> # wav 1 covers 62.5% (5/8) of the second chunk's frames + """ + + chunk_wav_lens = [] + + seq_size = sum(chunk_lens) + wav_lens_frames = wav_lens * seq_size + + chunk_start_frame = 0 + for chunk_len in chunk_lens: + chunk_raw_len = (wav_lens_frames - chunk_start_frame) / chunk_len + chunk_raw_len = torch.clamp(chunk_raw_len, 0.0, 1.0) + chunk_wav_lens.append(chunk_raw_len) + + chunk_start_frame += chunk_len + + return chunk_wav_lens + + +def infer_dependency_matrix( + model: Callable, seq_shape: tuple, in_stride: int = 1 +): + """ + Randomizes parts of the input sequence several times in order to detect + dependencies between input frames and output frames, aka whether a given + output frame depends on a given input frame. + + This can prove useful to check whether a model behaves correctly in a + streaming context and does not contain accidental dependencies to future + frames that couldn't be known in a streaming scenario. + + Note that this can get very computationally expensive for very long + sequences. + + Furthermore, this expects inference to be fully deterministic, else false + dependencies may be found. This also means that the model must be in eval + mode, to inhibit things like dropout layers. + + Arguments + --------- + model : Callable + Can be a model or a function (potentially emulating streaming + functionality). Does not require to be a trained model, random weights + should usually suffice. + seq_shape : tuple + The function tries inferring by randomizing parts of the input sequence + in order to detect unwanted dependencies. + The shape is expected to look like `[batch_size, seq_len, num_feats]`, + where `batch_size` may be `1`. + in_stride : int + Consider only N-th input, for when the input sequences are very long + (e.g. raw audio) and the output is shorter (subsampled, filters, etc.) + + Returns + ------- + dependencies : torch.BoolTensor + Matrix representing whether an output is dependent on an input; index + using `[in_frame_idx, out_frame_idx]`. `True` indicates a detected + dependency. + """ + # TODO: document arguments + + bs, seq_len, feat_len = seq_shape + + base_seq = torch.rand(seq_shape) + with torch.no_grad(): + base_out = model(base_seq) + + if not model(base_seq).equal(base_out): + raise ValueError( + "Expected deterministic model, but inferring twice on the same " + "data yielded different results. Make sure that you use " + "`eval()` mode so that it does not include randomness." + ) + out_len, _out_feat_len = base_out.shape[1:] + + deps = torch.zeros( + ((seq_len + (in_stride - 1)) // in_stride, out_len), dtype=torch.bool + ) + + for in_frame_idx in range(0, seq_len, in_stride): + test_seq = base_seq.clone() + test_seq[:, in_frame_idx, :] = torch.rand(bs, feat_len) + + with torch.no_grad(): + test_out = model(test_seq) + + for out_frame_idx in range(out_len): + if not torch.allclose( + test_out[:, out_frame_idx, :], base_out[:, out_frame_idx, :] + ): + deps[in_frame_idx // in_stride][out_frame_idx] = True + + return deps + + +def plot_dependency_matrix(deps): + """ + Returns a matplotlib figure of a dependency matrix generated by + `infer_dependency_matrix`. + + At a given point, a red square indicates that a given output frame (y-axis) + was to depend on a given input frame (x-axis). + + For example, a fully red image means that all output frames were dependent + on all the history. This could be the case of a bidirectional RNN, or a + transformer model, for example. + + Arguments + --------- + deps : torch.BoolTensor + Matrix returned by `infer_dependency_matrix` or one in a compatible + format. + """ + import matplotlib.pyplot as plt + from matplotlib.colors import ListedColormap + + cmap = ListedColormap(["white", "red"]) + + fig, ax = plt.subplots() + + ax.pcolormesh( + torch.permute(deps, (1, 0)), + cmap=cmap, + vmin=False, + vmax=True, + edgecolors="gray", + linewidth=0.5, + ) + ax.set_title("Dependency plot") + ax.set_xlabel("in") + ax.set_ylabel("out") + ax.set_aspect("equal") + return fig diff --git a/tests/integration/ASR_ConformerTransducer_streaming/example_asr_conformertransducer_streaming_experiment.py b/tests/integration/ASR_ConformerTransducer_streaming/example_asr_conformertransducer_streaming_experiment.py new file mode 100644 index 0000000000..893c511397 --- /dev/null +++ b/tests/integration/ASR_ConformerTransducer_streaming/example_asr_conformertransducer_streaming_experiment.py @@ -0,0 +1,281 @@ +#!/usr/bin/env/python3 +"""This minimal example trains a RNNT-based speech recognizer on a tiny dataset. +The encoder is based on a Conformer model with the use of Dynamic Chunk Training + (with a Dynamic Chunk Convolution within the convolution modules) that predict +phonemes. A greedy search is used on top of the output probabilities. +Given the tiny dataset, the expected behavior is to overfit the training dataset +(with a validation performance that stays high). +""" +import pathlib +import speechbrain as sb +from hyperpyyaml import load_hyperpyyaml +import torch + + +class ConformerTransducerBrain(sb.Brain): + def compute_forward(self, batch, stage): + """Forward computations from the waveform batches to the output probabilities.""" + batch = batch.to(self.device) + wavs, wav_lens = batch.sig + phn_with_bos, phn_with_bos_lens = batch.phn_encoded_bos + + # Add waveform augmentation if specified. + if stage == sb.Stage.TRAIN: + if hasattr(self.hparams, "wav_augment"): + wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens) + phn_with_bos = self.hparams.wav_augment.replicate_labels( + phn_with_bos + ) + + feats = self.hparams.compute_features(wavs) + + # Add feature augmentation if specified. + if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"): + feats, fea_lens = self.hparams.fea_augment(feats, wav_lens) + phn_with_bos = self.hparams.fea_augment.replicate_labels( + phn_with_bos + ) + + current_epoch = self.hparams.epoch_counter.current + + # Old models may not have the streaming hparam, we don't break them in + # any other way so just check for its presence + if hasattr(self.hparams, "streaming") and self.hparams.streaming: + dynchunktrain_config = self.hparams.dynchunktrain_config_sampler( + stage + ) + else: + dynchunktrain_config = None + + feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch) + + src = self.modules.CNN(feats) + x = self.modules.enc( + src, + wav_lens, + pad_idx=self.hparams.pad_index, + dynchunktrain_config=dynchunktrain_config, + ) + x = self.modules.proj_enc(x) + + e_in = self.modules.emb(phn_with_bos) + e_in = torch.nn.functional.dropout( + e_in, + self.hparams.dec_emb_dropout, + training=(stage == sb.Stage.TRAIN), + ) + h, _ = self.modules.dec(e_in) + h = torch.nn.functional.dropout( + h, self.hparams.dec_dropout, training=(stage == sb.Stage.TRAIN) + ) + h = self.modules.proj_dec(h) + + # Joint network + # add labelseq_dim to the encoder tensor: [B,T,H_enc] => [B,T,1,H_enc] + # add timeseq_dim to the decoder tensor: [B,U,H_dec] => [B,1,U,H_dec] + joint = self.modules.Tjoint(x.unsqueeze(2), h.unsqueeze(1)) + + # Output layer for transducer log-probabilities + logits_transducer = self.modules.transducer_lin(joint) + + # Compute outputs + if stage == sb.Stage.TRAIN: + p_ctc = None + p_ce = None + + if self.hparams.ctc_weight > 0.0: + # Output layer for ctc log-probabilities + out_ctc = self.modules.proj_ctc(x) + p_ctc = self.hparams.log_softmax(out_ctc) + + if self.hparams.ce_weight > 0.0: + # Output layer for ctc log-probabilities + p_ce = self.modules.dec_lin(h) + p_ce = self.hparams.log_softmax(p_ce) + + return p_ctc, p_ce, logits_transducer, wav_lens + + best_hyps, scores, _, _ = self.hparams.Greedysearcher(x) + return logits_transducer, wav_lens, best_hyps + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss (Transducer+(CTC+NLL)) given predictions and targets.""" + + ids = batch.id + phn, phn_lens = batch.phn_encoded + phn_with_eos, phn_with_eos_lens = batch.phn_encoded_eos + + # Train returns 4 elements vs 3 for val and test + if len(predictions) == 4: + p_ctc, p_ce, logits_transducer, wav_lens = predictions + else: + logits_transducer, wav_lens, predicted_phn = predictions + + if stage == sb.Stage.TRAIN: + if hasattr(self.hparams, "wav_augment"): + phn = self.hparams.wav_augment.replicate_labels(phn) + phn_lens = self.hparams.wav_augment.replicate_labels(phn_lens) + phn_with_eos = self.hparams.wav_augment.replicate_labels( + phn_with_eos + ) + phn_with_eos_lens = self.hparams.wav_augment.replicate_labels( + phn_with_eos_lens + ) + if hasattr(self.hparams, "fea_augment"): + phn = self.hparams.fea_augment.replicate_labels(phn) + phn_lens = self.hparams.fea_augment.replicate_labels(phn_lens) + phn_with_eos = self.hparams.fea_augment.replicate_labels( + phn_with_eos + ) + phn_with_eos_lens = self.hparams.fea_augment.replicate_labels( + phn_with_eos_lens + ) + + if stage == sb.Stage.TRAIN: + CTC_loss = 0.0 + CE_loss = 0.0 + if p_ctc is not None: + CTC_loss = self.hparams.ctc_cost(p_ctc, phn, wav_lens, phn_lens) + if p_ce is not None: + CE_loss = self.hparams.ce_cost( + p_ce, phn_with_eos, length=phn_with_eos_lens + ) + loss_transducer = self.hparams.transducer_cost( + logits_transducer, phn, wav_lens, phn_lens + ) + loss = ( + self.hparams.ctc_weight * CTC_loss + + self.hparams.ce_weight * CE_loss + + (1 - (self.hparams.ctc_weight + self.hparams.ce_weight)) + * loss_transducer + ) + else: + loss = self.hparams.transducer_cost( + logits_transducer, phn, wav_lens, phn_lens + ) + + if stage != sb.Stage.TRAIN: + self.per_metrics.append( + ids, predicted_phn, phn, target_len=phn_lens + ) + + return loss + + def on_stage_start(self, stage, epoch=None): + "Gets called when a stage (either training, validation, test) starts." + if stage != sb.Stage.TRAIN: + self.per_metrics = self.hparams.per_stats() + + def on_stage_end(self, stage, stage_loss, epoch=None): + """Gets called at the end of a stage.""" + if stage == sb.Stage.TRAIN: + self.train_loss = stage_loss + if stage == sb.Stage.VALID and epoch is not None: + print("Epoch %d complete" % epoch) + print("Train loss: %.2f" % self.train_loss) + if stage != sb.Stage.TRAIN: + print(stage, "loss: %.2f" % stage_loss) + print(stage, "PER: %.2f" % self.per_metrics.summarize("error_rate")) + + +def data_prep(data_folder, hparams): + "Creates the datasets and their data processing pipelines." + + # 1. Declarations: + train_data = sb.dataio.dataset.DynamicItemDataset.from_json( + json_path=data_folder / "../annotation/ASR_train.json", + replacements={"data_root": data_folder}, + ) + valid_data = sb.dataio.dataset.DynamicItemDataset.from_json( + json_path=data_folder / "../annotation/ASR_dev.json", + replacements={"data_root": data_folder}, + ) + datasets = [train_data, valid_data] + label_encoder = sb.dataio.encoder.CTCTextEncoder() + label_encoder.expect_len(hparams["num_labels"]) + + # 2. Define audio pipeline: + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + sig = sb.dataio.dataio.read_audio(wav) + return sig + + sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) + + # 3. Define text pipeline: + @sb.utils.data_pipeline.takes("phn") + @sb.utils.data_pipeline.provides( + "phn_list", "phn_encoded", "phn_encoded_bos", "phn_encoded_eos" + ) + def text_pipeline(phn): + phn_list = phn.strip().split() + yield phn_list + phn_encoded = label_encoder.encode_sequence_torch(phn_list) + yield phn_encoded + phn_encoded_bos = label_encoder.prepend_bos_index(phn_encoded).long() + yield phn_encoded_bos + phn_encoded_eos = label_encoder.append_eos_index(phn_encoded).long() + yield phn_encoded_eos + + sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) + + # 3. Fit encoder: + # NOTE: In this minimal example, also update from valid data + label_encoder.insert_blank(index=hparams["blank_index"]) + label_encoder.insert_bos_eos( + bos_index=hparams["bos_index"], eos_label="" + ) + label_encoder.update_from_didataset(train_data, output_key="phn_list") + label_encoder.update_from_didataset(valid_data, output_key="phn_list") + + # 4. Set output: + sb.dataio.dataset.set_output_keys( + datasets, + ["id", "sig", "phn_encoded", "phn_encoded_bos", "phn_encoded_eos"], + ) + return train_data, valid_data, label_encoder + + +def main(device="cpu"): + experiment_dir = pathlib.Path(__file__).resolve().parent + hparams_file = experiment_dir / "hyperparams.yaml" + data_folder = "../../samples/ASR" + data_folder = (experiment_dir / data_folder).resolve() + + # Load model hyper parameters: + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin) + + # Dataset creation + train_data, valid_data, label_encoder = data_prep(data_folder, hparams) + + # Trainer initialization + transducer_brain = ConformerTransducerBrain( + hparams["modules"], + hparams["opt_class"], + hparams, + run_opts={"device": device}, + ) + + # Training/validation loop + transducer_brain.fit( + range(hparams["number_of_epochs"]), + train_data, + valid_data, + train_loader_kwargs=hparams["dataloader_options"], + valid_loader_kwargs=hparams["dataloader_options"], + ) + # Evaluation is run separately (now just evaluating on valid data) + transducer_brain.evaluate(valid_data) + + # Check that model overfits for integration test + assert transducer_brain.train_loss < 90.0 + + +if __name__ == "__main__": + main() + + +def test_error(device): + main(device) diff --git a/tests/integration/ASR_ConformerTransducer_streaming/hyperparams.yaml b/tests/integration/ASR_ConformerTransducer_streaming/hyperparams.yaml new file mode 100644 index 0000000000..a988432a34 --- /dev/null +++ b/tests/integration/ASR_ConformerTransducer_streaming/hyperparams.yaml @@ -0,0 +1,268 @@ +# Seed needs to be set at top of yaml, before objects with parameters are made +seed: 3407 +__set_seed: !!python/object/apply:torch.manual_seed [!ref ] + +# Training parameters +# To make Transformers converge, the global bath size should be large enough. +# The global batch size is computed as batch_size * n_gpus * grad_accumulation_factor. +# Empirically, we found that this value should be >= 128. +# Please, set your parameters accordingly. +number_of_epochs: 30 +lr: 1.0 +ctc_weight: 0.3 # Multitask with CTC for the encoder (0.0 = disabled) +ce_weight: 0.0 # Multitask with CE for the decoder (0.0 = disabled) +max_grad_norm: 5.0 +loss_reduction: 'batchmean' +precision: fp32 # bf16, fp16 or fp32 + +# Feature parameters +sample_rate: 16000 +n_fft: 512 +n_mels: 80 +win_length: 32 + +# Streaming & dynamic chunk training options +# At least for the current architecture on LibriSpeech, we found out that +# non-streaming accuracy is very similar between `streaming: True` and +# `streaming: False`. +streaming: True # controls all Dynamic Chunk Training & chunk size & left context mechanisms + +# Configuration for Dynamic Chunk Training. +# In this model, a chunk is roughly equivalent to 40ms of audio. +dynchunktrain_config_sampler: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfigRandomSampler # yamllint disable-line rule:line-length + chunkwise_prob: 0.6 # Probability during a batch to limit attention and sample a random chunk size in the following range + chunk_size_min: 2 # Minimum chunk size (if in a DynChunkTrain batch) + chunk_size_max: 8 # Maximum chunk size (if in a DynChunkTrain batch) + limited_left_context_prob: 0.75 # If in a DynChunkTrain batch, the probability during a batch to restrict left context to a random number of chunks + left_context_chunks_min: 1 # Minimum left context size (in # of chunks) + left_context_chunks_max: 8 # Maximum left context size (in # of chunks) + # If you specify a valid/test config, you can optionally have evaluation be + # done with a specific DynChunkTrain configuration. + # valid_config: !new:speechbrain.utils.dynamic_chunk_training.DynChunkTrainConfig + # chunk_size: 24 + # left_context_size: 16 + # test_config: ... + +dataloader_options: + batch_size: 1 + +# Model parameters +# Transformer +d_model: 64 +joint_dim: 128 +nhead: 2 +num_encoder_layers: 1 +num_decoder_layers: 0 +d_ffn: 128 +transformer_dropout: 0.1 +activation: !name:torch.nn.GELU +output_neurons: !ref +dec_dim: 128 +dec_emb_dropout: 0.2 +dec_dropout: 0.1 + +# Decoding parameters +# Special tokens and labels +blank_index: 0 +bos_index: 1 +pad_index: 1 +num_labels: 45 +beam_size: 10 +nbest: 1 + +# If True uses torchaudio loss. Otherwise, the numba one +use_torchaudio: True + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +normalize: !new:speechbrain.processing.features.InputNormalization + norm_type: global + update_until_epoch: 4 + +compute_features: !new:speechbrain.lobes.features.Fbank + sample_rate: !ref + n_fft: !ref + n_mels: !ref + win_length: !ref + +# Speed perturbation +speed_changes: [95, 100, 105] # List of speed changes for time-stretching +speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb + orig_freq: !ref + speeds: !ref + +# Augmenter: Combines previously defined augmentations to perform data augmentation +wav_augment: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 1 + max_augmentations: 1 + augment_prob: 1.0 + augmentations: [!ref ] + + +# Time Drop +time_drop_length_low: 15 # Min length for temporal chunk to drop in spectrogram +time_drop_length_high: 25 # Max length for temporal chunk to drop in spectrogram +time_drop_count_low: 5 # Min number of chunks to drop in time in the spectrogram +time_drop_count_high: 5 # Max number of chunks to drop in time in the spectrogram +time_drop_replace: "zeros" # Method of dropping chunks + +time_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop + drop_length_low: !ref + drop_length_high: !ref + drop_count_low: !ref + drop_count_high: !ref + replace: !ref + dim: 1 + +# Frequency Drop +freq_drop_length_low: 25 # Min length for chunks to drop in frequency in the spectrogram +freq_drop_length_high: 35 # Max length for chunks to drop in frequency in the spectrogram +freq_drop_count_low: 2 # Min number of chunks to drop in frequency in the spectrogram +freq_drop_count_high: 2 # Max number of chunks to drop in frequency in the spectrogram +freq_drop_replace: "zeros" # Method of dropping chunks + +freq_drop: !new:speechbrain.augment.freq_domain.SpectrogramDrop + drop_length_low: !ref + drop_length_high: !ref + drop_count_low: !ref + drop_count_high: !ref + replace: !ref + dim: 2 + +# Time warp +time_warp_window: 5 # Length of time warping window +time_warp_mode: "bicubic" # Time warping method + +time_warp: !new:speechbrain.augment.freq_domain.Warping + warp_window: !ref + warp_mode: !ref + dim: 1 + +fea_augment: !new:speechbrain.augment.augmenter.Augmenter + parallel_augment: False + concat_original: False + repeat_augment: 1 + shuffle_augmentations: False + min_augmentations: 3 + max_augmentations: 3 + augment_prob: 1.0 + augmentations: [ + !ref , + !ref , + !ref ] + +CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd + input_shape: (8, 10, 80) + num_blocks: 2 + num_layers_per_block: 1 + out_channels: (64, 32) + kernel_sizes: (3, 3) + strides: (2, 2) + residuals: (False, False) + +Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length + input_size: 640 + tgt_vocab: !ref + d_model: !ref + nhead: !ref + num_encoder_layers: !ref + num_decoder_layers: !ref + d_ffn: !ref + dropout: !ref + activation: !ref + encoder_module: conformer + attention_type: RelPosMHAXL + normalize_before: True + causal: False + +# We must call an encoder wrapper so the decoder isn't run (we don't have any) +enc: !new:speechbrain.lobes.models.transformer.TransformerASR.EncoderWrapper + transformer: !ref + +# For MTL CTC over the encoder +proj_ctc: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + +# Define some projection layers to make sure that enc and dec +# output dim are the same before joining +proj_enc: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + bias: False + +proj_dec: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + bias: False + +# Uncomment for MTL with CTC +ctc_cost: !name:speechbrain.nnet.losses.ctc_loss + blank_index: !ref + reduction: !ref + +emb: !new:speechbrain.nnet.embedding.Embedding + num_embeddings: !ref + consider_as_one_hot: True + blank_id: !ref + +dec: !new:speechbrain.nnet.RNN.LSTM + input_shape: [null, null, !ref - 1] + hidden_size: !ref + num_layers: 1 + re_init: True + +# For MTL +ce_cost: !name:speechbrain.nnet.losses.nll_loss + label_smoothing: 0.1 + +Tjoint: !new:speechbrain.nnet.transducer.transducer_joint.Transducer_joint + joint: sum # joint [sum | concat] + nonlinearity: !ref + +transducer_lin: !new:speechbrain.nnet.linear.Linear + input_size: !ref + n_neurons: !ref + bias: False + +log_softmax: !new:speechbrain.nnet.activations.Softmax + apply_log: True + +transducer_cost: !name:speechbrain.nnet.losses.transducer_loss + blank_index: !ref + use_torchaudio: !ref + +modules: + CNN: !ref + enc: !ref + emb: !ref + dec: !ref + Tjoint: !ref + transducer_lin: !ref + normalize: !ref + proj_ctc: !ref + proj_dec: !ref + proj_enc: !ref + +model: !new:torch.nn.ModuleList + - [!ref , !ref , !ref , !ref , !ref , !ref , !ref , !ref ] + +Greedysearcher: !new:speechbrain.decoders.transducer.TransducerBeamSearcher + decode_network_lst: [!ref , !ref , !ref ] + tjoint: !ref + classifier_network: [!ref ] + blank_id: !ref + beam_size: 1 + nbest: 1 + +opt_class: !name:torch.optim.Adadelta + lr: !ref + +error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats + +per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats diff --git a/tests/unittests/test_dynamic_chunk_training.py b/tests/unittests/test_dynamic_chunk_training.py new file mode 100644 index 0000000000..6c74d10d1b --- /dev/null +++ b/tests/unittests/test_dynamic_chunk_training.py @@ -0,0 +1,38 @@ +def test_dynchunktrain_sampler(): + from speechbrain.core import Stage + from speechbrain.utils.dynamic_chunk_training import ( + DynChunkTrainConfig, + DynChunkTrainConfigRandomSampler, + ) + + # sanity check and cover for the random smapler + + valid_cfg = DynChunkTrainConfig(16, 32) + test_cfg = DynChunkTrainConfig(16, 32) + + sampler = DynChunkTrainConfigRandomSampler( + chunkwise_prob=1.0, + chunk_size_min=8, + chunk_size_max=8, + limited_left_context_prob=1.0, + left_context_chunks_min=16, + left_context_chunks_max=16, + test_config=valid_cfg, + valid_config=test_cfg, + ) + + sampled_train_config = sampler(Stage.TRAIN) + assert sampled_train_config.chunk_size == 8 + assert sampled_train_config.left_context_size == 16 + + assert sampler(Stage.VALID) == valid_cfg + assert sampler(Stage.TEST) == test_cfg + + +def test_dynchunktrain(): + from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig + + assert DynChunkTrainConfig(chunk_size=16).is_infinite_left_context() + assert not DynChunkTrainConfig( + chunk_size=16, left_context_size=4 + ).is_infinite_left_context()