Skip to content

Commit

Permalink
Add ASR CTC streaming example (huggingface#15309)
Browse files Browse the repository at this point in the history
* Single-epoch run

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Infinite dataset

* Trainer fix + distributed benchmark

* Benchmark fix

* unused import

* interleaved splits

* interleaved splits

* has_length util

* Move to research projects

* Leftover Sized checks

* Bump min version

* Unused import

* Revert trainer changes

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
2 people authored and Steven committed Feb 18, 2022
1 parent 9efd777 commit 8c0ff32
Show file tree
Hide file tree
Showing 5 changed files with 721 additions and 4 deletions.
57 changes: 57 additions & 0 deletions examples/pytorch/speech-recognition/README.md
Expand Up @@ -127,6 +127,62 @@ python -m torch.distributed.launch \
On 8 V100 GPUs, this script should run in *ca.* 18 minutes and yield a CTC loss of **0.39** and word error rate
of **0.36**.


### Multi GPU CTC with Dataset Streaming

The following command shows how to use [Dataset Streaming mode](https://huggingface.co/docs/datasets/dataset_streaming.html)
to fine-tune [XLS-R](https://huggingface.co/transformers/master/model_doc/xls_r.html)
on [Common Voice](https://huggingface.co/datasets/common_voice) using 4 GPUs in half-precision.

Streaming mode imposes several constraints on training:
1. We need to construct a tokenizer beforehand and define it via `--tokenizer_name_or_path`.
2. `--num_train_epochs` has to be replaced by `--max_steps`. Similarly, all other epoch-based arguments have to be
replaced by step-based ones.
3. Full dataset shuffling on each epoch is not possible, since we don't have the whole dataset available at once.
However, the `--shuffle_buffer_size` argument controls how many examples we can pre-download before shuffling them.


```bash
**python -m torch.distributed.launch \
--nproc_per_node 4 run_speech_recognition_ctc_streaming.py \
--dataset_name="common_voice" \
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
--tokenizer_name_or_path="anton-l/wav2vec2-tokenizer-turkish" \
--dataset_config_name="tr" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--output_dir="wav2vec2-xls-r-common_voice-tr-ft" \
--overwrite_output_dir \
--max_steps="5000" \
--per_device_train_batch_size="8" \
--gradient_accumulation_steps="2" \
--learning_rate="5e-4" \
--warmup_steps="500" \
--evaluation_strategy="steps" \
--text_column_name="sentence" \
--save_steps="500" \
--eval_steps="500" \
--logging_steps="1" \
--layerdrop="0.0" \
--eval_metrics wer cer \
--save_total_limit="1" \
--mask_time_prob="0.3" \
--mask_time_length="10" \
--mask_feature_prob="0.1" \
--mask_feature_length="64" \
--freeze_feature_encoder \
--chars_to_ignore , ? . ! - \; \: \" “ % ‘ ” � \
--max_duration_in_seconds="20" \
--shuffle_buffer_size="500" \
--fp16 \
--push_to_hub \
--do_train --do_eval \
--gradient_checkpointing**
```

On 4 V100 GPUs, this script should run in *ca.* 3h 31min and yield a CTC loss of **0.35** and word error rate
of **0.29**.

### Examples CTC

The following tables present a couple of example runs on the most popular speech-recognition datasets.
Expand Down Expand Up @@ -175,6 +231,7 @@ they can serve as a baseline to improve upon.
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) | 0.35 | - | 1 GPU V100 | 1h20min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-demo) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-demo/blob/main/run.sh) |
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.31 | - | 8 GPU V100 | 1h05 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft/blob/main/run.sh) |
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-1b](https://huggingface.co/facebook/wav2vec2-xls-r-1b) | 0.21 | - | 2 GPU Titan 24 GB RAM | 15h10 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xls-r-1b-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-1b-common_voice-tr-ft/blob/main/run.sh) |
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` in streaming mode | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.29 | - | 4 GPU V100 | 3h31 | [here](https://huggingface.co/anton-l/wav2vec2-xls-r-common_voice-tr-ft-stream) | [run.sh](https://huggingface.co/anton-l/wav2vec2-xls-r-common_voice-tr-ft-stream/blob/main/run.sh) |


#### Multilingual Librispeech CTC
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/speech-recognition/requirements.txt
@@ -1,4 +1,4 @@
datasets >= 1.13.3
datasets >= 1.18.0
torch >= 1.5
torchaudio
librosa
Expand Down
Expand Up @@ -51,7 +51,7 @@
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.17.0.dev0")

require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -146,7 +146,8 @@ class DataTrainingArguments:
train_split_name: str = field(
default="train+validation",
metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train+validation'"
"help": "The name of the training data set split to use (via the datasets library). Defaults to "
"'train+validation'"
},
)
eval_split_name: str = field(
Expand Down
Expand Up @@ -49,7 +49,7 @@
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.17.0.dev0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

logger = logging.getLogger(__name__)

Expand Down

0 comments on commit 8c0ff32

Please sign in to comment.