# 1 Transformer in Audio

For audio tasks, the input and/or output sequences may be audio instead of text:

- **Automatic speech recognition (ASR)**: The input is speech, the output is text.
- **Speech synthesis (TTS)**: The input is text, the output is speech.
- **Audio classification**: The input is audio, the output is a class probability — one for each element in the sequence or a single class probability for the entire sequence.
- **Voice conversion or speech enhancement**: Both the input and output are audio.

There are a few different ways to handle audio so it can be used with a transformer. The main consideration is whether to use the audio in its raw form — as a waveform — or to process it as a spectrogram instead.

![transformers_in_audio](../images/03-transformers_in_audio.png)

## 1.1 transformer input

- text input:
  - The input text is first tokenized, giving a sequence of text tokens.
  - This sequence is sent through an input embedding layer to convert the tokens into 512-dimensional vectors.
  - Those embedding vectors are then passed into the transformer encoder.

- waveform input:
  - After normalizing, the sequence of audio samples is turned into an embedding using a small convolutional neural network, known as the feature encoder.
 
- Spectrogram input
 
In both cases, waveform as well as spectrogram input, there is a small network in front of the transformer that converts the input into embeddings and then the transformer takes over to do its thing.

## 1.2 transformer output

- text output
- Spectrogram output
- waveform output

# 2 difference structure for transformer

## 2.1 CTC model: alignment of audio inputs the text outputs

CTC or Connectionist Temporal Classification is a technique that is used with **encoder-only** transformer models for automatic speech recognition. Examples of such models are Wav2Vec2, HuBERT and M-CTC-T.

Here’s the rub: In speech, we don’t know the **alignment** of the audio inputs and text outputs. We know that the order the speech is spoken in is the same as the order that the text is transcribed in (the alignment is so-called monotonic), but we don’t know how the characters in the transcription line up to the audio. This is where the CTC algorithm comes in.

## 2.2 seq2seq model

The CTC models discussed in the previous section used only the encoder part of the transformer architecture. When we also add the decoder to create an encoder-decoder model, this is referred to as a sequence-to-sequence model or seq2seq for short. The model maps a sequence of one kind of data to a sequence of another kind of data.

With encoder-only transformer models, the encoder made a prediction for each element in the input sequence. Therefore, both input and output sequences will always have the same length. In the case of CTC models such as Wav2Vec2 the input waveform was first downsampled, but there still was one prediction for every 20 ms of audio.

With a seq2seq model, there is no such one-to-one correspondence and the input and output sequences can have different lengths. That makes seq2seq models suitable for NLP tasks such as text summarization or translation between different languages — but also for audio tasks such as speech recognition.

## 2.3 audio classification

The goal of audio classification is to predict a class label for an audio input. The model can predict a single class label that covers the entire input sequence, or it can predict a label for every audio frame — typically every 20 milliseconds of input audio — in which case the model’s output is a sequence of class label probabilities. An example of the former is detecting what bird is making a particular sound; an example of the latter is speaker diarization, where the model predicts which speaker is speaking at any given moment.

# 3 ASR

## 3.1 Pre-trained models and datasets for audio classification

### 3.1.1 CTC model

Prior to 2022, CTC was the more popular of the two architectures, with encoder-only models such as Wav2Vec2, HuBERT and XLSR achieving breakthoughs in the pre-training / fine-tuning paradigm for speech. 

Comparing the target text to the predicted transcription, we can see that all words sound correct, but some are not spelled accurately. For example:

- CHRISTMAUS vs. CHRISTMAS
- ROSE vs. ROAST
- SIMALYIS vs. SIMILES

This highlights the shortcoming of a CTC model. A CTC model is essentially an `acoustic-only` model: it consists of an encoder which forms hidden-state representations from the audio inputs, and a linear layer which maps the hidden-states to characters.

This means that the system almost entirely bases its prediction on the acoustic input it was given (the phonetic sounds of the audio), and so has a tendency to transcribe the audio in a phonetic way (e.g. CHRISTMAUS). 

