<a href="https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/whisper_compile.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Whisper: Torch Compile for 4x Faster Inference

The Whisper model in Transformers is *overhead bound*: the bottle-neck in inference speed is the CPU not instructing the GPU fast enough for the GPU to be fully utilized. To address this, we need to provide the GPU with more operations at once. One way of achieving this is using [**torch compile**](https://pytorch.org/docs/stable/generated/torch.compile.html), a native PyTorch function for accelerating the inference speed of PyTorch models.

Torch compile takes a large region of a PyTorch graph and captures it into a single compiled region. In doing so, we can reduce the GPU instructions to this single compiled region, thus reducing the CPU overhead. Furthermore, torch compile generates faster kernels for the operations, speeding up computations and ensuring they are *memory bound*.

In this Colab, we'll see how torch compile can be enabled for the Whisper model with just two lines of code. We'll perform a benchmark that highlights the 4x inference speed improvements that torch compile provides to any Whisper model on the Hugging Face Hub.

## Set-Up

The runtime is configured to use an A100 GPU provided through Google Colab Pro Tier. You can either connect to an A100 by clicking the button `Connect A100` in the top right-hand corner of the screen, or select a diffent GPU by clicking `Runtime` -> `Change runtime type`.

Once we've done that, we can go ahead and install the necessary Python packages. We'll install [🤗 Transformers](https://huggingface.co/docs/transformers/index) for loading and running the Whisper models, [🤗 Datasets](https://huggingface.co/docs/datasets/index) for loading our benchmarking dataset, and [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) for fast model loading:

In [1]:
!pip install --upgrade --quiet pip
!pip install --upgrade --quiet transformers datasets[audio] accelerate

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.4/9.4 MB[0m [31m51.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m116.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.
gcsfs 2024.6.1 requires fsspec==2024.6.1, but you have fsspec 2024.5.0 which is incompatible.
google-colab 1.0.0 requires requests==2.31.0, but you have requests 2.32.3 which is incompatible.
ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.[0m[31m
[0m

## Benchmarking

First, we'll load the Whisper model and its accompanying processor using the familiar 🤗 Transformers API.
In this example, we'll load the pre-trained Whisper [medium.en](https://huggingface.co/openai/whisper-medium.en) model, but you're free
to swap this for any one of the [10k Whisper checkpoints](https://huggingface.co/models?library=transformers&other=whisper&sort=trending)
on the Hugging Face Hub. To reduce the loading time, we'll pass the [low_cpu_mem_usage](https://huggingface.co/docs/transformers/v4.43.4/en/big_models#accelerates-big-model-inference) flag to `.from_pretrained`:

In [2]:
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium.en", torch_dtype=torch_dtype, low_cpu_mem_usage=True)
model.to(device)

processor = WhisperProcessor.from_pretrained("openai/whisper-medium.en")

config.json:   0%|          | 0.00/1.95k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.06G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/1.95k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/805 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.41M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.83k [00:00<?, ?B/s]

For benchmarking, we'll load a small dataset consisting of 73 samples from the [LibriSpeech ASR](https://huggingface.co/datasets/librispeech_asr) validation-clean dataset. This amounts to ~9MB of data, so it's very lightweight and quick to download on-device:

In [3]:
from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset

Downloading readme:   0%|          | 0.00/520 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.19M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/73 [00:00<?, ? examples/s]

Dataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_rows: 73
})

To ensure the sampling rate of the audios matches the sampling rate of our model, we'll re-sample the audios to the sampling rate expected by Whisper (16kHz).
Note that the re-sampling is applied on-the-fly when the audios are loaded, with a no-op if the sampling rate already matches:

In [4]:
from datasets import Audio

sampling_rate = processor.feature_extractor.sampling_rate
dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))

We're ready to start benchmarking 📏 The following cell iterates over samples in our dataset one-by-one (i.e. with a batch size of one). For each sample, we perform three stages of inference:
1. Pre-processing the raw audio inputs to log-mel spectrograms
2. Auto-regressively generating the text tokens, conditional on the spectrogram inputs
3. Post-processing the generated tokens to text strings

For the purposes of benchmarking, we'll time the generation step, which is the portion performed by the Whisper model itself:

In [5]:
import time
from tqdm import tqdm
from torch.nn.attention import sdpa_kernel, SDPBackend

inference_time = 0.0
model.generation_config.max_new_tokens = 128

for sample in tqdm(dataset):
    # 1. Pre-process the audio inputs
    input_features = processor(sample["audio"]["array"], sampling_rate=16000, return_tensors="pt").input_features
    input_features = input_features.to(device, dtype=torch_dtype)

    # 2. Auto-regressively generate text tokens
    start = time.time()
    pred_ids = model.generate(input_features)
    inference_time += time.time() - start

    # 3. Post-process tokens to text
    pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True)

print(inference_time)

  0%|          | 0/73 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 73/73 [00:54<00:00,  1.34it/s]

36.376708984375





We have a baseline of **36.4-seconds** for our un-compiled model. Let's now apply torch compile and re-measure performance.

## Enable torch compile

The first step in enabling compile is self-explanatory: we need to apply the `torch.compile` transformation to the model forward pass. We'll set the compile mode to `reduce-overhead`, which uses [CUDA Graphs](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) to further reduce CPU overhead. We'll also set `fullgraph=True` to compile the entire model in one graph (i.e. with no graph breaks):

In [6]:
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

The second step involves setting the key-value (kv) cache. During decoding, the Whisper decoder computes the kv states for each new input token and saves them to be re-used at the next decoding step, forming a **kv-cache**. The default kv-cache implementation grows in length with each generated token, since we save a new set of kv states for each decoding step.

While dynamic shapes are compatible with a subset of `torch.compile` optimizations, they limit the extent to which the CPU overhead can be reduced. Thus, we'll switch the kv cache to a **static** implementation, which pre-allocates the entire kv-cache size to the maximum value and masks out the un-used parts from the attention computation. In doing so, this kv-cache implementation is compatible with the `reduce-overhead` setting from the previous step.

In [7]:
model.generation_config.cache_implementation = "static"

Since torch compile is a "just in time" (JIT) compilation, we need to perform a set of compilation steps to compile our model. Here, we'll perform three warm-up steps, generating to our maximum number of permitted tokens each time:

In [8]:
max_new_tokens = model.generation_config.max_new_tokens

for _ in tqdm(range(3)):
    with sdpa_kernel(SDPBackend.MATH):
        model.generate(input_features, min_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens)

  self.pid = os.fork()
100%|██████████| 3/3 [03:11<00:00, 63.68s/it]


**Note:** this code-cell may take several minutes to run, particularly the first time it is called. To reduce the compilation time of subsequent runs, upgrade to `torch>2.4` and enable the [FX graph cache](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) with the flag `TORCHINDUCTOR_FX_GRAPH_CACHE=1`.

## Benchmarking with Compile

We're now ready to re-run our benchmark using the compiled implementation. The only change we'll make is using the [scaled dot product attention (SDPA) context manager](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) to switch the attention implementation from flash attention to the native PyTorch C++ implementation, which typically leads to better performance under compile:

In [9]:
inference_time = 0.0

for sample in tqdm(dataset):
    # 1. Pre-process the audio inputs
    input_features = processor(sample["audio"]["array"], sampling_rate=16000, return_tensors="pt").input_features
    input_features = input_features.to(device, dtype=torch_dtype)

    # 2. Auto-regressively generate text tokens
    start = time.time()
    with sdpa_kernel(SDPBackend.MATH):
        pred_ids = model.generate(input_features)
    inference_time += time.time() - start

    # 3. Post-process tokens to text
    pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True)

print(inference_time)

100%|██████████| 73/73 [00:12<00:00,  5.99it/s]

10.20078706741333





The inference speed is reduced to just **10.2-seconds**, representing a speed-up of 3.6x with no degradation to model accuracy ⚡️

Recall that this optimization technique is model-agnostic: it can be applied to any Whisper model in the Transformers library. The speed-ups are hardware and model size dependent, with the largest speed-ups typically be seen for smaller models. However, even the largest Whisper model ([large-v3](https://huggingface.co/openai/whisper-large-v3)) obtains >3x speed-up.

## Conclusion

In this Colab, we've broken down the steps for Whisper inference using torch compile, demonstrating a 4x speed-up with two lines of additional code. For an end-to-end code example, refer to the [Whisper model card](https://huggingface.co/openai/whisper-large-v3#torch-compile).