<a href="https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/speculative_decoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Speculative Decoding for 2x Faster Whisper Inference

Open AI's [Whisper](https://huggingface.co/collections/openai/whisper-release-6501bba2cf999715fd953013) is a general purpose speech transcription model that achieves state-of-the-art results across a range of different benchmarks and audio conditions. While the transcription accuracy is exceptional, the inference time is very slow. A 1 hour audio clip takes upwards of 6 minutes to transcribe on a 16GB T4 GPU, even after leveraging inference optimisations like [flash attention](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2), half-precision and chunking.

In this Google Colab, we demonstrate how Speculative Decoding can be employed to reduce the inference time of Whisper by a **factor of 2**, while mathematically ensuring exactly the **same outputs** are achieved from the model. As a result, this method provides a perfect drop-in replacement for existing Whisper pipelines, since it provides free 2x speed-up while maintaining the same accuracy.

## Speculative Decoding

Speculative Decoding was proposed in the paper [Fast Inference from Transformers via Speculative Decoding](https://arxiv.org/abs/2211.17192) by Yaniv Leviathan et. al. from Google.
It works on the premise that a faster, **assistant model** can be used to boostrap the generation of a larger, **main model**.

The assistant model first generates a sequence of candidate tokens $\hat{\boldsymbol{y}}_{1:N}$. While these candidate tokens are generated quickly, they may differ from those predicted by the main model. Therefore, the candidate tokens are verified by the main model in a single forward pass. By generating with the faster assistant model and only performing validation forward passes with the main model, the decoding process is sped-up significantly. The $i$-th candidate token from the assistant model $\hat{y}_i$ is only kept if all previous candidate tokens $\hat{\boldsymbol{y}}_{<i}$ match the validation tokens predicted by the main model. Consequently, speculative decoding ensures that the generated output exactly matches the sequence of tokens that would be generated by the main model, making it a natural replacement for existing inference pipelines that use the main model alone.

In the remainder of this Colab, we'll go through a hands-on example for running speculative decoding using the Whisper model for speech transcription. For more details on how speculative decoding works, including animated illustrations of the generation process, refer to the [full blog post](https://huggingface.co/blog/whisper-spec-dec#speculative-decoding).

## Set-Up Environment

The runtime is already configured to use the free 16GB T4 GPU provided through Google Colab Free Tier, so all you need to do is hit the button `Connect T4` in the top right-hand corner of the screen.

Once we've done that, we can go ahead and install the necessary Python packages, including [🤗 Transformers](https://huggingface.co/docs/transformers/index) for loading and running the Whisper models, and [🤗 Datasets](https://huggingface.co/docs/datasets/index) for loading our benchmarking datasets:

In [1]:
!pip install --upgrade --quiet pip
!pip install --upgrade --quiet torch transformers datasets[audio] accelerate evaluate jiwer

[0m

## English Speech Transcription

### Baseline Implementation

We start by benchmarking Whisper [large-v2](https://huggingface.co/openai/whisper-large-v2) to get our baseline number for inference speed. We can load the main model and it's corresponding processor via the convenient [`AutoModelForSpeechSeq2Seq`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForSpeechSeq2Seq) and [`AutoProcessor`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoProcessor) classes. We'll load the model in `float16` precision and make sure that loading time takes as little time as possible by passing [`low_cpu_mem_usage=True`](https://huggingface.co/docs/transformers/main_classes/model#large-model-loading). Finally, we'll pass the argument `attn_implementation="sdpa"` to benefit from Flash Attention speed-ups through PyTorch's [SDPA attention kernel](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html):

In [2]:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v2"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

Let's load the English speech transcription dataset that we will use for benchmarking. We'll load a small dataset consisting of 73 samples from the [LibriSpeech ASR](https://huggingface.co/datasets/librispeech_asr) validation-clean dataset. This amounts to ~9MB of data, so it's very lightweight and quick to download on device:

In [3]:
from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

For the benchmark, we only want to measure the generation time, so let's write a short helper function that measures
this step. The following function will return both the decoded tokens and the time it took to run the model:

In [4]:
import time

def generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

We can now iterate over the audio samples in our dataset and sum up the overall generation time:

In [5]:
from tqdm import tqdm

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = generate_with_time(model, inputs)
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["text"]))

print(all_time)

100%|██████████| 73/73 [01:26<00:00,  1.19s/it]

77.25942993164062





Alright! We see that transcribing the 73 samples took 77 seconds. Let's check the WER of the predictions:

In [6]:
from evaluate import load

wer = load("wer")

print(wer.compute(predictions=predictions, references=references))

0.03507271171941831


Our final baseline numbers are 31 seconds for a WER of 3.5%.

### Speculative Decoding

Now let's load the assistant model for speculative decoding. In this example, we'll use a distilled variant of Whisper, [distil-large-v2](https://huggingface.co/distil-whisper/distil-large-v2). The distilled model copies the entire encoder from Whisper, but only 2 of the 32 decoder layers. As such, it runs 6x faster than Whisper, while performing to within 1% WER on our-of-distribution test sets. This makes it the perfect candidate choice of assistant model, since it has both high transcription accuracy and fast generation<a name="cite_ref-1"></a>[<sup>[1]</sup>](#cite_note-1).

Since Distil-Whisper uses exactly same encoder as the Whisper model, we can share the encoder across the main and assistant models. We then only have to load the 2-layer decoder from Distil-Whisper as a "decoder-only" model. We can do this through the convenient [`AutoModelForCausalLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM) auto class. In practice, this results in only an 8% increase to VRAM over using the main model alone.

In [7]:
from transformers import AutoModelForCausalLM

assistant_model_id = "distil-whisper/distil-large-v2"

assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)

assistant_model.to(device);

------------------------------------------------------------------------

<a name="cite_note-1"></a>1. We intend to release an improved variant of Distil-Whisper with a stronger alignment in the token distribution that will improve speculative decoding performance further. Follow the [Distil-Whisper repository](https://github.com/huggingface/distil-whisper) for updates.

We can define a modified function for our speculative decoding benchmark. The only difference from the previous function is that we pass the assistant model to our call to `.generate`:

In [8]:
def assisted_generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, assistant_model=assistant_model, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

Let's run the benchmark with speculative decoding, using Distil-Whisper as the assistant to Whisper:

In [9]:
all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = assisted_generate_with_time(model, inputs)
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["text"]))

print(all_time)

100%|██████████| 73/73 [00:53<00:00,  1.36it/s]

44.95286297798157





With speculative decoding, the inference time was just 33 seconds, 2.3x faster than before! Let's verify we have the same
WER:

In [10]:
print(wer.compute(predictions=predictions, references=references))

0.03507271171941831


Perfect! 3.5% WER again. This confirms we have identical outputs to using the main model standalone.

Speculative decoding can also be incorporated with 🤗 Transformers [pipeline](https://huggingface.co/docs/transformers/pipeline_tutorial) class for an easy API for inference. Below, we instantiate the pipeline using the model and processor, and then use it to transcribe the first sample from the toy dataset. This can be extended to transcribe audio samples of arbitrary length, including with the use of batching:

In [11]:
from transformers import pipeline

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    batch_size=4,
    generate_kwargs={"assistant_model": assistant_model},
    torch_dtype=torch_dtype,
    device=device,
)

sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])

 Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.


An end-to-end codesnippet for running speculative decoding with Whisper and Distil-Whisper can be found on the [Distil-Whisper model card](https://huggingface.co/distil-whisper/distil-large-v2#speculative-decoding). It combines the stages of inference covered in this notebook into a single code example.

## Multilingual Speech Transcription

Distil-Whisper is the perfect assistant model for English speech transcription, since it performs to within 1% WER of the original Whisper model, while being 6x faster over short and long-form audio samples. However, the official Distil-Whisper checkpoints are English only, meaning they cannot be used for multilingual speech transcription. To use speculative decoding for multilingual speech transcription, one could either use on of the [official multilingual Whisper checkpoints](https://huggingface.co/openai/whisper-large-v2#model-details), or a fine-tuned variant of Whisper. As of the time of writing, there are over 5,000 [fine-tuned Whisper checkpoints](https://huggingface.co/models?other=whisper) on the Hugging Face Hub in over 100 languages. These provide an excellent starting point for selecting assistant Whisper checkpoints that perform very well on a single language. In this example, we'll use the smallest official multilingual checkpoint, Whisper [tiny](https://huggingface.co/openai/whisper-tiny).

Let's load the weights for our new assistant model, Whisper tiny. Since the encoder in Whisper tiny differs from that in large-v2, we'll load both the encoder and decoder using the `AutoModelForSpeechSeq2Seq` class:

In [12]:
assistant_model_id = "openai/whisper-tiny"

assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    assistant_model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)

assistant_model.to(device);

For our benchmarking dataset, we'll load 73 samples from the Dutch (nl) split of the [VoxPopuli](https://huggingface.co/datasets/facebook/voxpopuli) dataset:

In [13]:
dataset = load_dataset("sanchit-gandhi/voxpopuli_dummy", "nl", split="validation")

Great! We can now re-run our benchmark for our baseline Whisper large-v2 model as before. The only change we make is that we pass the language and task arguments to our generate function, in order to ensure we perform speech transcription (not speech translation). Note that speculative decoding is fully compatible with the speech translation task. Simply set the task argument as required below:

In [None]:
all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = generate_with_time(model, inputs, language="nl", task="transcribe")
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

wer_result = wer.compute(predictions=predictions, references=references)

print("Time:", all_time)
print("WER:", wer_result)

 23%|██▎       | 17/73 [00:25<01:26,  1.55s/it]

Right! We have our baseline time of 47 seconds and a WER of 12.8%. Let's re-run the generation process using speculative decoding:

In [None]:
all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = assisted_generate_with_time(model, inputs, language="nl", task="transcribe")
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

wer_result = wer.compute(predictions=predictions, references=references)

print("Time:", all_time)
print("WER:", wer_result)

Again, we achieve 12.8% WER, but this time in just 31 seconds of inference time, representing a speed-up of 1.5x.

## Strategies for Efficient Speculative Decoding

#### Assistant Model

Our objective is to select an assistant model that is both fast **and** maintains the same token distribution as the main model. If you have a particular language in which you want to transcribe, an effective strategy is to train two Whisper models of different sizes, and use one as the assistant to the other:

* Fine-tune Whisper [large-v3](https://huggingface.co/openai/whisper-large-v3) to act as your main model
* Distil Whisper [large-v3](https://huggingface.co/openai/whisper-large-v3) on the same dataset to act as your assistant model

Additional training improves the WER performance of both the main and assistant model on your chosen language, while maximising the alignment in the token distributions. A complete guide to Whisper fine-tuning can be found [here](https://huggingface.co/blog/fine-tune-whisper), and distillation [here](https://github.com/huggingface/distil-whisper/tree/main/training).

#### Batch Size

It is worth noting that the largest speed gains with speculative decoding come with a batch size of 1. For batched speculative decoding, all candidate tokens **across the batch** must match the validation tokens in order for the tokens to be accepted. If a token in the batch at a given position does not agree, all candidate tokens that precede the position are discarded. Consequently, speculative decoding favours lower batch sizes. In practice, we find that speculative decoding provides a speed-up until a batch size of 4. Above batch size 4, speculative decoding returns slower inference than the main model alone. For full results, refer to Section D.3 of the [Distil-Whisper paper](https://arxiv.org/pdf/2311.00430.pdf).