It gives less importance to the language modelling context of previous and successive letters, and so is prone to phonetic spelling errors. 

### 3.1.2 Seq2Seq model

In particular, the need for large amounts of training data has been a bottleneck in the advancement of Seq2Seq architectures for speech. Labelled speech data is difficult to come by, with the largest annotated datasets at the time clocking in at just 10,000 hours. This all changed in 2022 upon the release of Whisper.

```python
from datasets import load_dataset

dataset = load_dataset(
    "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
dataset
"""
Dataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_rows: 73
})
"""

import torch
from transformers import pipeline

device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipe = pipeline(
    "automatic-speech-recognition", model="openai/whisper-base", device=device
)

sample = dataset[2]

pipe(sample["audio"], max_new_tokens=256)
```

## 3.2 Long-Form Transcription and Timestamps

### 3.2.1 concate a long audio

```python
import numpy as np

target_length_in_m = 5

# 将分钟数转为秒数（* 60），再转为采用点数（* sampling rate）
sampling_rate = pipe.feature_extractor.sampling_rate
target_length_in_samples = target_length_in_m * 60 * sampling_rate

# 遍历数据集，拼接样本直到长度达到目标
long_audio = []
for sample in dataset:
    long_audio.extend(sample["audio"]["array"])
    if len(long_audio) > target_length_in_samples:
        break

long_audio = np.asarray(long_audio)

# 结果如何？
seconds = len(long_audio) / 16000
minutes, seconds = divmod(seconds, 60)
print(f"Length of audio sample is {minutes} minutes {seconds:.2f} seconds")
# Length of audio sample is 5.0 minutes 17.22 seconds
```

```python
pipe(
    long_audio,
    max_new_tokens=256,
    generate_kwargs={"task": "transcribe"},
    chunk_length_s=30,
    batch_size=8,
)
```

### 3.2.2 add timestamp

```python
pipe(
    long_audio,
    max_new_tokens=256,
    generate_kwargs={"task": "transcribe"},
    chunk_length_s=30,
    batch_size=8,
    return_timestamps=True,
)["chunks"]
```

### 3.2.3 about low-resource languages

While the Whisper model performs extremely well on many high-resource languages, it has lower transcription and translation accuracy on low-resource languages, i.e. those with less readily available training data. There is also varying performance across different accents and dialects of certain languages, including lower accuracy for speakers of different genders, races, ages or other demographic criteria (c.f. Whisper paper).

To boost the performance on low-resource languages, accents or dialects, we can take the pre-trained Whisper model and train it on a small corpus of appropriately selected data, in a process called fine-tuning. We’ll show that with as little as ten hours of additional data, we can improve the performance of the Whisper model by over 100% on a low-resource language. In the next section, we’ll cover the process behind selecting a dataset for fine-tuning.

## 3.3 choose dataset

Features of speech datasets
1. Number of hours
2. Domain
3. Speaking style
4. Transcription style

To know if our fine-tuned model is any good, we’ll need a rigorous way of evaluating it on unseen data and measuring its transcription accuracy. We’ll cover exactly this in the next section!

## 3.4 Evaluation metrics for ASR

If you’re familiar with the [Levenshtein distance](https://en.wikipedia.org/wiki/Levenshtein_distance) from NLP, the metrics for assessing speech recognition systems will be familiar!

When assessing speech recognition systems, we compare the system’s predictions to the target text transcriptions, annotating any errors that are present. We categorise these errors into one of three categories:

- Substitutions (S): where we transcribe the wrong word in our prediction (“sit” instead of “sat”)
- Insertions (I): where we add an extra word in our prediction
- Deletions (D): where we remove a word in our prediction

These error categories are the same for all speech recognition metrics. What differs is the level at which we compute these errors: we can either compute them on the word level or on the character level.

### 3.4.1 word level: WER

The word error rate (WER) metric is the ‘de facto’ metric for speech recognition.

```python
# pip install --upgrade evaluate jiwer

from evaluate import load

wer_metric = load("wer")

wer = wer_metric.compute(references=[reference], predictions=[prediction])

print(wer)
# 0.3333333333333333
```

Since the WER is the ratio of errors to number of words (N), there is no upper limit on the WER!

### 3.4.2 word level: WAcc

Word Accuracy: We can flip the WER around to give us a metric where higher is better. Rather than measuring the word error rate, we can measure the word accuracy (WAcc) of our system

### 3.4.3 character level: CER

Character Error Rate:  The character error rate (CER) assesses systems on the character level. This means we divide up our words into their individual characters, and annotate errors on a character-by-character basis.

### 3.4.5 Which metric should I use

In general, the WER is used far more than the CER for assessing speech systems. This is because the WER requires systems to have greater understanding of the context of the predictions. 

But, Certain languages, such as Mandarin and Japanese, have no notion of ‘words’, and so the WER is meaningless. Here, we revert to using the CER.

### 3.4.6 Normalization and orthographic

the predicted transcriptions will be fully formatted with casing and punctuation, a style referred to as orthographic.

There is a happy medium between normalising and not normalising: we can train our systems on orthographic transcriptions, and then normalise the predictions and targets before computing the WER.

This way, we train our systems to predict fully formatted text, but also benefit from the WER improvements we get by normalising the transcriptions.


In [1]:
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()

reference = "HE TELLS US THAT AT THIS FESTIVE SEASON OF THE YEAR WITH CHRISTMAS AND ROAST BEEF LOOMING BEFORE US SIMILES DRAWN FROM EATING AND ITS RESULTS OCCUR MOST READILY TO THE MIND"
normalized_referece = normalizer(reference)
normalized_referece

'he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind'

## 3.5 evaluate whisper-samll in 'dv'

```python
# 1 prepare pipeline
# ====================

from transformers import pipeline
import torch

if torch.cuda.is_available():
    device = "cuda:0"
    torch_dtype = torch.float16
else:
    device = "cpu"
    torch_dtype = torch.float32

pipe = pipeline(
    "automatic-speech-recognition",
    model="openai/whisper-small",
    torch_dtype=torch_dtype,
    device=device,
)

# 2 prepare dataset
# ====================

from huggingface_hub import notebook_login

notebook_login()

from datasets import load_dataset

common_voice_test = load_dataset(
    "mozilla-foundation/common_voice_13_0", "dv", split="test"
)

# 3 directly predict by origin model
# ====================

from tqdm import tqdm
from transformers.pipelines.pt_utils import KeyDataset

all_predictions = []

# 运行流式推理
for prediction in tqdm(
    pipe(
        KeyDataset(common_voice_test, "audio"),
        max_new_tokens=128,
        generate_kwargs={"task": "transcribe"},
        batch_size=32,
    ),
    total=len(common_voice_test),
):
    all_predictions.append(prediction["text"])

# 4 evaluate metrics: WER with orthograpic
# ====================

from evaluate import load

wer_metric = load("wer")

wer_ortho = 100 * wer_metric.compute(
    references=common_voice_test["sentence"], predictions=all_predictions
)
wer_ortho
# 167.29577268612022

# 4 evaluate metrics: WER without orthograpic
# ====================

from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()

# 计算规范化 WER
all_predictions_norm = [normalizer(pred) for pred in all_predictions]
all_references_norm = [normalizer(label) for label in common_voice_test["sentence"]]

# 过滤掉参考文本被规范化后为零值的样本
all_predictions_norm = [
    all_predictions_norm[i]
    for i in range(len(all_predictions_norm))
    if len(all_references_norm[i]) > 0
]
all_references_norm = [
    all_references_norm[i]
    for i in range(len(all_references_norm))
    if len(all_references_norm[i]) > 0
]

wer = 100 * wer_metric.compute(
    references=all_references_norm, predictions=all_predictions_norm
)

wer
# 125.69809089960707
```

## 3.6 fine-tuning model

### 3.6.1 prepare env

```python
from huggingface_hub import notebook_login

notebook_login()
"""
Login successful
Your token has been saved to /root/.huggingface/token
"""
```

### 3.6.2 prepare dataset

```python

# load dataset

from datasets import load_dataset, DatasetDict

common_voice = DatasetDict()

common_voice["train"] = load_dataset(
    "mozilla-foundation/common_voice_13_0", "dv", split="train+validation"
)
common_voice["test"] = load_dataset(
    "mozilla-foundation/common_voice_13_0", "dv", split="test"
)

print(common_voice)

# process columns
common_voice = common_voice.select_columns(["audio", "sentence"])
```

### 3.6.3 prepare pipeline: Model for Transfer Learning

**ASR pipeline**

The ASR pipeline can be de-composed into three stages:

- The feature extractor which pre-processes the raw audio-inputs to log-mel spectrograms
- The model which performs the sequence-to-sequence mapping
- The tokenizer which post-processes the predicted tokens to text

**fine-tune it on a new language**

When you fine-tune it on a new language, Whisper does a good job at leveraging its knowledge of the other 96 languages it’s pre-trained on. Largely speaking, all modern languages will be linguistically similar to at least one of the 96 languages Whisper already knows, so we’ll fall under this paradigm of cross-lingual knowledge representation.

What we need to do to fine-tune Whisper on a new language is find the language most similar that Whisper was pre-trained on. The Wikipedia article for Dhivehi states that Dhivehi is closely related to the Sinhalese language of Sri Lanka. If we check the language codes again, we can see that Sinhalese is present in the Whisper language set, so we can safely set our language argument to "sinhalese".

```python
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained(
    "openai/whisper-small", language="sinhalese", task="transcribe"
)
```

### 3.6.4 preprocess dataset

```python
common_voice["train"].features
"""
{'audio': Audio(sampling_rate=48000, mono=True, decode=True, id=None),
 'sentence': Value(dtype='string', id=None)}
"""

# resample
from datasets import Audio

sampling_rate = processor.feature_extractor.sampling_rate
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=sampling_rate))

# ignore too long examples
max_input_length = 30.0

def is_audio_in_length_range(length):
    return length < max_input_length

common_voice["train"] = common_voice["train"].filter(
    is_audio_in_length_range,
    input_columns=["input_length"],
)
```

### 3.6.5 tokenize audio(intput feature, log-mel)

Now we can write a function to prepare our data ready for the model:
- We load and resample the audio data on a sample-by-sample basis by calling sample["audio"]. As explained above, 🤗 Datasets performs any necessary resampling operations on the fly.
- We use the feature extractor to compute the log-mel spectrogram input features from our 1-dimensional audio array.
- We encode the transcriptions to label ids through the use of the tokenizer.

```python
def prepare_dataset(example):
    audio = example["audio"]

    example = processor(
        audio=audio["array"],
        sampling_rate=audio["sampling_rate"],
        text=example["sentence"],
    )

    # 计算输入音频样本的长度，以秒计
    example["input_length"] = len(audio["array"]) / audio["sampling_rate"]

    return example

common_voice = common_voice.map(
    prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=1
)
```

### 3.6.6 define a data collator

Training and Evaluation
Now that we’ve prepared our data, we’re ready to dive into the training pipeline. The 🤗 Trainer will do much of the heavy lifting for us. All we have to do is:

Define a data collator: the data collator takes our pre-processed data and prepares PyTorch tensors ready for the model.

Evaluation metrics: during evaluation, we want to evaluate the model using the word error rate (WER) metric. We need to define a compute_metrics function that handles this computation.

Load a pre-trained checkpoint: we need to load a pre-trained checkpoint and configure it correctly for training.

Define the training arguments: these will be used by the 🤗 Trainer in constructing the training schedule.

Once we’ve fine-tuned the model, we will evaluate it on the test data to verify that we have correctly trained it to transcribe speech in Dhivehi.

```python
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [
            {"input_features": feature["input_features"][0]} for feature in features
        ]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
```

### 3.6.7 define Evaluation Metrics

```python
import evaluate

metric = evaluate.load("wer")

from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # 用 pad_token_id 替换 -100
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # 我们希望在计算指标时不要组合起词元
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    # 计算普通的 WER
    wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)

    # 计算标准化的 WER
    pred_str_norm = [normalizer(pred) for pred in pred_str]
    label_str_norm = [normalizer(label) for label in label_str]
    # 过滤，从而在评估时只计算 reference 非空的样本
    pred_str_norm = [
        pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0
    ]
    label_str_norm = [
        label_str_norm[i]
        for i in range(len(label_str_norm))
        if len(label_str_norm[i]) > 0
    ]

    wer = 100 * metric.compute(predictions=pred_str_norm, references=label_str_norm)

    return {"wer_ortho": wer_ortho, "wer": wer}
```

### 3.6.8 Load a pre-trained checkpoint

```python

# 1 load checkpoint

from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
```

### 3.6.9 define the training arguments

For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments).

```python
# some setting

from functools import partial

# train args: disable cache during training since it's incompatible with gradient checkpointing
model.config.use_cache = False

# generate args: 为生成设置语言和任务，并重新启用缓存
model.generate = partial(
    model.generate, language="sinhalese", task="transcribe", use_cache=True
)

# train args
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-dv",  # 在 HF Hub 上的输出目录的名字
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # 每次 batch size 下调到一半就把这个参数上调到两倍
    learning_rate=1e-5,
    lr_scheduler_type="constant_with_warmup",
    warmup_steps=50,
    max_steps=500,  # 如果您有自己的 GPU 或者 Colab 付费计划，上调到 4000
    gradient_checkpointing=True,
    fp16=True,
    fp16_full_eval=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=500,
    eval_steps=500,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)

# We can forward the training arguments to the 🤗 Trainer along with our model, dataset, data collator and compute_metrics function:

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)
```

### 3.6.10 train

```python
trainer.train()
```

### 3.6.11 share model

```python
kwargs = {
    "dataset_tags": "mozilla-foundation/common_voice_13_0",
    "dataset": "Common Voice 13",  # 训练数据集
    "language": "dv",
    "model_name": "Whisper Small Dv - Sanchit Gandhi",  # 给模型起个“漂亮”的名字
    "finetuned_from": "openai/whisper-small",
    "tasks": "automatic-speech-recognition",
}

trainer.push_to_hub(**kwargs)
```

## 3.7 build app demo

- [gradio app demo](https://huggingface.co/learn/audio-course/en/chapter5/demo)

# 4 more doc

* [Whisper Talk](https://www.youtube.com/live/fZMiD8sDzzg?feature=share) by Jong Wook Kim: a presentation on the Whisper model, explaining the motivation, architecture, training and results, delivered by Whisper author Jong Wook Kim
* [End-to-End Speech Benchmark (ESB)](https://arxiv.org/abs/2210.13352): a paper that comprehensively argues for using the orthographic WER as opposed to the normalised WER for evaluating ASR systems and presents an accompanying benchmark
* [Fine-Tuning Whisper for Multilingual ASR](https://huggingface.co/blog/fine-tune-whisper): an in-depth blog post that explains how the Whisper model works in more detail, and the pre- and post-processing steps involved with the feature extractor and tokenizer
* [Fine-tuning MMS Adapter Models for Multi-Lingual ASR](https://huggingface.co/blog/mms_adapters): an end-to-end guide for fine-tuning Meta AI's new [MMS](https://ai.facebook.com/blog/multilingual-model-speech-recognition/) speech recognition models, freezing the base model weights and only fine-tuning a small number of *adapter* layers
* [Boosting Wav2Vec2 with n-grams in 🤗 Transformers](https://huggingface.co/blog/wav2vec2-with-ngram): a blog post for combining CTC models with external language models (LMs) to combat spelling and punctuation errors